Skip to content

Commit

Permalink
Enabled auto-truncation for any pretrained models (#192)
Browse files Browse the repository at this point in the history
* Made truncation parameter automatically processed

Signed-off-by: Yerzhaisang Taskali <[email protected]>

* Made max_length parameter dynamic

Signed-off-by: yerzhaisang <[email protected]>

* Added unit test for checking truncation parameter

Signed-off-by: yerzhaisang <[email protected]>

* Updated CHANGELOG.md

Signed-off-by: yerzhaisang <[email protected]>

* Included the test of max_length parameter value

Signed-off-by: yerzhaisang <[email protected]>

* Slightly modeified the test of max_length parameter value

Signed-off-by: yerzhaisang <[email protected]>

* Modified CHANGELOG.md and removed the duplicate

Signed-off-by: yerzhaisang <[email protected]>

* Enabled auto-truncation format also for ONNX

Signed-off-by: yerzhaisang <[email protected]>

* Implemented reusable function

Signed-off-by: yerzhaisang <[email protected]>

* Fixed the lint

Signed-off-by: yerzhaisang <[email protected]>

* Change tokenizer.json only if truncation is null

Signed-off-by: yerzhaisang <[email protected]>

* Removed function which had been accidentally added

Signed-off-by: yerzhaisang <[email protected]>

* Renamed reusable function and added the description

Signed-off-by: yerzhaisang <[email protected]>

* Fixed the lint

Signed-off-by: yerzhaisang <[email protected]>

---------

Signed-off-by: Yerzhaisang Taskali <[email protected]>
Signed-off-by: yerzhaisang <[email protected]>
  • Loading branch information
Yerzhaisang authored Jul 19, 2023
1 parent 7622af5 commit e0d1750
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Fix ModelUploader bug & Update model tracing demo notebook by @thanawan-atc in ([#185](https://github.com/opensearch-project/opensearch-py-ml/pull/185))
- Fix make_model_config_json function by @thanawan-atc in ([#188](https://github.com/opensearch-project/opensearch-py-ml/pull/188))
- Make make_model_config_json function more concise by @thanawan-atc in ([#191](https://github.com/opensearch-project/opensearch-py-ml/pull/191))
- Enabled auto-truncation for any pretrained models ([#192]https://github.com/opensearch-project/opensearch-py-ml/pull/192)

## [1.0.0]

Expand Down
37 changes: 37 additions & 0 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,37 @@ def zip_model(
)
print("zip file is saved to " + zip_file_path + "\n")

def _fill_null_truncation_field(
self,
save_json_folder_path: str,
max_length: int,
) -> None:
"""
Description:
Fill truncation field in tokenizer.json when it is null
:param save_json_folder_path:
path to save model json file, e.g, "home/save_pre_trained_model_json/")
:type save_json_folder_path: string
:param max_length:
maximum sequence length for model
:type max_length: int
:return: no return value expected
:rtype: None
"""
tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json")
with open(tokenizer_file_path) as user_file:
parsed_json = json.load(user_file)
if "truncation" not in parsed_json or parsed_json["truncation"] is None:
parsed_json["truncation"] = {
"direction": "Right",
"max_length": max_length,
"strategy": "LongestFirst",
"stride": 0,
}
with open(tokenizer_file_path, "w") as file:
json.dump(parsed_json, file, indent=2)

def save_as_pt(
self,
sentences: [str],
Expand Down Expand Up @@ -760,6 +791,9 @@ def save_as_pt(

# save tokenizer.json in save_json_folder_name
model.save(save_json_folder_path)
self._fill_null_truncation_field(
save_json_folder_path, model.tokenizer.model_max_length
)

# convert to pt format will need to be in cpu,
# set the device to cpu, convert its input_ids and attention_mask in cpu and save as .pt format
Expand Down Expand Up @@ -851,6 +885,9 @@ def save_as_onnx(

# save tokenizer.json in output_path
model.save(save_json_folder_path)
self._fill_null_truncation_field(
save_json_folder_path, model.tokenizer.model_max_length
)

convert(
framework="pt",
Expand Down
32 changes: 32 additions & 0 deletions tests/ml_models/test_sentencetransformermodel_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,5 +372,37 @@ def test_overwrite_fields_in_model_config():
clean_test_folder(TEST_FOLDER)


def test_truncation_parameter():
model_id = "sentence-transformers/msmarco-distilbert-base-tas-b"
MAX_LENGTH_TASB = 512

clean_test_folder(TEST_FOLDER)
test_model10 = SentenceTransformerModel(
folder_path=TEST_FOLDER,
model_id=model_id,
)

test_model10.save_as_pt(model_id=model_id, sentences=["today is sunny"])

tokenizer_json_file_path = os.path.join(TEST_FOLDER, "tokenizer.json")
try:
with open(tokenizer_json_file_path, "r") as json_file:
tokenizer_json = json.load(json_file)
except Exception as exec:
assert (
False
), f"Creating tokenizer.json file for tracing raised an exception {exec}"

assert tokenizer_json[
"truncation"
], "truncation parameter in tokenizer.json is null"

assert (
tokenizer_json["truncation"]["max_length"] == MAX_LENGTH_TASB
), "max_length is not properly set"

clean_test_folder(TEST_FOLDER)


clean_test_folder(TEST_FOLDER)
clean_test_folder(TESTDATA_UNZIP_FOLDER)

0 comments on commit e0d1750

Please sign in to comment.