Skip to content

Commit

Permalink
Set up cluster for auto-tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
thanawan-atc committed Jul 8, 2023
1 parent fb860af commit 77f7a97
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 31 deletions.
4 changes: 2 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
@nox.session(reuse_venv=True)
def format(session):
session.install("black", "isort", "flynt")
session.run("python", "utils/license-headers.py", "fix", *SOURCE_FILES)
session.run("python", "utils/lint/license-headers.py", "fix", *SOURCE_FILES)
session.run("flynt", *SOURCE_FILES)
session.run("black", "--target-version=py38", *SOURCE_FILES)
session.run("isort", "--profile=black", *SOURCE_FILES)
Expand All @@ -73,7 +73,7 @@ def lint(session):
# Install numpy to use its mypy plugin
# https://numpy.org/devdocs/reference/typing.html#mypy-plugin
session.install("black", "flake8", "mypy", "isort", "numpy")
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
session.run("python", "utils/lint/license-headers.py", "check", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py38", *SOURCE_FILES)
session.run("isort", "--check", "--profile=black", *SOURCE_FILES)
session.run("flake8", "--ignore=E501,W503,E402,E712,E203", *SOURCE_FILES)
Expand Down
File renamed without changes.
29 changes: 0 additions & 29 deletions utils/model_autotracing.py

This file was deleted.

52 changes: 52 additions & 0 deletions utils/model_uploader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import opensearchpy
from opensearchpy import OpenSearch

from opensearch_py_ml.common import os_version

OPENSEARCH_HOST = "https://instance:9200"
OPENSEARCH_ADMIN_USER, OPENSEARCH_ADMIN_PASSWORD = "admin", "admin"

# Define client to use in workflow
OPENSEARCH_TEST_CLIENT = OpenSearch(
hosts=[OPENSEARCH_HOST],
http_auth=(OPENSEARCH_ADMIN_USER, OPENSEARCH_ADMIN_PASSWORD),
verify_certs=False,
)
# in github automated workflow, host url is: https://instance:9200
# in development, usually host url is: https://localhost:9200
# it's hard to remember changing the host url. So applied a try catch so that we don't have to keep change this config
try:
OS_VERSION = os_version(OPENSEARCH_TEST_CLIENT)
except opensearchpy.exceptions.ConnectionError:
OPENSEARCH_HOST = "https://localhost:9200"
# Define client to use in tests
OPENSEARCH_TEST_CLIENT = OpenSearch(
hosts=[OPENSEARCH_HOST],
http_auth=(OPENSEARCH_ADMIN_USER, OPENSEARCH_ADMIN_PASSWORD),
verify_certs=False,
)
OS_VERSION = os_version(OPENSEARCH_TEST_CLIENT)
81 changes: 81 additions & 0 deletions utils/model_uploader/model_autotracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

import argparse
from opensearchpy import OpenSearch
from opensearch_py_ml.ml_commons import MLCommonClient
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel
from tests import OPENSEARCH_TEST_CLIENT
import warnings

TORCH_SCRIPT_FORMAT = 'TORCH_SCRIPT'
ONNX_FORMAT = 'ONNX'
TORCHSCRIPT_FOLDER_PATH = "sentence-transformers-torchscript/"
ONXX_FOLDER_PATH = "sentence-transformers-onxx/"
MODEL_CONFIG_FILE_NAME = "ml-commons_model_config.json"
TEST_SENTENCES = ["First test sentence", "Second test sentence"]
RTOL_TEST = 1e-03
ATOL_TEST = 1e-05
ML_BASE_URI = "/_plugins/_ml"

def main(args):
print("ARGUMENTS")
print(
args.model_id,
args.model_version,
args.tracing_format,
args.embedding_dimension,
args.pooling_mode,
)
ml_client = MLCommonClient(OPENSEARCH_TEST_CLIENT)



if __name__ == "__main__":
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings("ignore", message="Unverified HTTPS request")
warnings.filterwarnings("ignore", message="TracerWarning: torch.tensor")
warnings.filterwarnings("ignore", message="using SSL with verify_certs=False is insecure.")

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"model_id",
type=str,
help="Model ID for auto-tracing and uploading (e.g. sentence-transformers/msmarco-distilbert-base-tas-b)",
)
parser.add_argument(
"model_version", type=str, help="Model version number (e.g. 1.0.1)"
)
parser.add_argument(
"tracing_format",
choices=["BOTH", "TORCH_SCRIPT", "ONNX"],
help="Model format for auto-tracing",
)
parser.add_argument(
"-ed",
"--embedding_dimension",
type=int,
nargs="?",
default=None,
const=None,
help="Embedding dimension of the model to use if it does not exist in original config.json",
)
parser.add_argument(
"-pm",
"--pooling_mode",
type=str,
nargs="?",
default=None,
const=None,
choices=["CLS", "MEAN", "MAX", "MEAN_SQRT_LEN"],
help="Pooling mode if it does not exist in original config.json",
)
args = parser.parse_args()

main(args)

0 comments on commit 77f7a97

Please sign in to comment.