From e0d17501c10086a867fcbd15a0cb5eb192b88eba Mon Sep 17 00:00:00 2001 From: Yerzhaisang <55043014+Yerzhaisang@users.noreply.github.com> Date: Thu, 20 Jul 2023 00:07:43 +0600 Subject: [PATCH] Enabled auto-truncation for any pretrained models (#192) * Made truncation parameter automatically processed Signed-off-by: Yerzhaisang Taskali * Made max_length parameter dynamic Signed-off-by: yerzhaisang * Added unit test for checking truncation parameter Signed-off-by: yerzhaisang * Updated CHANGELOG.md Signed-off-by: yerzhaisang * Included the test of max_length parameter value Signed-off-by: yerzhaisang * Slightly modeified the test of max_length parameter value Signed-off-by: yerzhaisang * Modified CHANGELOG.md and removed the duplicate Signed-off-by: yerzhaisang * Enabled auto-truncation format also for ONNX Signed-off-by: yerzhaisang * Implemented reusable function Signed-off-by: yerzhaisang * Fixed the lint Signed-off-by: yerzhaisang * Change tokenizer.json only if truncation is null Signed-off-by: yerzhaisang * Removed function which had been accidentally added Signed-off-by: yerzhaisang * Renamed reusable function and added the description Signed-off-by: yerzhaisang * Fixed the lint Signed-off-by: yerzhaisang --------- Signed-off-by: Yerzhaisang Taskali Signed-off-by: yerzhaisang --- CHANGELOG.md | 1 + .../ml_models/sentencetransformermodel.py | 37 +++++++++++++++++++ .../test_sentencetransformermodel_pytest.py | 32 ++++++++++++++++ 3 files changed, 70 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ccd3e535..9db83c06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Fix ModelUploader bug & Update model tracing demo notebook by @thanawan-atc in ([#185](https://github.com/opensearch-project/opensearch-py-ml/pull/185)) - Fix make_model_config_json function by @thanawan-atc in ([#188](https://github.com/opensearch-project/opensearch-py-ml/pull/188)) - Make make_model_config_json function more concise by @thanawan-atc in ([#191](https://github.com/opensearch-project/opensearch-py-ml/pull/191)) +- Enabled auto-truncation for any pretrained models ([#192]https://github.com/opensearch-project/opensearch-py-ml/pull/192) ## [1.0.0] diff --git a/opensearch_py_ml/ml_models/sentencetransformermodel.py b/opensearch_py_ml/ml_models/sentencetransformermodel.py index 057cfc4e..fb166f28 100644 --- a/opensearch_py_ml/ml_models/sentencetransformermodel.py +++ b/opensearch_py_ml/ml_models/sentencetransformermodel.py @@ -701,6 +701,37 @@ def zip_model( ) print("zip file is saved to " + zip_file_path + "\n") + def _fill_null_truncation_field( + self, + save_json_folder_path: str, + max_length: int, + ) -> None: + """ + Description: + Fill truncation field in tokenizer.json when it is null + + :param save_json_folder_path: + path to save model json file, e.g, "home/save_pre_trained_model_json/") + :type save_json_folder_path: string + :param max_length: + maximum sequence length for model + :type max_length: int + :return: no return value expected + :rtype: None + """ + tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json") + with open(tokenizer_file_path) as user_file: + parsed_json = json.load(user_file) + if "truncation" not in parsed_json or parsed_json["truncation"] is None: + parsed_json["truncation"] = { + "direction": "Right", + "max_length": max_length, + "strategy": "LongestFirst", + "stride": 0, + } + with open(tokenizer_file_path, "w") as file: + json.dump(parsed_json, file, indent=2) + def save_as_pt( self, sentences: [str], @@ -760,6 +791,9 @@ def save_as_pt( # save tokenizer.json in save_json_folder_name model.save(save_json_folder_path) + self._fill_null_truncation_field( + save_json_folder_path, model.tokenizer.model_max_length + ) # convert to pt format will need to be in cpu, # set the device to cpu, convert its input_ids and attention_mask in cpu and save as .pt format @@ -851,6 +885,9 @@ def save_as_onnx( # save tokenizer.json in output_path model.save(save_json_folder_path) + self._fill_null_truncation_field( + save_json_folder_path, model.tokenizer.model_max_length + ) convert( framework="pt", diff --git a/tests/ml_models/test_sentencetransformermodel_pytest.py b/tests/ml_models/test_sentencetransformermodel_pytest.py index ae16681f..de76f1a7 100644 --- a/tests/ml_models/test_sentencetransformermodel_pytest.py +++ b/tests/ml_models/test_sentencetransformermodel_pytest.py @@ -372,5 +372,37 @@ def test_overwrite_fields_in_model_config(): clean_test_folder(TEST_FOLDER) +def test_truncation_parameter(): + model_id = "sentence-transformers/msmarco-distilbert-base-tas-b" + MAX_LENGTH_TASB = 512 + + clean_test_folder(TEST_FOLDER) + test_model10 = SentenceTransformerModel( + folder_path=TEST_FOLDER, + model_id=model_id, + ) + + test_model10.save_as_pt(model_id=model_id, sentences=["today is sunny"]) + + tokenizer_json_file_path = os.path.join(TEST_FOLDER, "tokenizer.json") + try: + with open(tokenizer_json_file_path, "r") as json_file: + tokenizer_json = json.load(json_file) + except Exception as exec: + assert ( + False + ), f"Creating tokenizer.json file for tracing raised an exception {exec}" + + assert tokenizer_json[ + "truncation" + ], "truncation parameter in tokenizer.json is null" + + assert ( + tokenizer_json["truncation"]["max_length"] == MAX_LENGTH_TASB + ), "max_length is not properly set" + + clean_test_folder(TEST_FOLDER) + + clean_test_folder(TEST_FOLDER) clean_test_folder(TESTDATA_UNZIP_FOLDER)