Skip to content

Commit

Permalink
Implemented reusable function
Browse files Browse the repository at this point in the history
Signed-off-by: yerzhaisang <[email protected]>
  • Loading branch information
Yerzhaisang committed Jul 19, 2023
1 parent 27375e2 commit 9f3ddbb
Showing 1 changed file with 35 additions and 24 deletions.
59 changes: 35 additions & 24 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,39 @@ def zip_model(
)
print("zip file is saved to " + zip_file_path + "\n")

def fix_truncation(
self,
save_json_folder_path: str,
max_length: int,
) -> None:
"""
Description:
Fix truncation parameter in tokenizer.json file.
If this parameter value is null, it results in error
:param save_json_folder_path:
path to save model json file, e.g, "home/save_pre_trained_model_json/")
:type save_json_folder_path: string
:param max_length:
maximum sequence length for model
:type max_length: int
:return: no return value expected
:rtype: None
"""
tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json")
with open(tokenizer_file_path) as user_file:
parsed_json = json.load(user_file)
if "truncation" not in parsed_json or parsed_json["truncation"] is None:
parsed_json["truncation"] = {
"direction": "Right",
"max_length": max_length,
"strategy": "LongestFirst",
"stride": 0,
}
with open(tokenizer_file_path, "w") as file:
json.dump(parsed_json, file, indent=2)


def save_as_pt(
self,
sentences: [str],
Expand Down Expand Up @@ -760,18 +793,7 @@ def save_as_pt(

# save tokenizer.json in save_json_folder_name
model.save(save_json_folder_path)
tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json")
with open(tokenizer_file_path) as user_file:
parsed_json = json.load(user_file)
if "truncation" not in parsed_json or parsed_json["truncation"] is None:
parsed_json["truncation"] = {
"direction": "Right",
"max_length": model.tokenizer.model_max_length,
"strategy": "LongestFirst",
"stride": 0,
}
with open(tokenizer_file_path, "w") as file:
json.dump(parsed_json, file, indent=2)
self.fix_truncation(save_json_folder_path, model.tokenizer.model_max_length)

# convert to pt format will need to be in cpu,
# set the device to cpu, convert its input_ids and attention_mask in cpu and save as .pt format
Expand Down Expand Up @@ -863,18 +885,7 @@ def save_as_onnx(

# save tokenizer.json in output_path
model.save(save_json_folder_path)
tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json")
with open(tokenizer_file_path) as user_file:
parsed_json = json.load(user_file)
if "truncation" not in parsed_json or parsed_json["truncation"] is None:
parsed_json["truncation"] = {
"direction": "Right",
"max_length": model.tokenizer.model_max_length,
"strategy": "LongestFirst",
"stride": 0,
}
with open(tokenizer_file_path, "w") as file:
json.dump(parsed_json, file, indent=2)
self.fix_truncation(save_json_folder_path, model.tokenizer.model_max_length)

convert(
framework="pt",
Expand Down

0 comments on commit 9f3ddbb

Please sign in to comment.