From 27375e2869d2f8895ae53084e444167e689dfaa7 Mon Sep 17 00:00:00 2001 From: yerzhaisang Date: Tue, 18 Jul 2023 23:39:32 +0600 Subject: [PATCH] Enabled auto-truncation format also for ONNX Signed-off-by: yerzhaisang --- .../ml_models/sentencetransformermodel.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/opensearch_py_ml/ml_models/sentencetransformermodel.py b/opensearch_py_ml/ml_models/sentencetransformermodel.py index cbd1c09d..f6be90ad 100644 --- a/opensearch_py_ml/ml_models/sentencetransformermodel.py +++ b/opensearch_py_ml/ml_models/sentencetransformermodel.py @@ -760,17 +760,17 @@ def save_as_pt( # save tokenizer.json in save_json_folder_name model.save(save_json_folder_path) - with open(save_json_folder_path + "/tokenizer.json") as user_file: - file_contents = user_file.read() - parsed_json = json.loads(file_contents) - if not parsed_json["truncation"]: + tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json") + with open(tokenizer_file_path) as user_file: + parsed_json = json.load(user_file) + if "truncation" not in parsed_json or parsed_json["truncation"] is None: parsed_json["truncation"] = { "direction": "Right", "max_length": model.tokenizer.model_max_length, "strategy": "LongestFirst", "stride": 0, } - with open(save_json_folder_path + "/tokenizer.json", "w") as file: + with open(tokenizer_file_path, "w") as file: json.dump(parsed_json, file, indent=2) # convert to pt format will need to be in cpu, @@ -863,6 +863,18 @@ def save_as_onnx( # save tokenizer.json in output_path model.save(save_json_folder_path) + tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json") + with open(tokenizer_file_path) as user_file: + parsed_json = json.load(user_file) + if "truncation" not in parsed_json or parsed_json["truncation"] is None: + parsed_json["truncation"] = { + "direction": "Right", + "max_length": model.tokenizer.model_max_length, + "strategy": "LongestFirst", + "stride": 0, + } + with open(tokenizer_file_path, "w") as file: + json.dump(parsed_json, file, indent=2) convert( framework="pt",