diff --git a/utils/model_uploader/model_autotracing.py b/utils/model_uploader/model_autotracing.py index 678a3600..374955fd 100644 --- a/utils/model_uploader/model_autotracing.py +++ b/utils/model_uploader/model_autotracing.py @@ -6,15 +6,17 @@ # GitHub history for details. import argparse +import warnings + from opensearchpy import OpenSearch + from opensearch_py_ml.ml_commons import MLCommonClient from opensearch_py_ml.ml_commons.model_uploader import ModelUploader from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel from tests import OPENSEARCH_TEST_CLIENT -import warnings -TORCH_SCRIPT_FORMAT = 'TORCH_SCRIPT' -ONNX_FORMAT = 'ONNX' +TORCH_SCRIPT_FORMAT = "TORCH_SCRIPT" +ONNX_FORMAT = "ONNX" TORCHSCRIPT_FOLDER_PATH = "sentence-transformers-torchscript/" ONXX_FOLDER_PATH = "sentence-transformers-onxx/" MODEL_CONFIG_FILE_NAME = "ml-commons_model_config.json" @@ -23,6 +25,7 @@ ATOL_TEST = 1e-05 ML_BASE_URI = "/_plugins/_ml" + def main(args): print("ARGUMENTS") print( @@ -33,16 +36,17 @@ def main(args): args.pooling_mode, ) ml_client = MLCommonClient(OPENSEARCH_TEST_CLIENT) - if __name__ == "__main__": - warnings.filterwarnings('ignore', category=DeprecationWarning) - warnings.filterwarnings('ignore', category=FutureWarning) + warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", message="Unverified HTTPS request") warnings.filterwarnings("ignore", message="TracerWarning: torch.tensor") - warnings.filterwarnings("ignore", message="using SSL with verify_certs=False is insecure.") - + warnings.filterwarnings( + "ignore", message="using SSL with verify_certs=False is insecure." + ) + parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "model_id",