From 4736d1a0722de2195eda5f11fa1ec0168bfd6a0d Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 13 Feb 2024 15:36:20 -0800 Subject: [PATCH 01/14] add cross-encoder tracing, config-generating, and uploading Signed-off-by: HenryL27 --- opensearch_py_ml/ml_models/__init__.py | 3 +- .../ml_models/crossencodermodel.py | 289 ++++++++++++++++++ requirements.txt | 1 + 3 files changed, 292 insertions(+), 1 deletion(-) create mode 100644 opensearch_py_ml/ml_models/crossencodermodel.py diff --git a/opensearch_py_ml/ml_models/__init__.py b/opensearch_py_ml/ml_models/__init__.py index 3ec96ebd5..790c90efe 100644 --- a/opensearch_py_ml/ml_models/__init__.py +++ b/opensearch_py_ml/ml_models/__init__.py @@ -7,5 +7,6 @@ from .metrics_correlation.mcorr import MCorr from .sentencetransformermodel import SentenceTransformerModel +from .crossencodermodel import CrossEncoderModel -__all__ = ["SentenceTransformerModel", "MCorr"] +__all__ = ["SentenceTransformerModel", "MCorr", "CrossEncoderModel"] diff --git a/opensearch_py_ml/ml_models/crossencodermodel.py b/opensearch_py_ml/ml_models/crossencodermodel.py new file mode 100644 index 000000000..3cfc6dc48 --- /dev/null +++ b/opensearch_py_ml/ml_models/crossencodermodel.py @@ -0,0 +1,289 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +import json +from opensearch_py_ml.ml_commons import ModelUploader +from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig +from pathlib import Path +from zipfile import ZipFile +import shutil +import os +import requests +import torch +from opensearch_py_ml.ml_commons.ml_common_utils import ( + _generate_model_content_hash_value, +) +from opensearchpy import OpenSearch + + +def _fix_tokenizer(max_len: int, path: Path): + """ + Add truncation parameters to tokenizer file. Edits the file in place + + :param max_len: max number of tokens to truncate to + :type max_len: int + :param path: path to tokenizer file + :type path: str + """ + with open(Path(path) / "tokenizer.json", "r") as f: + parsed = json.load(f) + if "truncation" not in parsed or parsed['truncation'] is None: + parsed['truncation'] = { + "direction": "Right", + "max_length": max_len, + "strategy": "LongestFirst", + "stride": 0, + } + with open(Path(path) / "tokenizer.json", "w") as f: + json.dump(parsed, f, indent=2) + + +class CrossEncoderModel: + """ + Class for configuring and uploading cross encoder models for opensearch + """ + def __init__( + self, + hf_model_id: str, + folder_path: str = None, + overwrite: bool = False + ) -> None: + """ + Initialize a new CrossEncoder model from a huggingface id + + :param hf_model_id: huggingface id of the model to load + :type hf_model_id: str + :param folder_path: folder path to save the model + default is /tmp/models/hf_model_id + :type folder_path: str + :param overwrite: whether to overwrite the existing model + :type overwrite: bool + :return: None + """ + default_folder_path = Path(f"/tmp/models/{hf_model_id}") + + if folder_path is None: + self._folder_path = default_folder_path + else: + self._folder_path = Path(folder_path) + + if self._folder_path.exists() and not overwrite: + raise Exception(f"Folder {self._folder_path} already exists. To overwrite it, set `overwrite=True`.") + + self._hf_model_id = hf_model_id + self._framework = None + self._folder_path.mkdir(parents=True, exist_ok=True) + + + def zip_model(self, framework: str = "pt") -> Path: + """ + Compiles and zips the model to {self._folder_path}/model.zip + + :param framework: one of "pt", "onnx". The framework to zip the model as. + default: "pt" + :type framework: str + :return: the path with the zipped model + :rtype: Path + """ + if framework == "pt": + self._framework = "pt" + return self._zip_model_pytorch() + if framework == "onnx": + self._framework = "onnx" + return self._zip_model_onnx() + raise Exception(f"Unrecognized framework {framework}. Accepted values are `pt`, `onnx`") + + + def _zip_model_pytorch(self) -> Path: + """ + Compiles the model to TORCHSCRIPT format. + """ + tk = AutoTokenizer.from_pretrained(self._hf_model_id) + model = AutoModelForSequenceClassification.from_pretrained(self._hf_model_id) + features = tk([["dummy sentence 1", "dummy sentence 2"]], return_tensors="pt") + mname = Path(self._hf_model_id).name + + # bge models don't generate token type ids + if mname.startswith("bge"): + features['token_type_ids'] = torch.zeros_like(features['input_ids']) + + # compile + compiled = torch.jit.trace(model, example_kwarg_inputs={ + 'input_ids': features['input_ids'], + 'attention_mask': features['attention_mask'], + 'token_type_ids': features['token_type_ids'] + }, strict=False) + torch.jit.save(compiled, f"/tmp/{mname}.pt") + + # save tokenizer file + tk_path = f"/tmp/{mname}-tokenizer" + tk.save_pretrained(tk_path) + _fix_tokenizer(tk.model_max_length, tk_path) + + # get apache license + r = requests.get("https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE") + with ZipFile(self._folder_path / "model.zip", "w") as f: + f.write(f"/tmp/{mname}.pt", arcname=f"{mname}.pt") + f.write(tk_path + "/tokenizer.json", arcname="tokenizer.json") + f.writestr("LICENSE", r.content) + + # clean up temp files + shutil.rmtree(f"/tmp/{mname}-tokenizer") + os.remove(f"/tmp/{mname}.pt") + return self._folder_path / "model.zip" + + def _zip_model_onnx(self): + """ + Compiles the model to ONNX format. + """ + tk = AutoTokenizer.from_pretrained(self._hf_model_id) + model = AutoModelForSequenceClassification.from_pretrained(self._hf_model_id) + features = tk([["dummy sentence 1", "dummy sentence 2"]], return_tensors="pt") + mname = Path(self._hf_model_id).name + + # bge models don't generate token type ids + if mname.startswith("bge"): + features['token_type_ids'] = torch.zeros_like(features['input_ids']) + + # export to onnx + onnx_model_path = f"/tmp/{mname}.onnx" + torch.onnx.export( + model=model, + args=(features['input_ids'], features['attention_mask'], features['token_type_ids']), + f=onnx_model_path, + input_names=['input_ids', 'attention_mask', 'token_type_ids'], + output_names=['output'], + dynamic_axes={ + 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, + 'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, + 'token_type_ids': {0: 'batch_size', 1: 'sequence_length'}, + 'output': {0: 'batch_size'} + }, + verbose=True + ) + + # save tokenizer file + tk_path = f"/tmp/{mname}-tokenizer" + tk.save_pretrained(tk_path) + _fix_tokenizer(tk.model_max_length, tk_path) + + # get apache license + r = requests.get("https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE") + with ZipFile(self._folder_path / "model.zip", "w") as f: + f.write(onnx_model_path, arcname=f"{mname}.pt") + f.write(tk_path + "/tokenizer.json", arcname="tokenizer.json") + f.writestr("LICENSE", r.content) + + # clean up temp files + shutil.rmtree(f"/tmp/{mname}-tokenizer") + os.remove(onnx_model_path) + return self._folder_path / "model.zip" + + + def make_model_config_json( + self, + model_name: str = None, + version_number: str = 1, + description: str = None, + all_config: str = None, + model_type: str = None, + verbose: bool = False, + ): + """ + Parse from config.json file of pre-trained hugging-face model to generate a ml-commons_model_config.json file. + If all required fields are given by users, use the given parameters and will skip reading the config.json + + :param model_name: + Optional, The name of the model. If None, default is model id, for example, + 'sentence-transformers/msmarco-distilbert-base-tas-b' + :type model_name: string + :param version_number: + Optional, The version number of the model. Default is 1 + :type version_number: string + :param description: Optional, the description of the model. If None, get description from the README.md + file in the model folder. + :type description: str + :param all_config: + Optional, the all_config of the model. If None, parse all contents from the config file of pre-trained + hugging-face model + :type all_config: dict + :param model_type: + Optional, the model_type of the model. If None, parse model_type from the config file of pre-trained + hugging-face model + :type model_type: string + :param verbose: + optional, use printing more logs. Default as false + :type verbose: bool + :return: model config file path. The file path where the model config file is being saved + :rtype: string + """ + if not (self._folder_path / "model.zip").exists(): + raise Exception("Generate the model zip before generating the config") + hash_value = _generate_model_content_hash_value(str(self._folder_path / "model.zip")) + if model_name is None: + model_name = Path(self._hf_model_id).name + if description is None: + description = f"Cross Encoder Model {model_name}" + if all_config is None: + cfg = AutoConfig.from_pretrained(self._hf_model_id) + all_config = cfg.to_json_string() + if model_type is None: + model_type = "bert" + model_format = None + if self._framework is not None: + model_format = { + 'pt': 'TORCH_SCRIPT', + 'onnx': 'ONNX' + }.get(self._framework) + if model_format is None: + raise Exception("Model format either not found or not supported. Zip the model before generating the config") + model_config_content = { + "name": model_name, + "version": f"1.0.{version_number}", + "description": description, + "model_format": model_format, + "function_name": "TEXT_SIMILARITY", + "model_content_hash_value": hash_value, + "model_config": { + "model_type": model_type, + "embedding_dimension": 1, + "framework_type": "huggingface_transformers", + "all_config": all_config, + } + } + if verbose: + print(json.dumps(model_config_content, indent=2)) + with open(self._folder_path / "config.json", "w") as f: + json.dump(model_config_content, f) + return self._folder_path / "config.json" + + def upload(self, client: OpenSearch, framework: str = 'pt', model_group_id: str = "", verbose: bool = False): + """ + Upload the model to OpenSearch + + :param client: OpenSearch client + :type client: OpenSearch + :param framework: either 'pt' or 'onnx' + :type framework: str + :param model_group_id: model group id to upload this model to + :type model_group_id: str + :param verbose: log a bunch or not + :type verbose: bool + """ + config_path = self._folder_path / "config.json" + model_path = self._folder_path / "model.zip" + gen_cfg = False + if not model_path.exists() or self._framework != framework: + gen_cfg = True + self.zip_model(framework) + if not config_path.exists() or gen_cfg: + self.make_model_config_json() + uploader = ModelUploader(client) + uploader._register_model(str(model_path), str(config_path), model_group_id, verbose) + + + diff --git a/requirements.txt b/requirements.txt index ab4e94821..b3976aee4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ sentence_transformers tqdm transformers deprecated +requests \ No newline at end of file From b36f9b95f6f27c56e7f195301a544a9fb3a2540c Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 13 Feb 2024 16:22:47 -0800 Subject: [PATCH 02/14] run nox format Signed-off-by: HenryL27 --- opensearch_py_ml/ml_models/__init__.py | 2 +- .../ml_models/crossencodermodel.py | 123 ++++++++++-------- 2 files changed, 72 insertions(+), 53 deletions(-) diff --git a/opensearch_py_ml/ml_models/__init__.py b/opensearch_py_ml/ml_models/__init__.py index 790c90efe..77e802e94 100644 --- a/opensearch_py_ml/ml_models/__init__.py +++ b/opensearch_py_ml/ml_models/__init__.py @@ -5,8 +5,8 @@ # Any modifications Copyright OpenSearch Contributors. See # GitHub history for details. +from .crossencodermodel import CrossEncoderModel from .metrics_correlation.mcorr import MCorr from .sentencetransformermodel import SentenceTransformerModel -from .crossencodermodel import CrossEncoderModel __all__ = ["SentenceTransformerModel", "MCorr", "CrossEncoderModel"] diff --git a/opensearch_py_ml/ml_models/crossencodermodel.py b/opensearch_py_ml/ml_models/crossencodermodel.py index 3cfc6dc48..c61cb26cd 100644 --- a/opensearch_py_ml/ml_models/crossencodermodel.py +++ b/opensearch_py_ml/ml_models/crossencodermodel.py @@ -6,18 +6,20 @@ # GitHub history for details. import json -from opensearch_py_ml.ml_commons import ModelUploader -from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig +import os +import shutil from pathlib import Path from zipfile import ZipFile -import shutil -import os + import requests import torch +from opensearchpy import OpenSearch +from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer + +from opensearch_py_ml.ml_commons import ModelUploader from opensearch_py_ml.ml_commons.ml_common_utils import ( _generate_model_content_hash_value, ) -from opensearchpy import OpenSearch def _fix_tokenizer(max_len: int, path: Path): @@ -31,8 +33,8 @@ def _fix_tokenizer(max_len: int, path: Path): """ with open(Path(path) / "tokenizer.json", "r") as f: parsed = json.load(f) - if "truncation" not in parsed or parsed['truncation'] is None: - parsed['truncation'] = { + if "truncation" not in parsed or parsed["truncation"] is None: + parsed["truncation"] = { "direction": "Right", "max_length": max_len, "strategy": "LongestFirst", @@ -46,11 +48,9 @@ class CrossEncoderModel: """ Class for configuring and uploading cross encoder models for opensearch """ + def __init__( - self, - hf_model_id: str, - folder_path: str = None, - overwrite: bool = False + self, hf_model_id: str, folder_path: str = None, overwrite: bool = False ) -> None: """ Initialize a new CrossEncoder model from a huggingface id @@ -72,13 +72,14 @@ def __init__( self._folder_path = Path(folder_path) if self._folder_path.exists() and not overwrite: - raise Exception(f"Folder {self._folder_path} already exists. To overwrite it, set `overwrite=True`.") + raise Exception( + f"Folder {self._folder_path} already exists. To overwrite it, set `overwrite=True`." + ) self._hf_model_id = hf_model_id self._framework = None self._folder_path.mkdir(parents=True, exist_ok=True) - def zip_model(self, framework: str = "pt") -> Path: """ Compiles and zips the model to {self._folder_path}/model.zip @@ -95,8 +96,9 @@ def zip_model(self, framework: str = "pt") -> Path: if framework == "onnx": self._framework = "onnx" return self._zip_model_onnx() - raise Exception(f"Unrecognized framework {framework}. Accepted values are `pt`, `onnx`") - + raise Exception( + f"Unrecognized framework {framework}. Accepted values are `pt`, `onnx`" + ) def _zip_model_pytorch(self) -> Path: """ @@ -109,14 +111,18 @@ def _zip_model_pytorch(self) -> Path: # bge models don't generate token type ids if mname.startswith("bge"): - features['token_type_ids'] = torch.zeros_like(features['input_ids']) + features["token_type_ids"] = torch.zeros_like(features["input_ids"]) # compile - compiled = torch.jit.trace(model, example_kwarg_inputs={ - 'input_ids': features['input_ids'], - 'attention_mask': features['attention_mask'], - 'token_type_ids': features['token_type_ids'] - }, strict=False) + compiled = torch.jit.trace( + model, + example_kwarg_inputs={ + "input_ids": features["input_ids"], + "attention_mask": features["attention_mask"], + "token_type_ids": features["token_type_ids"], + }, + strict=False, + ) torch.jit.save(compiled, f"/tmp/{mname}.pt") # save tokenizer file @@ -125,7 +131,9 @@ def _zip_model_pytorch(self) -> Path: _fix_tokenizer(tk.model_max_length, tk_path) # get apache license - r = requests.get("https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE") + r = requests.get( + "https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE" + ) with ZipFile(self._folder_path / "model.zip", "w") as f: f.write(f"/tmp/{mname}.pt", arcname=f"{mname}.pt") f.write(tk_path + "/tokenizer.json", arcname="tokenizer.json") @@ -147,23 +155,27 @@ def _zip_model_onnx(self): # bge models don't generate token type ids if mname.startswith("bge"): - features['token_type_ids'] = torch.zeros_like(features['input_ids']) + features["token_type_ids"] = torch.zeros_like(features["input_ids"]) # export to onnx onnx_model_path = f"/tmp/{mname}.onnx" torch.onnx.export( model=model, - args=(features['input_ids'], features['attention_mask'], features['token_type_ids']), + args=( + features["input_ids"], + features["attention_mask"], + features["token_type_ids"], + ), f=onnx_model_path, - input_names=['input_ids', 'attention_mask', 'token_type_ids'], - output_names=['output'], + input_names=["input_ids", "attention_mask", "token_type_ids"], + output_names=["output"], dynamic_axes={ - 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, - 'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, - 'token_type_ids': {0: 'batch_size', 1: 'sequence_length'}, - 'output': {0: 'batch_size'} + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "token_type_ids": {0: "batch_size", 1: "sequence_length"}, + "output": {0: "batch_size"}, }, - verbose=True + verbose=True, ) # save tokenizer file @@ -172,7 +184,9 @@ def _zip_model_onnx(self): _fix_tokenizer(tk.model_max_length, tk_path) # get apache license - r = requests.get("https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE") + r = requests.get( + "https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE" + ) with ZipFile(self._folder_path / "model.zip", "w") as f: f.write(onnx_model_path, arcname=f"{mname}.pt") f.write(tk_path + "/tokenizer.json", arcname="tokenizer.json") @@ -183,15 +197,14 @@ def _zip_model_onnx(self): os.remove(onnx_model_path) return self._folder_path / "model.zip" - def make_model_config_json( - self, - model_name: str = None, - version_number: str = 1, - description: str = None, - all_config: str = None, - model_type: str = None, - verbose: bool = False, + self, + model_name: str = None, + version_number: str = 1, + description: str = None, + all_config: str = None, + model_type: str = None, + verbose: bool = False, ): """ Parse from config.json file of pre-trained hugging-face model to generate a ml-commons_model_config.json file. @@ -223,7 +236,9 @@ def make_model_config_json( """ if not (self._folder_path / "model.zip").exists(): raise Exception("Generate the model zip before generating the config") - hash_value = _generate_model_content_hash_value(str(self._folder_path / "model.zip")) + hash_value = _generate_model_content_hash_value( + str(self._folder_path / "model.zip") + ) if model_name is None: model_name = Path(self._hf_model_id).name if description is None: @@ -235,12 +250,11 @@ def make_model_config_json( model_type = "bert" model_format = None if self._framework is not None: - model_format = { - 'pt': 'TORCH_SCRIPT', - 'onnx': 'ONNX' - }.get(self._framework) + model_format = {"pt": "TORCH_SCRIPT", "onnx": "ONNX"}.get(self._framework) if model_format is None: - raise Exception("Model format either not found or not supported. Zip the model before generating the config") + raise Exception( + "Model format either not found or not supported. Zip the model before generating the config" + ) model_config_content = { "name": model_name, "version": f"1.0.{version_number}", @@ -253,7 +267,7 @@ def make_model_config_json( "embedding_dimension": 1, "framework_type": "huggingface_transformers", "all_config": all_config, - } + }, } if verbose: print(json.dumps(model_config_content, indent=2)) @@ -261,7 +275,13 @@ def make_model_config_json( json.dump(model_config_content, f) return self._folder_path / "config.json" - def upload(self, client: OpenSearch, framework: str = 'pt', model_group_id: str = "", verbose: bool = False): + def upload( + self, + client: OpenSearch, + framework: str = "pt", + model_group_id: str = "", + verbose: bool = False, + ): """ Upload the model to OpenSearch @@ -283,7 +303,6 @@ def upload(self, client: OpenSearch, framework: str = 'pt', model_group_id: str if not config_path.exists() or gen_cfg: self.make_model_config_json() uploader = ModelUploader(client) - uploader._register_model(str(model_path), str(config_path), model_group_id, verbose) - - - + uploader._register_model( + str(model_path), str(config_path), model_group_id, verbose + ) From fdcdb4bc6940f22dfd081f89e4391134e73990d7 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 13 Feb 2024 16:44:50 -0800 Subject: [PATCH 03/14] update changelog Signed-off-by: HenryL27 --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 133b0abcb..aa69c8520 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Add support for model profiles by @rawwar in ([#358](https://github.com/opensearch-project/opensearch-py-ml/pull/358)) - Support for security default admin credential changes in 2.12.0 in ([#365](https://github.com/opensearch-project/opensearch-py-ml/pull/365)) - adding cross encoder models in the pre-trained traced list ([#378](https://github.com/opensearch-project/opensearch-py-ml/pull/378)) +- Add support for Cross Encoders - Trace, Config, Upload by @HenryL27 in ([#375](https://github.com/opensearch-project/opensearch-py-ml/pull/375)) ### Changed From fe1109b5f3b3a45049c98f24838a2a0278fa2cf7 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 14 Feb 2024 14:36:49 -0800 Subject: [PATCH 04/14] condense common zipping logic; allow more configurable file names Signed-off-by: HenryL27 --- .../ml_models/crossencodermodel.py | 154 +++++++++--------- 1 file changed, 77 insertions(+), 77 deletions(-) diff --git a/opensearch_py_ml/ml_models/crossencodermodel.py b/opensearch_py_ml/ml_models/crossencodermodel.py index c61cb26cd..0bd548103 100644 --- a/opensearch_py_ml/ml_models/crossencodermodel.py +++ b/opensearch_py_ml/ml_models/crossencodermodel.py @@ -79,31 +79,24 @@ def __init__( self._hf_model_id = hf_model_id self._framework = None self._folder_path.mkdir(parents=True, exist_ok=True) + self._model_zip = None + self._model_config = None - def zip_model(self, framework: str = "pt") -> Path: + def zip_model(self, framework: str = "pt", zip_fname: str = "model.zip") -> Path: """ - Compiles and zips the model to {self._folder_path}/model.zip + Compiles and zips the model to {self._folder_path}/{zip_fname} :param framework: one of "pt", "onnx". The framework to zip the model as. default: "pt" :type framework: str + :param zip_fname: path to place resulting zip file inside of self._folder_path. + Example: if folder_path is "/tmp/models" and zip_path is "zipped_up.zip" then + the file can be found at "/tmp/models/zipped_up.zip" + Default: "model.zip" + :type zip_fname: str :return: the path with the zipped model :rtype: Path """ - if framework == "pt": - self._framework = "pt" - return self._zip_model_pytorch() - if framework == "onnx": - self._framework = "onnx" - return self._zip_model_onnx() - raise Exception( - f"Unrecognized framework {framework}. Accepted values are `pt`, `onnx`" - ) - - def _zip_model_pytorch(self) -> Path: - """ - Compiles the model to TORCHSCRIPT format. - """ tk = AutoTokenizer.from_pretrained(self._hf_model_id) model = AutoModelForSequenceClassification.from_pretrained(self._hf_model_id) features = tk([["dummy sentence 1", "dummy sentence 2"]], return_tensors="pt") @@ -113,20 +106,19 @@ def _zip_model_pytorch(self) -> Path: if mname.startswith("bge"): features["token_type_ids"] = torch.zeros_like(features["input_ids"]) - # compile - compiled = torch.jit.trace( - model, - example_kwarg_inputs={ - "input_ids": features["input_ids"], - "attention_mask": features["attention_mask"], - "token_type_ids": features["token_type_ids"], - }, - strict=False, - ) - torch.jit.save(compiled, f"/tmp/{mname}.pt") + if framework == "pt": + self._framework = "pt" + model_loc = CrossEncoderModel._trace_pytorch(model, features, mname) + elif framework == "onnx": + self._framework = "onnx" + model_loc = CrossEncoderModel._trace_onnx(model, features, mname) + else: + raise Exception( + f"Unrecognized framework {framework}. Accepted values are `pt`, `onnx`" + ) # save tokenizer file - tk_path = f"/tmp/{mname}-tokenizer" + tk_path = Path(f"/tmp/{mname}-tokenizer") tk.save_pretrained(tk_path) _fix_tokenizer(tk.model_max_length, tk_path) @@ -134,31 +126,46 @@ def _zip_model_pytorch(self) -> Path: r = requests.get( "https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE" ) - with ZipFile(self._folder_path / "model.zip", "w") as f: - f.write(f"/tmp/{mname}.pt", arcname=f"{mname}.pt") - f.write(tk_path + "/tokenizer.json", arcname="tokenizer.json") + self._model_zip = self._folder_path / zip_fname + with ZipFile(self._model_zip, "w") as f: + f.write(model_loc, arcname=model_loc.name) + f.write(tk_path / "tokenizer.json", arcname="tokenizer.json") f.writestr("LICENSE", r.content) # clean up temp files - shutil.rmtree(f"/tmp/{mname}-tokenizer") - os.remove(f"/tmp/{mname}.pt") - return self._folder_path / "model.zip" + shutil.rmtree(tk_path) + os.remove(model_loc) + return self._model_zip - def _zip_model_onnx(self): + @staticmethod + def _trace_pytorch(model, features, mname) -> Path: """ - Compiles the model to ONNX format. - """ - tk = AutoTokenizer.from_pretrained(self._hf_model_id) - model = AutoModelForSequenceClassification.from_pretrained(self._hf_model_id) - features = tk([["dummy sentence 1", "dummy sentence 2"]], return_tensors="pt") - mname = Path(self._hf_model_id).name + Compiles the model to TORCHSCRIPT format. - # bge models don't generate token type ids - if mname.startswith("bge"): - features["token_type_ids"] = torch.zeros_like(features["input_ids"]) + :param features: Model input features + :return: Path to the traced model + """ + # compile + compiled = torch.jit.trace( + model, + example_kwarg_inputs={ + "input_ids": features["input_ids"], + "attention_mask": features["attention_mask"], + "token_type_ids": features["token_type_ids"], + }, + strict=False, + ) + save_loc = Path(f"/tmp/{mname}.pt") + torch.jit.save(compiled, f"/tmp/{mname}.pt") + return save_loc + @staticmethod + def _trace_onnx(model, features, mname): + """ + Compiles the model to ONNX format. + """ # export to onnx - onnx_model_path = f"/tmp/{mname}.onnx" + save_loc = Path(f"/tmp/{mname}.onnx") torch.onnx.export( model=model, args=( @@ -166,7 +173,7 @@ def _zip_model_onnx(self): features["attention_mask"], features["token_type_ids"], ), - f=onnx_model_path, + f=str(save_loc), input_names=["input_ids", "attention_mask", "token_type_ids"], output_names=["output"], dynamic_axes={ @@ -177,28 +184,11 @@ def _zip_model_onnx(self): }, verbose=True, ) - - # save tokenizer file - tk_path = f"/tmp/{mname}-tokenizer" - tk.save_pretrained(tk_path) - _fix_tokenizer(tk.model_max_length, tk_path) - - # get apache license - r = requests.get( - "https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE" - ) - with ZipFile(self._folder_path / "model.zip", "w") as f: - f.write(onnx_model_path, arcname=f"{mname}.pt") - f.write(tk_path + "/tokenizer.json", arcname="tokenizer.json") - f.writestr("LICENSE", r.content) - - # clean up temp files - shutil.rmtree(f"/tmp/{mname}-tokenizer") - os.remove(onnx_model_path) - return self._folder_path / "model.zip" + return save_loc def make_model_config_json( self, + config_fname: str = "config.json", model_name: str = None, version_number: str = 1, description: str = None, @@ -210,6 +200,11 @@ def make_model_config_json( Parse from config.json file of pre-trained hugging-face model to generate a ml-commons_model_config.json file. If all required fields are given by users, use the given parameters and will skip reading the config.json + :param config_fname: + Optional, File name of model json config file. Default is "config.json". + Controls where the config file generated by this function will appear - + "{self._folder_path}/{config_fname}" + :type config_fname: str :param model_name: Optional, The name of the model. If None, default is model id, for example, 'sentence-transformers/msmarco-distilbert-base-tas-b' @@ -234,11 +229,13 @@ def make_model_config_json( :return: model config file path. The file path where the model config file is being saved :rtype: string """ - if not (self._folder_path / "model.zip").exists(): - raise Exception("Generate the model zip before generating the config") - hash_value = _generate_model_content_hash_value( - str(self._folder_path / "model.zip") - ) + if self._model_zip is None: + raise Exception( + "No model zip file. Generate the model zip file before generating the config." + ) + if not self._model_zip.exists(): + raise Exception(f"Model zip file {self._model_zip} could not be found") + hash_value = _generate_model_content_hash_value(str(self._model_zip)) if model_name is None: model_name = Path(self._hf_model_id).name if description is None: @@ -269,11 +266,12 @@ def make_model_config_json( "all_config": all_config, }, } + self._model_config = self._folder_path / config_fname if verbose: print(json.dumps(model_config_content, indent=2)) - with open(self._folder_path / "config.json", "w") as f: + with open(self._model_config, "w") as f: json.dump(model_config_content, f) - return self._folder_path / "config.json" + return self._model_config def upload( self, @@ -294,15 +292,17 @@ def upload( :param verbose: log a bunch or not :type verbose: bool """ - config_path = self._folder_path / "config.json" - model_path = self._folder_path / "model.zip" gen_cfg = False - if not model_path.exists() or self._framework != framework: + if ( + self._model_zip is None + or not self._model_zip.exists() + or self._framework != framework + ): gen_cfg = True self.zip_model(framework) - if not config_path.exists() or gen_cfg: + if self._model_config is None or not self._model_config.exists() or gen_cfg: self.make_model_config_json() uploader = ModelUploader(client) uploader._register_model( - str(model_path), str(config_path), model_group_id, verbose + str(self._model_zip), str(self._model_config), model_group_id, verbose ) From c54b8bdc2c4c60bfa73b1767b0be2608b4d72f34 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 14 Feb 2024 15:12:50 -0800 Subject: [PATCH 05/14] add simple unit tests for model saving Signed-off-by: HenryL27 --- .../test_crossencodermodel_pytest.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tests/ml_models/test_crossencodermodel_pytest.py diff --git a/tests/ml_models/test_crossencodermodel_pytest.py b/tests/ml_models/test_crossencodermodel_pytest.py new file mode 100644 index 000000000..633c5ac11 --- /dev/null +++ b/tests/ml_models/test_crossencodermodel_pytest.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. +import shutil +from pathlib import Path + +import pytest + +from opensearch_py_ml.ml_models import CrossEncoderModel +from tests.ml_models.test_sentencetransformermodel_pytest import ( + compare_model_config, + compare_model_zip_file, +) + +TEST_FOLDER = Path(__file__) / "tests" / "test_model_files" + + +@pytest.fixture(scope="function") +def tinybert() -> CrossEncoderModel: + model = CrossEncoderModel("cross-encoder/ms-marco-TinyBERT-L-2-v2") + yield model + shutil.rmtree( + "/tmp/models/cross-encoder/ms-marco-TinyBert-L-2-v2", ignore_errors=True + ) + + +def test_pt_has_correct_files(tinybert): + zip_path = tinybert.zip_model() + config_path = tinybert.make_model_config_json() + compare_model_zip_file( + zip_file_path=zip_path, + expected_filenames=["ms-marco-TinyBERT-L-2-v2.pt", "tokenizer.json", "LICENSE"], + model_format="TORCH_SCRIPT", + ) + compare_model_config( + model_config_path=config_path, + model_id="cross-encoder/ms-marco-TinyBERT-L-2-v2", + model_format="TORCH_SCRIPT", + expected_model_description={ + "model_type": "bert", + "embedding_dimension": 1, + "framework_type": "huggingface_transformers", + }, + ) + + +def test_onnx_has_correct_files(tinybert): + zip_path = tinybert.zip_model(framework="onnx") + config_path = tinybert.make_model_config_json() + compare_model_zip_file( + zip_file_path=zip_path, + expected_filenames=[ + "ms-marco-TinyBERT-L-2-v2.onnx", + "tokenizer.json", + "LICENSE", + ], + model_format="ONNX", + ) + compare_model_config( + model_config_path=config_path, + model_id="cross-encoder/ms-marco-TinyBERT-L-2-v2", + model_format="ONNX", + expected_model_description={ + "model_type": "bert", + "embedding_dimension": 1, + "framework_type": "huggingface_transformers", + }, + ) + + +def test_can_pick_names_for_files(tinybert): + zip_path = tinybert.zip_model(framework="onnx", zip_fname="funky-model-filename.pt") + config_path = tinybert.make_model_config_json( + config_fname="funky-model-config.json" + ) + compare_model_zip_file( + zip_file_path=zip_path, + expected_filenames=["funky-model-filename.pt", "tokenizer.json", "LICENSE"], + model_format="TORCH_SCRIPT", + ) + compare_model_config( + model_config_path=config_path, + model_id="cross-encoder/ms-marco-TinyBERT-L-2-v2", + model_format="TORCH_SCRIPT", + expected_model_description={ + "model_type": "bert", + "embedding_dimension": 1, + "framework_type": "huggingface_transformers", + }, + ) + assert config_path.endswith("funky-model-config.json") From fdf94b6a5cfeeabcd550972b0e61bbc144b1674e Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 14 Feb 2024 17:46:58 -0800 Subject: [PATCH 06/14] add some more tokenizer max length checks Signed-off-by: HenryL27 --- opensearch_py_ml/ml_models/crossencodermodel.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/opensearch_py_ml/ml_models/crossencodermodel.py b/opensearch_py_ml/ml_models/crossencodermodel.py index 0bd548103..801d4414e 100644 --- a/opensearch_py_ml/ml_models/crossencodermodel.py +++ b/opensearch_py_ml/ml_models/crossencodermodel.py @@ -120,6 +120,17 @@ def zip_model(self, framework: str = "pt", zip_fname: str = "model.zip") -> Path # save tokenizer file tk_path = Path(f"/tmp/{mname}-tokenizer") tk.save_pretrained(tk_path) + if tk.model_max_length > model.get_max_length(): + model_config = AutoConfig.from_pretrained(self._hf_model_id) + if hasattr(model_config, "max_position_embeddings"): + tk.model_max_length = model_config.max_position_embeddings + elif hasattr(model_config, "n_positions"): + tk.model_max_length = model_config.n_positions + else: + tk.model_max_length = 2**15 # =32768. Set to something big I guess + print( + f"The model_max_length is not properly defined in tokenizer_config.json. Setting it to be {tk.model_max_length}" + ) _fix_tokenizer(tk.model_max_length, tk_path) # get apache license From 69e75c2ec4f6f533367cf15820e21efab893f2c5 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 14 Feb 2024 17:51:59 -0800 Subject: [PATCH 07/14] compare file sets, not lists in tests Signed-off-by: HenryL27 --- tests/ml_models/test_crossencodermodel_pytest.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/ml_models/test_crossencodermodel_pytest.py b/tests/ml_models/test_crossencodermodel_pytest.py index 633c5ac11..602df85d3 100644 --- a/tests/ml_models/test_crossencodermodel_pytest.py +++ b/tests/ml_models/test_crossencodermodel_pytest.py @@ -20,7 +20,7 @@ @pytest.fixture(scope="function") def tinybert() -> CrossEncoderModel: - model = CrossEncoderModel("cross-encoder/ms-marco-TinyBERT-L-2-v2") + model = CrossEncoderModel("cross-encoder/ms-marco-TinyBERT-L-2-v2", overwrite=True) yield model shutil.rmtree( "/tmp/models/cross-encoder/ms-marco-TinyBert-L-2-v2", ignore_errors=True @@ -32,7 +32,7 @@ def test_pt_has_correct_files(tinybert): config_path = tinybert.make_model_config_json() compare_model_zip_file( zip_file_path=zip_path, - expected_filenames=["ms-marco-TinyBERT-L-2-v2.pt", "tokenizer.json", "LICENSE"], + expected_filenames={"ms-marco-TinyBERT-L-2-v2.pt", "tokenizer.json", "LICENSE"}, model_format="TORCH_SCRIPT", ) compare_model_config( @@ -52,11 +52,11 @@ def test_onnx_has_correct_files(tinybert): config_path = tinybert.make_model_config_json() compare_model_zip_file( zip_file_path=zip_path, - expected_filenames=[ + expected_filenames={ "ms-marco-TinyBERT-L-2-v2.onnx", "tokenizer.json", "LICENSE", - ], + }, model_format="ONNX", ) compare_model_config( @@ -78,7 +78,7 @@ def test_can_pick_names_for_files(tinybert): ) compare_model_zip_file( zip_file_path=zip_path, - expected_filenames=["funky-model-filename.pt", "tokenizer.json", "LICENSE"], + expected_filenames={"funky-model-filename.pt", "tokenizer.json", "LICENSE"}, model_format="TORCH_SCRIPT", ) compare_model_config( From 9e587404795eb1b635de3637bf3c6b8a6524a3f1 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 15 Feb 2024 14:25:32 -0800 Subject: [PATCH 08/14] no model.get_max_length() function so set tokenizer max when it's None Signed-off-by: HenryL27 --- opensearch_py_ml/ml_models/crossencodermodel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/opensearch_py_ml/ml_models/crossencodermodel.py b/opensearch_py_ml/ml_models/crossencodermodel.py index 801d4414e..da459275b 100644 --- a/opensearch_py_ml/ml_models/crossencodermodel.py +++ b/opensearch_py_ml/ml_models/crossencodermodel.py @@ -120,7 +120,7 @@ def zip_model(self, framework: str = "pt", zip_fname: str = "model.zip") -> Path # save tokenizer file tk_path = Path(f"/tmp/{mname}-tokenizer") tk.save_pretrained(tk_path) - if tk.model_max_length > model.get_max_length(): + if tk.model_max_length is None: model_config = AutoConfig.from_pretrained(self._hf_model_id) if hasattr(model_config, "max_position_embeddings"): tk.model_max_length = model_config.max_position_embeddings @@ -129,7 +129,7 @@ def zip_model(self, framework: str = "pt", zip_fname: str = "model.zip") -> Path else: tk.model_max_length = 2**15 # =32768. Set to something big I guess print( - f"The model_max_length is not properly defined in tokenizer_config.json. Setting it to be {tk.model_max_length}" + f"The model_max_length is not found in tokenizer_config.json. Setting it to be {tk.model_max_length}" ) _fix_tokenizer(tk.model_max_length, tk_path) From 6f08d008f241dc364d94891f55f369f4ef6d103d Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 21 Feb 2024 14:10:56 -0800 Subject: [PATCH 09/14] change framework name pt -> torch_script Signed-off-by: HenryL27 --- .../ml_models/crossencodermodel.py | 22 +++++++++++-------- .../test_crossencodermodel_pytest.py | 4 +++- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/opensearch_py_ml/ml_models/crossencodermodel.py b/opensearch_py_ml/ml_models/crossencodermodel.py index da459275b..da8f229c3 100644 --- a/opensearch_py_ml/ml_models/crossencodermodel.py +++ b/opensearch_py_ml/ml_models/crossencodermodel.py @@ -82,12 +82,14 @@ def __init__( self._model_zip = None self._model_config = None - def zip_model(self, framework: str = "pt", zip_fname: str = "model.zip") -> Path: + def zip_model( + self, framework: str = "torch_script", zip_fname: str = "model.zip" + ) -> Path: """ Compiles and zips the model to {self._folder_path}/{zip_fname} - :param framework: one of "pt", "onnx". The framework to zip the model as. - default: "pt" + :param framework: one of "torch_script", "onnx". The framework to zip the model as. + default: "torch_script" :type framework: str :param zip_fname: path to place resulting zip file inside of self._folder_path. Example: if folder_path is "/tmp/models" and zip_path is "zipped_up.zip" then @@ -106,15 +108,15 @@ def zip_model(self, framework: str = "pt", zip_fname: str = "model.zip") -> Path if mname.startswith("bge"): features["token_type_ids"] = torch.zeros_like(features["input_ids"]) - if framework == "pt": - self._framework = "pt" + if framework == "torch_script": + self._framework = "torch_script" model_loc = CrossEncoderModel._trace_pytorch(model, features, mname) elif framework == "onnx": self._framework = "onnx" model_loc = CrossEncoderModel._trace_onnx(model, features, mname) else: raise Exception( - f"Unrecognized framework {framework}. Accepted values are `pt`, `onnx`" + f"Unrecognized framework {framework}. Accepted values are `torch_script`, `onnx`" ) # save tokenizer file @@ -258,7 +260,9 @@ def make_model_config_json( model_type = "bert" model_format = None if self._framework is not None: - model_format = {"pt": "TORCH_SCRIPT", "onnx": "ONNX"}.get(self._framework) + model_format = {"torch_script": "TORCH_SCRIPT", "onnx": "ONNX"}.get( + self._framework + ) if model_format is None: raise Exception( "Model format either not found or not supported. Zip the model before generating the config" @@ -287,7 +291,7 @@ def make_model_config_json( def upload( self, client: OpenSearch, - framework: str = "pt", + framework: str = "torch_script", model_group_id: str = "", verbose: bool = False, ): @@ -296,7 +300,7 @@ def upload( :param client: OpenSearch client :type client: OpenSearch - :param framework: either 'pt' or 'onnx' + :param framework: either 'torch_script' or 'onnx' :type framework: str :param model_group_id: model group id to upload this model to :type model_group_id: str diff --git a/tests/ml_models/test_crossencodermodel_pytest.py b/tests/ml_models/test_crossencodermodel_pytest.py index 602df85d3..8d757c2bf 100644 --- a/tests/ml_models/test_crossencodermodel_pytest.py +++ b/tests/ml_models/test_crossencodermodel_pytest.py @@ -72,7 +72,9 @@ def test_onnx_has_correct_files(tinybert): def test_can_pick_names_for_files(tinybert): - zip_path = tinybert.zip_model(framework="onnx", zip_fname="funky-model-filename.pt") + zip_path = tinybert.zip_model( + framework="torch_script", zip_fname="funky-model-filename.pt" + ) config_path = tinybert.make_model_config_json( config_fname="funky-model-config.json" ) From 588af41a0156fbee321616f2431ece58e9968bf3 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 21 Feb 2024 14:43:37 -0800 Subject: [PATCH 10/14] function_name -> model_mask_type Signed-off-by: HenryL27 --- opensearch_py_ml/ml_models/crossencodermodel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/opensearch_py_ml/ml_models/crossencodermodel.py b/opensearch_py_ml/ml_models/crossencodermodel.py index da8f229c3..37e997fee 100644 --- a/opensearch_py_ml/ml_models/crossencodermodel.py +++ b/opensearch_py_ml/ml_models/crossencodermodel.py @@ -203,7 +203,7 @@ def make_model_config_json( self, config_fname: str = "config.json", model_name: str = None, - version_number: str = 1, + version_number: str = '1.0.0', description: str = None, all_config: str = None, model_type: str = None, @@ -269,10 +269,10 @@ def make_model_config_json( ) model_config_content = { "name": model_name, - "version": f"1.0.{version_number}", + "version": version_number, "description": description, "model_format": model_format, - "function_name": "TEXT_SIMILARITY", + "model_mask_type": "TEXT_SIMILARITY", "model_content_hash_value": hash_value, "model_config": { "model_type": model_type, From b52007664e65f7e2b57187f60c9b32af2feef812 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 21 Feb 2024 18:53:41 -0800 Subject: [PATCH 11/14] model_task_type with a t. also include function name Signed-off-by: HenryL27 --- opensearch_py_ml/ml_models/crossencodermodel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/opensearch_py_ml/ml_models/crossencodermodel.py b/opensearch_py_ml/ml_models/crossencodermodel.py index 37e997fee..4cb2d602b 100644 --- a/opensearch_py_ml/ml_models/crossencodermodel.py +++ b/opensearch_py_ml/ml_models/crossencodermodel.py @@ -272,7 +272,8 @@ def make_model_config_json( "version": version_number, "description": description, "model_format": model_format, - "model_mask_type": "TEXT_SIMILARITY", + "function_name": "TEXT_SIMILARITY", + "model_task_type": "TEXT_SIMILARITY", "model_content_hash_value": hash_value, "model_config": { "model_type": model_type, From ce4a860b02c4020d6a2b8d69ed9c51e5a02dddcc Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 28 Feb 2024 08:56:18 -0800 Subject: [PATCH 12/14] add more deatiled description wiith option to pull from a readme Signed-off-by: HenryL27 --- .../ml_models/crossencodermodel.py | 86 ++++++++++++++++++- 1 file changed, 84 insertions(+), 2 deletions(-) diff --git a/opensearch_py_ml/ml_models/crossencodermodel.py b/opensearch_py_ml/ml_models/crossencodermodel.py index 4cb2d602b..d22c6c24e 100644 --- a/opensearch_py_ml/ml_models/crossencodermodel.py +++ b/opensearch_py_ml/ml_models/crossencodermodel.py @@ -7,12 +7,14 @@ import json import os +import re import shutil from pathlib import Path from zipfile import ZipFile import requests import torch +from mdutils.fileutils import MarkDownFile from opensearchpy import OpenSearch from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer @@ -203,7 +205,7 @@ def make_model_config_json( self, config_fname: str = "config.json", model_name: str = None, - version_number: str = '1.0.0', + version_number: str = "1.0.0", description: str = None, all_config: str = None, model_type: str = None, @@ -252,7 +254,20 @@ def make_model_config_json( if model_name is None: model_name = Path(self._hf_model_id).name if description is None: - description = f"Cross Encoder Model {model_name}" + readme_file_path = os.path.join(self._folder_path, "README.md") + if os.path.exists(readme_file_path): + try: + if verbose: + print("reading README.md file") + description = self._get_model_description_from_readme_file( + readme_file_path + ) + except Exception as e: + print(f"Cannot scrape model description from README.md file: {e}") + description = self._generate_default_model_description() + else: + print("Cannot find README.md file to scrape model description") + description = self._generate_default_model_description() if all_config is None: cfg = AutoConfig.from_pretrained(self._hf_model_id) all_config = cfg.to_json_string() @@ -322,3 +337,70 @@ def upload( uploader._register_model( str(self._model_zip), str(self._model_config), model_group_id, verbose ) + + def _get_model_description_from_readme_file(self, readme_file_path) -> str: + """ + Get description of the model from README.md file in the model folder + after the model is saved in local directory + + See example here: + https://huggingface.co/sentence-transformers/msmarco-distilbert-base-tas-b/blob/main/README.md) + + This function assumes that the README.md has the following format: + + # sentence-transformers/msmarco-distilbert-base-tas-b + This is [ ... further description ... ] + + # [ ... Next section ...] + ... + + :param readme_file_path: Path to README.md file + :type readme_file_path: string + :return: Description of the model + :rtype: string + """ + readme_data = MarkDownFile.read_file(readme_file_path) + + # Find the description section + start_str = f"\n# {self._hf_model_id}" + start = readme_data.find(start_str) + if start == -1: + model_name = self._hf_model_id.split("/")[1] + start_str = f"\n# {model_name}" + start = readme_data.find(start_str) + end = readme_data.find("\n#", start + len(start_str)) + + # If we cannot find the scope of description section, raise error. + if start == -1 or end == -1: + assert False, "Cannot find description in README.md file" + + # Parse out the description section + description = readme_data[start + len(start_str) + 1 : end].strip() + description = description.split("\n")[0] + + # Remove hyperlink and reformat text + description = re.sub(r"\(.*?\)", "", description) + description = re.sub(r"[\[\]]", "", description) + description = re.sub(r"\*", "", description) + + # Remove unnecessary part if exists (i.e. " For an introduction to ...") + # (Found in https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-dot-v1/blob/main/README.md) + unnecessary_part = description.find(" For an introduction to") + if unnecessary_part != -1: + description = description[:unnecessary_part] + + return description + + def _generate_default_model_description(self) -> str: + """ + Generate default model description of the model based on embedding_dimension + + :return: Description of the model + :rtype: string + """ + print( + "Using default description instead (You can overwrite this by specifying description parameter in \ +make_model_config_json function)" + ) + description = f"This is a cross-encoder model: It maps (query, passage) pairs to real-valued relevance scores." + return description From 2fb551bd919ecb1b88c5da183be1e93331910c11 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 29 Feb 2024 09:10:55 -0800 Subject: [PATCH 13/14] fix test and lint issue Signed-off-by: HenryL27 --- opensearch_py_ml/ml_models/crossencodermodel.py | 2 +- tests/ml_models/test_crossencodermodel_pytest.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/opensearch_py_ml/ml_models/crossencodermodel.py b/opensearch_py_ml/ml_models/crossencodermodel.py index d22c6c24e..6b7535385 100644 --- a/opensearch_py_ml/ml_models/crossencodermodel.py +++ b/opensearch_py_ml/ml_models/crossencodermodel.py @@ -402,5 +402,5 @@ def _generate_default_model_description(self) -> str: "Using default description instead (You can overwrite this by specifying description parameter in \ make_model_config_json function)" ) - description = f"This is a cross-encoder model: It maps (query, passage) pairs to real-valued relevance scores." + description = "This is a cross-encoder model: It maps (query, passage) pairs to real-valued relevance scores." return description diff --git a/tests/ml_models/test_crossencodermodel_pytest.py b/tests/ml_models/test_crossencodermodel_pytest.py index 8d757c2bf..e20a45b3b 100644 --- a/tests/ml_models/test_crossencodermodel_pytest.py +++ b/tests/ml_models/test_crossencodermodel_pytest.py @@ -73,14 +73,15 @@ def test_onnx_has_correct_files(tinybert): def test_can_pick_names_for_files(tinybert): zip_path = tinybert.zip_model( - framework="torch_script", zip_fname="funky-model-filename.pt" + framework="torch_script", zip_fname="funky-model-filename.zip" ) config_path = tinybert.make_model_config_json( config_fname="funky-model-config.json" ) + assert (tinybert._folder_path / "funky-model-filename.zip").exists() compare_model_zip_file( zip_file_path=zip_path, - expected_filenames={"funky-model-filename.pt", "tokenizer.json", "LICENSE"}, + expected_filenames={"ms-marco-TinyBERT-L-2-v2.pt", "tokenizer.json", "LICENSE"}, model_format="TORCH_SCRIPT", ) compare_model_config( From b063f3218cd181a7e03e7dcec740f4191fb96e22 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 21 May 2024 09:30:19 -0700 Subject: [PATCH 14/14] address pr comments Signed-off-by: Henry Lindeman --- .../ml_models/crossencodermodel.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/opensearch_py_ml/ml_models/crossencodermodel.py b/opensearch_py_ml/ml_models/crossencodermodel.py index 6b7535385..700a45dd5 100644 --- a/opensearch_py_ml/ml_models/crossencodermodel.py +++ b/opensearch_py_ml/ml_models/crossencodermodel.py @@ -62,7 +62,7 @@ def __init__( :param folder_path: folder path to save the model default is /tmp/models/hf_model_id :type folder_path: str - :param overwrite: whether to overwrite the existing model + :param overwrite: whether to overwrite the existing model at folder+path :type overwrite: bool :return: None """ @@ -104,25 +104,25 @@ def zip_model( tk = AutoTokenizer.from_pretrained(self._hf_model_id) model = AutoModelForSequenceClassification.from_pretrained(self._hf_model_id) features = tk([["dummy sentence 1", "dummy sentence 2"]], return_tensors="pt") - mname = Path(self._hf_model_id).name + model_name = Path(self._hf_model_id).name # bge models don't generate token type ids - if mname.startswith("bge"): + if model_name.startswith("bge"): features["token_type_ids"] = torch.zeros_like(features["input_ids"]) if framework == "torch_script": self._framework = "torch_script" - model_loc = CrossEncoderModel._trace_pytorch(model, features, mname) + model_loc = CrossEncoderModel._trace_pytorch(model, features, model_name) elif framework == "onnx": self._framework = "onnx" - model_loc = CrossEncoderModel._trace_onnx(model, features, mname) + model_loc = CrossEncoderModel._trace_onnx(model, features, model_name) else: raise Exception( f"Unrecognized framework {framework}. Accepted values are `torch_script`, `onnx`" ) # save tokenizer file - tk_path = Path(f"/tmp/{mname}-tokenizer") + tk_path = Path(f"/tmp/{model_name}-tokenizer") tk.save_pretrained(tk_path) if tk.model_max_length is None: model_config = AutoConfig.from_pretrained(self._hf_model_id) @@ -153,7 +153,7 @@ def zip_model( return self._model_zip @staticmethod - def _trace_pytorch(model, features, mname) -> Path: + def _trace_pytorch(model, features, model_name) -> Path: """ Compiles the model to TORCHSCRIPT format. @@ -170,17 +170,17 @@ def _trace_pytorch(model, features, mname) -> Path: }, strict=False, ) - save_loc = Path(f"/tmp/{mname}.pt") - torch.jit.save(compiled, f"/tmp/{mname}.pt") + save_loc = Path(f"/tmp/{model_name}.pt") + torch.jit.save(compiled, f"/tmp/{model_name}.pt") return save_loc @staticmethod - def _trace_onnx(model, features, mname): + def _trace_onnx(model, features, model_name): """ Compiles the model to ONNX format. """ # export to onnx - save_loc = Path(f"/tmp/{mname}.onnx") + save_loc = Path(f"/tmp/{model_name}.onnx") torch.onnx.export( model=model, args=(