diff --git a/opensearch_py_ml/ml_models/sentencetransformermodel.py b/opensearch_py_ml/ml_models/sentencetransformermodel.py index 93600487..892a8d19 100644 --- a/opensearch_py_ml/ml_models/sentencetransformermodel.py +++ b/opensearch_py_ml/ml_models/sentencetransformermodel.py @@ -687,8 +687,7 @@ def zip_model( if zip_file_name is None: zip_file_name = str(self.model_id.split("/")[-1] + ".zip") - - + zip_file_path = os.path.join(self.folder_path, zip_file_name) zip_file_name_without_extension = zip_file_name.split(".")[0] @@ -697,7 +696,6 @@ def zip_model( tokenizer_json_path = os.path.join(self.folder_path, "tokenizer.json") print("tokenizer_json_path: ", tokenizer_json_path) - if not os.path.exists(tokenizer_json_path): raise Exception( @@ -720,7 +718,6 @@ def zip_model( zipObj.write(APACHE_LICENSE_PATH, arcname="LICENSE") print("zip file is saved to " + zip_file_path + "\n") - def _fill_null_truncation_field( self, diff --git a/tests/ml_models/test_sentencetransformermodel_pytest.py b/tests/ml_models/test_sentencetransformermodel_pytest.py index 95ca168f..da759e64 100644 --- a/tests/ml_models/test_sentencetransformermodel_pytest.py +++ b/tests/ml_models/test_sentencetransformermodel_pytest.py @@ -24,6 +24,7 @@ os.path.dirname(os.path.abspath("__file__")), "tests", "sample_zip" ) + def clean_test_folder(TEST_FOLDER): if os.path.exists(TEST_FOLDER): for files in os.listdir(TEST_FOLDER): @@ -156,6 +157,7 @@ def test_missing_files(): clean_test_folder(temp_path) assert "Cannot find config.json" in str(exc_info.value) + def test_save_as_pt(): try: test_model.save_as_pt(sentences=["today is sunny"]) @@ -294,6 +296,7 @@ def test_make_model_config_json_for_onnx(): clean_test_folder(TEST_FOLDER) + def test_overwrite_fields_in_model_config(): model_id = "sentence-transformers/all-distilroberta-v1" expected_model_config_data = { @@ -635,7 +638,11 @@ def test_save_as_pt_with_license(): model_id=model_id, ) - test_model15.save_as_pt(model_id=model_id, sentences=["today is sunny"], license_to_be_zipped="apache-2.0") + test_model15.save_as_pt( + model_id=model_id, + sentences=["today is sunny"], + license_to_be_zipped="apache-2.0", + ) with ZipFile(torch_script_zip_file_path, "r") as f: filenames = set(f.namelist()) assert ( @@ -698,5 +705,6 @@ def test_zip_model_with_license(): ), f"The content in the model zip file does not match the expected content: {filenames} != {expected_filenames_with_license}" clean_test_folder(TEST_FOLDER) + clean_test_folder(TEST_FOLDER) clean_test_folder(TESTDATA_UNZIP_FOLDER) diff --git a/utils/model_uploader/model_autotracing.py b/utils/model_uploader/model_autotracing.py index 4d283490..a0578ff9 100644 --- a/utils/model_uploader/model_autotracing.py +++ b/utils/model_uploader/model_autotracing.py @@ -143,24 +143,23 @@ def trace_sentence_transformer_model( model_config_path = None try: model_config_path = pre_trained_model.make_model_config_json( - version_number=model_version, - model_format=model_format, - embedding_dimension=embedding_dimension, - pooling_mode=pooling_mode, - description=model_description, + version_number=model_version, + model_format=model_format, + embedding_dimension=embedding_dimension, + pooling_mode=pooling_mode, + description=model_description, ) except Exception as e: assert ( False ), f"Raised Exception during making model config file for {model_format} model: {e}" - # 4.) Preview model config print(f"\n+++++ {model_format} Model Config +++++\n") with open(model_config_path, "r") as f: model_config = json.load(f) print(json.dumps(model_config, indent=4)) - print(f"\n+++++++++++++++++++++++++++++++++++++++\n") + print("\n+++++++++++++++++++++++++++++++++++++++\n") # 5.) Return model_path & model_config_path for model registration return model_path, model_config_path