Skip to content

Commit

Permalink
Correct linting
Browse files Browse the repository at this point in the history
Signed-off-by: thanawan-atc <[email protected]>
  • Loading branch information
thanawan-atc committed Sep 17, 2023
1 parent d2e6a1d commit 9b460fd
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
5 changes: 1 addition & 4 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,7 @@ def zip_model(

if zip_file_name is None:
zip_file_name = str(self.model_id.split("/")[-1] + ".zip")



zip_file_path = os.path.join(self.folder_path, zip_file_name)
zip_file_name_without_extension = zip_file_name.split(".")[0]

Expand All @@ -697,7 +696,6 @@ def zip_model(

tokenizer_json_path = os.path.join(self.folder_path, "tokenizer.json")
print("tokenizer_json_path: ", tokenizer_json_path)


if not os.path.exists(tokenizer_json_path):
raise Exception(
Expand All @@ -720,7 +718,6 @@ def zip_model(
zipObj.write(APACHE_LICENSE_PATH, arcname="LICENSE")

print("zip file is saved to " + zip_file_path + "\n")


def _fill_null_truncation_field(
self,
Expand Down
10 changes: 9 additions & 1 deletion tests/ml_models/test_sentencetransformermodel_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
os.path.dirname(os.path.abspath("__file__")), "tests", "sample_zip"
)


def clean_test_folder(TEST_FOLDER):
if os.path.exists(TEST_FOLDER):
for files in os.listdir(TEST_FOLDER):
Expand Down Expand Up @@ -156,6 +157,7 @@ def test_missing_files():
clean_test_folder(temp_path)
assert "Cannot find config.json" in str(exc_info.value)


def test_save_as_pt():
try:
test_model.save_as_pt(sentences=["today is sunny"])
Expand Down Expand Up @@ -294,6 +296,7 @@ def test_make_model_config_json_for_onnx():

clean_test_folder(TEST_FOLDER)


def test_overwrite_fields_in_model_config():
model_id = "sentence-transformers/all-distilroberta-v1"
expected_model_config_data = {
Expand Down Expand Up @@ -635,7 +638,11 @@ def test_save_as_pt_with_license():
model_id=model_id,
)

test_model15.save_as_pt(model_id=model_id, sentences=["today is sunny"], license_to_be_zipped="apache-2.0")
test_model15.save_as_pt(
model_id=model_id,
sentences=["today is sunny"],
license_to_be_zipped="apache-2.0",
)
with ZipFile(torch_script_zip_file_path, "r") as f:
filenames = set(f.namelist())
assert (
Expand Down Expand Up @@ -698,5 +705,6 @@ def test_zip_model_with_license():
), f"The content in the model zip file does not match the expected content: {filenames} != {expected_filenames_with_license}"
clean_test_folder(TEST_FOLDER)


clean_test_folder(TEST_FOLDER)
clean_test_folder(TESTDATA_UNZIP_FOLDER)
13 changes: 6 additions & 7 deletions utils/model_uploader/model_autotracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,24 +143,23 @@ def trace_sentence_transformer_model(
model_config_path = None
try:
model_config_path = pre_trained_model.make_model_config_json(
version_number=model_version,
model_format=model_format,
embedding_dimension=embedding_dimension,
pooling_mode=pooling_mode,
description=model_description,
version_number=model_version,
model_format=model_format,
embedding_dimension=embedding_dimension,
pooling_mode=pooling_mode,
description=model_description,
)
except Exception as e:
assert (
False
), f"Raised Exception during making model config file for {model_format} model: {e}"


# 4.) Preview model config
print(f"\n+++++ {model_format} Model Config +++++\n")
with open(model_config_path, "r") as f:
model_config = json.load(f)
print(json.dumps(model_config, indent=4))
print(f"\n+++++++++++++++++++++++++++++++++++++++\n")
print("\n+++++++++++++++++++++++++++++++++++++++\n")

# 5.) Return model_path & model_config_path for model registration
return model_path, model_config_path
Expand Down

0 comments on commit 9b460fd

Please sign in to comment.