Skip to content

Commit

Permalink
Improve cov (2)
Browse files Browse the repository at this point in the history
Signed-off-by: Thanawan Atchariyachanvanit <[email protected]>
  • Loading branch information
thanawan-atc committed Jul 14, 2023
1 parent c2fd054 commit c685cd1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 30 deletions.
43 changes: 15 additions & 28 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,32 +1024,25 @@ def make_model_config_json(

# if user input model_type/embedding_dimension/pooling_mode, it will skip this step.
model = SentenceTransformer(self.model_id)
if model_type is None:
if len(model._modules) >= 1 and isinstance(
model._modules["0"], Transformer
):
try:
if model_type is None or embedding_dimension is None or pooling_mode is None or normalize_result is None:
try:
if model_type is None and len(model._modules) >= 1 and isinstance(
model._modules["0"], Transformer
):
model_type = model._modules["0"].auto_model.__class__.__name__
model_type = model_type.lower().rstrip("model")
except Exception as e:
raise Exception(f"Raised exception while getting model_type: {e}")

if embedding_dimension is None:
try:
embedding_dimension = model.get_sentence_embedding_dimension()
if embedding_dimension is None:
embedding_dimension = model.get_sentence_embedding_dimension()
if pooling is None and len(model._modules) >= 2 and isinstance(model._modules["1"], Pooling):
pooling_mode = model._modules["1"].get_pooling_mode_str().upper()
if normalize_result is None:
if len(model._modules) >= 3 and isinstance(model._modules["2"], Normalize):
normalize_result = True
else:
normalize_result = False
except Exception as e:
raise Exception(
f"Raised exception while calling get_sentence_embedding_dimension(): {e}"
)
raise Exception(f"Raised exception while getting model data from pre-trained hugging-face model object: {e}")

if pooling_mode is None:
if len(model._modules) >= 2 and isinstance(model._modules["1"], Pooling):
try:
pooling_mode = model._modules["1"].get_pooling_mode_str().upper()
except Exception as e:
raise Exception(
f"Raised exception while calling get_pooling_mode_str(): {e}"
)

if all_config is None:
if not os.path.exists(config_json_file_path):
Expand All @@ -1075,12 +1068,6 @@ def make_model_config_json(
"file in the path.",
)

if normalize_result is None:
if len(model._modules) >= 3 and isinstance(model._modules["2"], Normalize):
normalize_result = True
else:
normalize_result = False

model_config_content = {
"name": model_name,
"version": version_number,
Expand Down
4 changes: 2 additions & 2 deletions tests/ml_models/test_sentencetransformermodel_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_missing_files():

# test no tokenizer.json file
with pytest.raises(Exception) as exc_info:
test_model.zip_model()
test_model.zip_model(verbose=True)
assert "Cannot find tokenizer.json file" in str(exc_info.value)

# test no model file
Expand All @@ -137,7 +137,7 @@ def test_missing_files():
test_model3 = SentenceTransformerModel(folder_path=temp_path)
test_model3.save_as_pt(sentences=["today is sunny"])
os.remove(os.path.join(temp_path, "msmarco-distilbert-base-tas-b.pt"))
test_model3.zip_model()
test_model3.zip_model(verbose=True)
clean_test_folder(temp_path)
assert "Cannot find model in the model path" in str(exc_info.value)

Expand Down

0 comments on commit c685cd1

Please sign in to comment.