From 9f3ddbb9f62c8e17297a5863806757e33955be43 Mon Sep 17 00:00:00 2001 From: yerzhaisang Date: Wed, 19 Jul 2023 12:53:13 +0600 Subject: [PATCH] Implemented reusable function Signed-off-by: yerzhaisang --- .../ml_models/sentencetransformermodel.py | 59 +++++++++++-------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/opensearch_py_ml/ml_models/sentencetransformermodel.py b/opensearch_py_ml/ml_models/sentencetransformermodel.py index f6be90ad..54e308de 100644 --- a/opensearch_py_ml/ml_models/sentencetransformermodel.py +++ b/opensearch_py_ml/ml_models/sentencetransformermodel.py @@ -701,6 +701,39 @@ def zip_model( ) print("zip file is saved to " + zip_file_path + "\n") + def fix_truncation( + self, + save_json_folder_path: str, + max_length: int, + ) -> None: + """ + Description: + Fix truncation parameter in tokenizer.json file. + If this parameter value is null, it results in error + + :param save_json_folder_path: + path to save model json file, e.g, "home/save_pre_trained_model_json/") + :type save_json_folder_path: string + :param max_length: + maximum sequence length for model + :type max_length: int + :return: no return value expected + :rtype: None + """ + tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json") + with open(tokenizer_file_path) as user_file: + parsed_json = json.load(user_file) + if "truncation" not in parsed_json or parsed_json["truncation"] is None: + parsed_json["truncation"] = { + "direction": "Right", + "max_length": max_length, + "strategy": "LongestFirst", + "stride": 0, + } + with open(tokenizer_file_path, "w") as file: + json.dump(parsed_json, file, indent=2) + + def save_as_pt( self, sentences: [str], @@ -760,18 +793,7 @@ def save_as_pt( # save tokenizer.json in save_json_folder_name model.save(save_json_folder_path) - tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json") - with open(tokenizer_file_path) as user_file: - parsed_json = json.load(user_file) - if "truncation" not in parsed_json or parsed_json["truncation"] is None: - parsed_json["truncation"] = { - "direction": "Right", - "max_length": model.tokenizer.model_max_length, - "strategy": "LongestFirst", - "stride": 0, - } - with open(tokenizer_file_path, "w") as file: - json.dump(parsed_json, file, indent=2) + self.fix_truncation(save_json_folder_path, model.tokenizer.model_max_length) # convert to pt format will need to be in cpu, # set the device to cpu, convert its input_ids and attention_mask in cpu and save as .pt format @@ -863,18 +885,7 @@ def save_as_onnx( # save tokenizer.json in output_path model.save(save_json_folder_path) - tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json") - with open(tokenizer_file_path) as user_file: - parsed_json = json.load(user_file) - if "truncation" not in parsed_json or parsed_json["truncation"] is None: - parsed_json["truncation"] = { - "direction": "Right", - "max_length": model.tokenizer.model_max_length, - "strategy": "LongestFirst", - "stride": 0, - } - with open(tokenizer_file_path, "w") as file: - json.dump(parsed_json, file, indent=2) + self.fix_truncation(save_json_folder_path, model.tokenizer.model_max_length) convert( framework="pt",