From 77f7a97652273eec575362ce174e6c2adc199eda Mon Sep 17 00:00:00 2001 From: Thanawan Atchariyachanvanit Date: Sat, 8 Jul 2023 06:30:24 +0000 Subject: [PATCH] Set up cluster for auto-tracing --- noxfile.py | 4 +- utils/{ => lint}/license-headers.py | 0 utils/model_autotracing.py | 29 -------- utils/model_uploader/__init__.py | 52 +++++++++++++++ utils/model_uploader/model_autotracing.py | 81 +++++++++++++++++++++++ 5 files changed, 135 insertions(+), 31 deletions(-) rename utils/{ => lint}/license-headers.py (100%) delete mode 100644 utils/model_autotracing.py create mode 100644 utils/model_uploader/__init__.py create mode 100644 utils/model_uploader/model_autotracing.py diff --git a/noxfile.py b/noxfile.py index 448c3990..c9a835d0 100644 --- a/noxfile.py +++ b/noxfile.py @@ -61,7 +61,7 @@ @nox.session(reuse_venv=True) def format(session): session.install("black", "isort", "flynt") - session.run("python", "utils/license-headers.py", "fix", *SOURCE_FILES) + session.run("python", "utils/lint/license-headers.py", "fix", *SOURCE_FILES) session.run("flynt", *SOURCE_FILES) session.run("black", "--target-version=py38", *SOURCE_FILES) session.run("isort", "--profile=black", *SOURCE_FILES) @@ -73,7 +73,7 @@ def lint(session): # Install numpy to use its mypy plugin # https://numpy.org/devdocs/reference/typing.html#mypy-plugin session.install("black", "flake8", "mypy", "isort", "numpy") - session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) + session.run("python", "utils/lint/license-headers.py", "check", *SOURCE_FILES) session.run("black", "--check", "--target-version=py38", *SOURCE_FILES) session.run("isort", "--check", "--profile=black", *SOURCE_FILES) session.run("flake8", "--ignore=E501,W503,E402,E712,E203", *SOURCE_FILES) diff --git a/utils/license-headers.py b/utils/lint/license-headers.py similarity index 100% rename from utils/license-headers.py rename to utils/lint/license-headers.py diff --git a/utils/model_autotracing.py b/utils/model_autotracing.py deleted file mode 100644 index 811255b5..00000000 --- a/utils/model_autotracing.py +++ /dev/null @@ -1,29 +0,0 @@ -import argparse - -def main(args): - print('ARGUMENTS') - print(args.model_id, args.model_version, args.tracing_format, args.embedding_dimension, args.pooling_mode - ) - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument('model_id', type=str, - help="Model ID for auto-tracing and uploading (e.g. sentence-transformers/msmarco-distilbert-base-tas-b)") - parser.add_argument('model_version', type=str, - help="Model version number (e.g. 1.0.1)") - parser.add_argument('tracing_format', choices=['BOTH', 'TORCH_SCRIPT', 'ONNX'], - help="Model format for auto-tracing") - parser.add_argument('-ed', '--embedding_dimension', - type=int, nargs='?', default=None, const=None, - help="Embedding dimension of the model to use if it does not exist in original config.json") - parser.add_argument('-pm', '--pooling_mode', - type=str, nargs='?', default=None, const=None, - choices=['CLS', 'MEAN', 'MAX', 'MEAN_SQRT_LEN'], - help="Pooling mode if it does not exist in original config.json") - args = parser.parse_args() - - main(args) - - - - diff --git a/utils/model_uploader/__init__.py b/utils/model_uploader/__init__.py new file mode 100644 index 00000000..3d86e7c8 --- /dev/null +++ b/utils/model_uploader/__init__.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import opensearchpy +from opensearchpy import OpenSearch + +from opensearch_py_ml.common import os_version + +OPENSEARCH_HOST = "https://instance:9200" +OPENSEARCH_ADMIN_USER, OPENSEARCH_ADMIN_PASSWORD = "admin", "admin" + +# Define client to use in workflow +OPENSEARCH_TEST_CLIENT = OpenSearch( + hosts=[OPENSEARCH_HOST], + http_auth=(OPENSEARCH_ADMIN_USER, OPENSEARCH_ADMIN_PASSWORD), + verify_certs=False, +) +# in github automated workflow, host url is: https://instance:9200 +# in development, usually host url is: https://localhost:9200 +# it's hard to remember changing the host url. So applied a try catch so that we don't have to keep change this config +try: + OS_VERSION = os_version(OPENSEARCH_TEST_CLIENT) +except opensearchpy.exceptions.ConnectionError: + OPENSEARCH_HOST = "https://localhost:9200" + # Define client to use in tests + OPENSEARCH_TEST_CLIENT = OpenSearch( + hosts=[OPENSEARCH_HOST], + http_auth=(OPENSEARCH_ADMIN_USER, OPENSEARCH_ADMIN_PASSWORD), + verify_certs=False, + ) + OS_VERSION = os_version(OPENSEARCH_TEST_CLIENT) diff --git a/utils/model_uploader/model_autotracing.py b/utils/model_uploader/model_autotracing.py new file mode 100644 index 00000000..678a3600 --- /dev/null +++ b/utils/model_uploader/model_autotracing.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +import argparse +from opensearchpy import OpenSearch +from opensearch_py_ml.ml_commons import MLCommonClient +from opensearch_py_ml.ml_commons.model_uploader import ModelUploader +from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel +from tests import OPENSEARCH_TEST_CLIENT +import warnings + +TORCH_SCRIPT_FORMAT = 'TORCH_SCRIPT' +ONNX_FORMAT = 'ONNX' +TORCHSCRIPT_FOLDER_PATH = "sentence-transformers-torchscript/" +ONXX_FOLDER_PATH = "sentence-transformers-onxx/" +MODEL_CONFIG_FILE_NAME = "ml-commons_model_config.json" +TEST_SENTENCES = ["First test sentence", "Second test sentence"] +RTOL_TEST = 1e-03 +ATOL_TEST = 1e-05 +ML_BASE_URI = "/_plugins/_ml" + +def main(args): + print("ARGUMENTS") + print( + args.model_id, + args.model_version, + args.tracing_format, + args.embedding_dimension, + args.pooling_mode, + ) + ml_client = MLCommonClient(OPENSEARCH_TEST_CLIENT) + + + +if __name__ == "__main__": + warnings.filterwarnings('ignore', category=DeprecationWarning) + warnings.filterwarnings('ignore', category=FutureWarning) + warnings.filterwarnings("ignore", message="Unverified HTTPS request") + warnings.filterwarnings("ignore", message="TracerWarning: torch.tensor") + warnings.filterwarnings("ignore", message="using SSL with verify_certs=False is insecure.") + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "model_id", + type=str, + help="Model ID for auto-tracing and uploading (e.g. sentence-transformers/msmarco-distilbert-base-tas-b)", + ) + parser.add_argument( + "model_version", type=str, help="Model version number (e.g. 1.0.1)" + ) + parser.add_argument( + "tracing_format", + choices=["BOTH", "TORCH_SCRIPT", "ONNX"], + help="Model format for auto-tracing", + ) + parser.add_argument( + "-ed", + "--embedding_dimension", + type=int, + nargs="?", + default=None, + const=None, + help="Embedding dimension of the model to use if it does not exist in original config.json", + ) + parser.add_argument( + "-pm", + "--pooling_mode", + type=str, + nargs="?", + default=None, + const=None, + choices=["CLS", "MEAN", "MAX", "MEAN_SQRT_LEN"], + help="Pooling mode if it does not exist in original config.json", + ) + args = parser.parse_args() + + main(args)