diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ff975b24..c7ead57c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,5 @@ --- -exclude: ^(.*egg.info.*|.*/parameters.py$|.*\.py_template|.*/experimental/.*|.*/fixtures/.*|docs/source/_themes/.*) -minimum_pre_commit_version: 3.4.0 +exclude: ^(.*egg.info.*|.*/parameters.py$|.*\.py_template|.*/experimental/.*|.*/fixtures/.*|docs/source/_themes/.*|.*\.patch) repos: - repo: https://github.com/asottile/pyupgrade rev: v2.31.1 @@ -65,7 +64,7 @@ repos: - id: markdownlint-fix language_version: 16.20.2 - repo: https://github.com/keith/pre-commit-buildifier - rev: 6.0.0 + rev: 7.3.1 hooks: - id: buildifier args: @@ -84,7 +83,7 @@ repos: exclude_types: - image - repo: https://github.com/lyz-code/yamlfix - rev: 1.13.0 + rev: 1.16.1 hooks: - id: yamlfix args: diff --git a/CHANGELOG.md b/CHANGELOG.md index 6afd884c..269e9e4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,25 @@ # Release History -## 1.6.2 (TBD) +## 1.6.3 + +- Model Registry (PrPr) has been removed. + +### Bug Fixes + +- Registry: Fix a bug that when package whose name does not follow PEP-508 is provided when logging the model, + an unexpected normalization is happening. +- Registry: Fix `not a valid remote uri` error when logging mlflow models. +- Registry: Fix a bug that `ModelVersion.run` is called in a nested way. +- Registry: Fix an issue that leads to `log_model` failure when local package version contains parts other than + base version. + +### New Features + +- Data: Improve `DataConnector.to_pandas()` performance when loading from Snowpark DataFrames. +- Model Registry: Allow users to set a model task while using `log_model`. +- Feature Store: FeatureView supports ON_CREATE or ON_SCHEDULE initialize mode. + +## 1.6.2 (2024-09-04) ### Bug Fixes @@ -18,8 +37,6 @@ - Data: Add native batching support via `batch_size` and `drop_last_batch` arguments to `DataConnector.to_torch_dataset()` - Feature Store: update_feature_view() supports taking feature view object as argument. -### Behavior Changes - ## 1.6.1 (2024-08-12) ### Bug Fixes @@ -42,8 +59,6 @@ - Registry: Option to `enable_explainability` set to True by default for XGBoost, LightGBM and CatBoost as PuPr feature. - Registry: Option to `enable_explainability` when registering SHAP supported sklearn models. -### Behavior Changes - ## 1.6.0 (2024-07-29) ### Bug Fixes diff --git a/bazel/requirements/templates/meta.tpl.yaml b/bazel/requirements/templates/meta.tpl.yaml index d7f3fc96..36556103 100644 --- a/bazel/requirements/templates/meta.tpl.yaml +++ b/bazel/requirements/templates/meta.tpl.yaml @@ -11,11 +11,9 @@ build: requirements: build: - python - - bazel >=6.0.0 + - bazel==6.3.0 run: - python>=3.8,<3.12 - run_constrained: - - openjpeg !=2.4.0=*_1 # [win] about: home: https://github.com/snowflakedb/snowflake-ml-python diff --git a/ci/build_and_run_tests.sh b/ci/build_and_run_tests.sh index 29b80616..2ac9d1ec 100755 --- a/ci/build_and_run_tests.sh +++ b/ci/build_and_run_tests.sh @@ -1,7 +1,7 @@ #!/bin/bash # Usage -# build_and_run_tests.sh [-b ] [--env pip|conda] [--mode merge_gate|continuous_run] [--with-snowpark] [--report ] +# build_and_run_tests.sh [-b ] [--env pip|conda] [--mode merge_gate|continuous_run] [--with-snowpark] [--with-spcs-image] [--report ] # # Args # workspace: path to the workspace, SnowML code should be in snowml directory. @@ -14,6 +14,7 @@ # continuous_run (default): run all tests. (For nightly run. Alias: release) # quarantined: run all quarantined tests. # with-snowpark: Build and test with snowpark in snowpark-python directory in the workspace. +# with-spcs-image: Build and test with spcs-image in spcs-image directory in the workspace. # snowflake-env: The environment of the snowflake, use to determine the test quarantine list # report: Path to xml test report # @@ -29,7 +30,7 @@ PROG=$0 help() { local exit_code=$1 - echo "Usage: ${PROG} [-b ] [--env pip|conda] [--mode merge_gate|continuous_run|quarantined] [--with-snowpark] [--snowflake-env ] [--report ]" + echo "Usage: ${PROG} [-b ] [--env pip|conda] [--mode merge_gate|continuous_run|quarantined] [--with-snowpark] [--with-spcs-image] [--snowflake-env ] [--report ]" exit "${exit_code}" } @@ -37,6 +38,7 @@ WORKSPACE=$1 && shift || help 1 BAZEL="bazel" ENV="pip" WITH_SNOWPARK=false +WITH_SPCS_IMAGE=false MODE="continuous_run" PYTHON_VERSION=3.8 PYTHON_ENABLE_SCRIPT="bin/activate" @@ -86,6 +88,9 @@ while (($#)); do shift PYTHON_VERSION=$1 ;; + --with-spcs-image) + WITH_SPCS_IMAGE=true + ;; -h | --help) help 0 ;; @@ -260,11 +265,18 @@ else # Build SnowML pushd ${SNOWML_DIR} # Build conda package - conda build --prefix-length 50 --python=${PYTHON_VERSION} --croot "${WORKSPACE}/conda-bld" ci/conda_recipe + conda build -c conda-forge --override-channels --prefix-length 50 --python=${PYTHON_VERSION} --croot "${WORKSPACE}/conda-bld" ci/conda_recipe conda build purge popd fi +if [[ "${WITH_SPCS_IMAGE}" = true ]]; then + pushd ${SNOWML_DIR} + # Build SPCS Image + source model_container_services_deployment/ci/build_and_push_images.sh + popd +fi + # Start testing pushd "${TEMP_TEST_DIR}" @@ -281,6 +293,11 @@ if [[ -n "${JUNIT_REPORT_PATH}" ]]; then fi if [ "${ENV}" = "pip" ]; then + if [ "${WITH_SPCS_IMAGE}" = true ]; then + COMMON_PYTEST_FLAG+=(-m "spcs_deployment_image and not pip_incompatible") + else + COMMON_PYTEST_FLAG+=(-m "not pip_incompatible") + fi # Copy wheel package cp "${WORKSPACE}/snowflake_ml_python-${VERSION}-py3-none-any.whl" "${TEMP_TEST_DIR}" @@ -302,10 +319,15 @@ if [ "${ENV}" = "pip" ]; then # Run the tests set +e - TEST_SRCDIR="${TEMP_TEST_DIR}" python -m pytest "${COMMON_PYTEST_FLAG[@]}" -m "not pip_incompatible" tests/integ/ + TEST_SRCDIR="${TEMP_TEST_DIR}" python -m pytest "${COMMON_PYTEST_FLAG[@]}" tests/integ/ TEST_RETCODE=$? set -e else + if [ "${WITH_SPCS_IMAGE}" = true ]; then + COMMON_PYTEST_FLAG+=(-m "spcs_deployment_image and not conda_incompatible") + else + COMMON_PYTEST_FLAG+=(-m "not conda_incompatible") + fi # Create local conda channel conda index "${WORKSPACE}/conda-bld" @@ -319,7 +341,7 @@ else # Run integration tests set +e - TEST_SRCDIR="${TEMP_TEST_DIR}" conda run -p testenv --no-capture-output python -m pytest "${COMMON_PYTEST_FLAG[@]}" -m "not conda_incompatible" tests/integ/ + TEST_SRCDIR="${TEMP_TEST_DIR}" conda run -p testenv --no-capture-output python -m pytest "${COMMON_PYTEST_FLAG[@]}" tests/integ/ TEST_RETCODE=$? set -e diff --git a/ci/conda_recipe/README.md b/ci/conda_recipe/README.md index 54824f7d..df3843c1 100644 --- a/ci/conda_recipe/README.md +++ b/ci/conda_recipe/README.md @@ -6,7 +6,7 @@ Conda's guide on building a conda package from a wheel: To invoke conda build: ```sh -conda build --prefix-length=0 --python=[3.8|3.9|3.10|3.11] ci/conda_recipe +conda build -c conda-forge --override-channels --prefix-length=0 --python=[3.8|3.9|3.10|3.11] ci/conda_recipe ``` - `--prefix-length=0`: prevent the conda build environment from being created in diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index ff824bda..c03b1964 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -17,11 +17,11 @@ build: noarch: python package: name: snowflake-ml-python - version: 1.6.2 + version: 1.6.3 requirements: build: - python - - bazel >=6.0.0 + - bazel==6.3.0 run: - absl-py>=0.15,<2 - aiohttp!=4.0.0a0, !=4.0.0a1 @@ -39,7 +39,7 @@ requirements: - requests - retrying>=1.3.3,<2 - s3fs>=2022.11,<2024 - - scikit-learn>=1.2.1,<1.4 + - scikit-learn>=1.2.1,<1.6 - scipy>=1.9,<2 - snowflake-connector-python>=3.5.0,<4 - snowflake-snowpark-python>=1.17.0,<2 @@ -54,11 +54,10 @@ requirements: - pytorch>=2.0.1,<2.3.0 - sentence-transformers>=2.2.2,<3 - sentencepiece>=0.1.95,<1 - - shap==0.42.1 + - shap>=0.42.0,<1 - tensorflow>=2.10,<3 - tokenizers>=0.10,<1 - torchdata>=0.4,<1 - transformers>=4.32.1,<5 - - openjpeg !=2.4.0=*_1 # [win] source: path: ../../ diff --git a/ci/targets/local_only.txt b/ci/targets/local_only.txt index ede3c16a..e69de29b 100644 --- a/ci/targets/local_only.txt +++ b/ci/targets/local_only.txt @@ -1,2 +0,0 @@ -//snowflake/ml/model/_deploy_client/image_builds/inference_server:gpu_test -//snowflake/ml/model/_deploy_client/image_builds/inference_server:main_vllm_test diff --git a/ci/targets/quarantine/prod3.txt b/ci/targets/quarantine/prod3.txt index a4b6155a..4288f47d 100644 --- a/ci/targets/quarantine/prod3.txt +++ b/ci/targets/quarantine/prod3.txt @@ -1,5 +1,7 @@ -//tests/integ/snowflake/ml/model:deployment_to_snowservice_integ_test -//tests/integ/snowflake/ml/registry:model_registry_snowservice_integ_test -//tests/integ/snowflake/ml/model:spcs_llm_model_integ_test +//snowflake/ml/model/_packager/model_handlers_test:mlflow_test //tests/integ/snowflake/ml/extra_tests:xgboost_external_memory_training_test -//tests/integ/snowflake/ml/registry:model_registry_snowservice_merge_gate_integ_test +//tests/integ/snowflake/ml/modeling/ensemble:isolation_forest_test +//tests/integ/snowflake/ml/modeling/linear_model:sgd_one_class_svm_test +//tests/integ/snowflake/ml/modeling/preprocessing:k_bins_discretizer_test +//tests/integ/snowflake/ml/registry/model:registry_mlflow_model_test +//tests/integ/snowflake/ml/registry/services/... diff --git a/ci/targets/slow.txt b/ci/targets/slow.txt index b51d2abf..e69de29b 100644 --- a/ci/targets/slow.txt +++ b/ci/targets/slow.txt @@ -1,3 +0,0 @@ -//tests/integ/snowflake/ml/model:deployment_to_snowservice_integ_test -//tests/integ/snowflake/ml/registry:model_registry_snowservice_integ_test -//tests/integ/snowflake/ml/model:spcs_llm_model_integ_test diff --git a/codegen/sklearn_wrapper_generator.py b/codegen/sklearn_wrapper_generator.py index 6528a2de..ecaee16a 100644 --- a/codegen/sklearn_wrapper_generator.py +++ b/codegen/sklearn_wrapper_generator.py @@ -1058,12 +1058,41 @@ def generate(self) -> "SklearnWrapperGenerator": ] self.test_estimator_input_args_list.append(f"dictionary={dictionary}") + if WrapperGeneratorFactory._is_class_of_type(self.class_object[1], "Isomap"): + # Using higher n_neighbors for Isomap to balance accuracy and performance. + self.test_estimator_input_args_list.append("n_neighbors=30") + + if WrapperGeneratorFactory._is_class_of_type( + self.class_object[1], "KNeighborsClassifier" + ) or WrapperGeneratorFactory._is_class_of_type(self.class_object[1], "RadiusNeighborsClassifier"): + # Use distance-based weighting to reduce ties and improve prediction accuracy. + self.test_estimator_input_args_list.append("weights='distance'") + + if WrapperGeneratorFactory._is_class_of_type(self.class_object[1], "Nystroem"): + # Setting specific parameters for Nystroem to ensure a meaningful transformation. + # - `gamma`: Controls the shape of the RBF kernel. By setting gamma to a lower value + # like 0.1, you can help generate larger transformation values in the output, making the + # transformation less sensitive to small variations in the input data. This value also + # balances between underfitting and overfitting for most datasets. + # - `n_components`: Specifies a larger number of components for the approximation, + # which enhances the accuracy of the kernel approximation. This is especially useful + # in higher-dimensional data or when a more precise transformation is needed. + self.test_estimator_input_args_list.append("gamma=0.1") + self.test_estimator_input_args_list.append("n_components=200") + if WrapperGeneratorFactory._is_class_of_type(self.class_object[1], "SelectKBest"): # Set the k of SelectKBest features transformer to half the number of columns in the dataset. self.test_estimator_input_args_list.append("k=int(len(cols)/2)") if "n_components" in self.original_init_signature.parameters.keys(): - if WrapperGeneratorFactory._is_class_of_type(self.class_object[1], "SpectralBiclustering"): + if self.original_class_name == "KernelPCA": + # Explicitly set 'n_components' to the number of input columns (len(cols)) + # to ensure consistency between implementations. This is necessary because + # the default behavior might differ, with 'n_components' otherwise defaulting + # to the minimum of the number of features or samples, potentially leading to + # discrepancies between the implementations. + self.test_estimator_input_args_list.append("n_components=int(len(cols)/2)") + elif WrapperGeneratorFactory._is_class_of_type(self.class_object[1], "SpectralBiclustering"): # For spectral bi clustering, set number of singular vectors to consider to number of input cols and # num best vector to select to half the number of input cols. self.test_estimator_input_args_list.append("n_components=len(cols)") diff --git a/codegen/sklearn_wrapper_template.py_template b/codegen/sklearn_wrapper_template.py_template index 67725e7e..8a432a92 100644 --- a/codegen/sklearn_wrapper_template.py_template +++ b/codegen/sklearn_wrapper_template.py_template @@ -389,6 +389,7 @@ class {transform.original_class_name}(BaseTransformer): """ self._infer_input_output_cols(dataset) super()._check_dataset_type(dataset) + model_trainer = ModelTrainerBuilder.build_fit_transform( estimator=self._sklearn_object, dataset=dataset, diff --git a/codegen/transformer_autogen_test_template.py_template b/codegen/transformer_autogen_test_template.py_template index 39620dac..48a4c646 100644 --- a/codegen/transformer_autogen_test_template.py_template +++ b/codegen/transformer_autogen_test_template.py_template @@ -182,7 +182,7 @@ class {transform.test_class_name}(TestCase): # TODO(snandamuri): HistGradientBoostingRegressor is returning different results in different envs. # Needs further debugging. if {transform._is_hist_gradient_boosting_regressor}: - num_diffs = (~np.isclose(actual_arr, sklearn_numpy_arr)).sum() + num_diffs = (~np.isclose(actual_arr, sklearn_numpy_arr, rtol=1.e-2, atol=1.e-2)).sum() num_example = sklearn_numpy_arr.shape[0] assert num_diffs < 0.1 * num_example elif (not {transform._is_deterministic}) or (not {transform._is_deterministic_cross_platform} and platform.system() == 'Windows'): diff --git a/requirements.yml b/requirements.yml index f2b2b820..b243d53a 100644 --- a/requirements.yml +++ b/requirements.yml @@ -61,8 +61,6 @@ # all extras requirements. All extras requirements will be labeled as `run_constrained` in conda's meta.yaml. # `tags`: Set tags to filter some of the requirements in specific cases. The current valid tags include: -# - `deployment_core`: Used by model deployment to indicate dependencies required to execute model deployment code -# on the server-side. (Obsolete) # - `model_packaging`: Used by model packaging and deployment to indicate the core requirements to save and load the # model. # - `snowml_inference_alternative`: Used by model packaging and deployment to indicate a subset of requirements to run @@ -75,7 +73,6 @@ version_requirements: '>=0.15,<2' tags: - build_essential - - deployment_core - snowml_inference_alternative # For fsspec[http] in conda - name_conda: aiohttp @@ -85,7 +82,6 @@ dev_version: 3.5.0 version_requirements: '>=3.5.0,<4' tags: - - deployment_core - snowml_inference_alternative - name: build dev_version: 0.10.0 @@ -106,7 +102,6 @@ dev_version: 2.2.1 version_requirements: '>=2.0.0' tags: - - deployment_core - model_packaging - name: cryptography dev_version: 39.0.1 @@ -147,7 +142,7 @@ - build_essential - name: shap dev_version: 0.42.1 - version_requirements: ==0.42.1 + version_requirements: '>=0.42.0,<1' requirements_extra_tags: - shap - name: mlflow @@ -165,21 +160,18 @@ dev_version: 1.23.5 version_requirements: '>=1.23,<2' tags: - - deployment_core - build_essential - snowml_inference_alternative - name: packaging dev_version: '23.0' version_requirements: '>=20.9,<24' tags: - - deployment_core - build_essential - snowml_inference_alternative - name: pandas dev_version: 1.5.3 version_requirements: '>=1.0.0,<3' tags: - - deployment_core - snowml_inference_alternative - name: protobuf dev_version: 3.20.3 @@ -212,7 +204,6 @@ dev_version: '6.0' version_requirements: '>=6.0,<7' tags: - - deployment_core - snowml_inference_alternative - name: retrying dev_version: 1.3.3 @@ -230,7 +221,7 @@ version_requirements: '>=2022.11,<2024' - name: scikit-learn dev_version: 1.3.0 - version_requirements: '>=1.2.1,<1.4' + version_requirements: '>=1.2.1,<1.6' tags: - build_essential - name: scipy @@ -254,7 +245,6 @@ dev_version: 1.17.0 version_requirements: '>=1.17.0,<2' tags: - - deployment_core - snowml_inference_alternative - name: sphinx dev_version: 5.0.2 @@ -304,7 +294,6 @@ dev_version: 4.6.3 version_requirements: '>=4.1.0,<5' tags: - - deployment_core - snowml_inference_alternative - name: xgboost dev_version: 1.7.3 diff --git a/snowflake/cortex/BUILD.bazel b/snowflake/cortex/BUILD.bazel index 9a0c6c9d..dd6039f0 100644 --- a/snowflake/cortex/BUILD.bazel +++ b/snowflake/cortex/BUILD.bazel @@ -153,6 +153,44 @@ py_test( ], ) +py_library( + name = "embed_text_768", + srcs = ["_embed_text_768.py"], + deps = [ + ":util", + "//snowflake/ml/_internal:telemetry", + ], +) + +py_test( + name = "embed_text_768_test", + srcs = ["embed_text_768_test.py"], + deps = [ + ":embed_text_768", + ":test_util", + "//snowflake/ml/utils:connection_params", + ], +) + +py_library( + name = "embed_text_1024", + srcs = ["_embed_text_1024.py"], + deps = [ + ":util", + "//snowflake/ml/_internal:telemetry", + ], +) + +py_test( + name = "embed_text_1024_test", + srcs = ["embed_text_1024_test.py"], + deps = [ + ":embed_text_1024", + ":test_util", + "//snowflake/ml/utils:connection_params", + ], +) + py_library( name = "init", srcs = [ @@ -161,6 +199,8 @@ py_library( deps = [ ":classify_text", ":complete", + ":embed_text_1024", + ":embed_text_768", ":extract_answer", ":sentiment", ":summarize", diff --git a/snowflake/cortex/__init__.py b/snowflake/cortex/__init__.py index 1ee01368..92ab345f 100644 --- a/snowflake/cortex/__init__.py +++ b/snowflake/cortex/__init__.py @@ -1,5 +1,7 @@ from snowflake.cortex._classify_text import ClassifyText from snowflake.cortex._complete import Complete, CompleteOptions +from snowflake.cortex._embed_text_768 import EmbedText768 +from snowflake.cortex._embed_text_1024 import EmbedText1024 from snowflake.cortex._extract_answer import ExtractAnswer from snowflake.cortex._sentiment import Sentiment from snowflake.cortex._summarize import Summarize @@ -9,6 +11,8 @@ "ClassifyText", "Complete", "CompleteOptions", + "EmbedText768", + "EmbedText1024", "ExtractAnswer", "Sentiment", "Summarize", diff --git a/snowflake/cortex/_classify_text.py b/snowflake/cortex/_classify_text.py index 780d9254..1b1f1bf1 100644 --- a/snowflake/cortex/_classify_text.py +++ b/snowflake/cortex/_classify_text.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import List, Optional, Union, cast from snowflake import snowpark from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function @@ -33,4 +33,4 @@ def _classify_text_impl( categories: Union[List[str], snowpark.Column], session: Optional[snowpark.Session] = None, ) -> Union[str, snowpark.Column]: - return call_sql_function(function, session, str_input, categories) + return cast(Union[str, snowpark.Column], call_sql_function(function, session, str_input, categories)) diff --git a/snowflake/cortex/_embed_text_1024.py b/snowflake/cortex/_embed_text_1024.py new file mode 100644 index 00000000..8743ed78 --- /dev/null +++ b/snowflake/cortex/_embed_text_1024.py @@ -0,0 +1,37 @@ +from typing import List, Optional, Union, cast + +from snowflake import snowpark +from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function +from snowflake.ml._internal import telemetry + + +@telemetry.send_api_usage_telemetry( + project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, +) +def EmbedText1024( + model: Union[str, snowpark.Column], + text: Union[str, snowpark.Column], + session: Optional[snowpark.Session] = None, +) -> Union[List[float], snowpark.Column]: + """TextEmbed calls into the LLM inference service to embed the text. + + Args: + model: A Column of strings representing the model to use for embedding. The value + of the strings must be within the SUPPORTED_MODELS list. + text: A Column of strings representing input text. + session: The snowpark session to use. Will be inferred by context if not specified. + + Returns: + A column of vectors containing embeddings. + """ + + return _embed_text_1024_impl("snowflake.cortex.embed_text_1024", model, text, session=session) + + +def _embed_text_1024_impl( + function: str, + model: Union[str, snowpark.Column], + text: Union[str, snowpark.Column], + session: Optional[snowpark.Session] = None, +) -> Union[List[float], snowpark.Column]: + return cast(Union[List[float], snowpark.Column], call_sql_function(function, session, model, text)) diff --git a/snowflake/cortex/_embed_text_768.py b/snowflake/cortex/_embed_text_768.py new file mode 100644 index 00000000..e4adf9b6 --- /dev/null +++ b/snowflake/cortex/_embed_text_768.py @@ -0,0 +1,37 @@ +from typing import List, Optional, Union, cast + +from snowflake import snowpark +from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function +from snowflake.ml._internal import telemetry + + +@telemetry.send_api_usage_telemetry( + project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, +) +def EmbedText768( + model: Union[str, snowpark.Column], + text: Union[str, snowpark.Column], + session: Optional[snowpark.Session] = None, +) -> Union[List[float], snowpark.Column]: + """TextEmbed calls into the LLM inference service to embed the text. + + Args: + model: A Column of strings representing the model to use for embedding. The value + of the strings must be within the SUPPORTED_MODELS list. + text: A Column of strings representing input text. + session: The snowpark session to use. Will be inferred by context if not specified. + + Returns: + A column of vectors containing embeddings. + """ + + return _embed_text_768_impl("snowflake.cortex.embed_text_768", model, text, session=session) + + +def _embed_text_768_impl( + function: str, + model: Union[str, snowpark.Column], + text: Union[str, snowpark.Column], + session: Optional[snowpark.Session] = None, +) -> Union[List[float], snowpark.Column]: + return cast(Union[List[float], snowpark.Column], call_sql_function(function, session, model, text)) diff --git a/snowflake/cortex/_extract_answer.py b/snowflake/cortex/_extract_answer.py index 8110d667..91d88239 100644 --- a/snowflake/cortex/_extract_answer.py +++ b/snowflake/cortex/_extract_answer.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, cast from snowflake import snowpark from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function @@ -33,4 +33,4 @@ def _extract_answer_impl( question: Union[str, snowpark.Column], session: Optional[snowpark.Session] = None, ) -> Union[str, snowpark.Column]: - return call_sql_function(function, session, from_text, question) + return cast(Union[str, snowpark.Column], call_sql_function(function, session, from_text, question)) diff --git a/snowflake/cortex/_sentiment.py b/snowflake/cortex/_sentiment.py index dd723790..33a1c4b7 100644 --- a/snowflake/cortex/_sentiment.py +++ b/snowflake/cortex/_sentiment.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, cast from snowflake import snowpark from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function @@ -30,4 +30,4 @@ def _sentiment_impl( output = call_sql_function(function, session, text) if isinstance(output, snowpark.Column): return output - return float(output) + return float(cast(str, output)) diff --git a/snowflake/cortex/_summarize.py b/snowflake/cortex/_summarize.py index 7887477f..5883c931 100644 --- a/snowflake/cortex/_summarize.py +++ b/snowflake/cortex/_summarize.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, cast from snowflake import snowpark from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function @@ -30,4 +30,4 @@ def _summarize_impl( text: Union[str, snowpark.Column], session: Optional[snowpark.Session] = None, ) -> Union[str, snowpark.Column]: - return call_sql_function(function, session, text) + return cast(Union[str, snowpark.Column], call_sql_function(function, session, text)) diff --git a/snowflake/cortex/_translate.py b/snowflake/cortex/_translate.py index c9dc41f7..9220e070 100644 --- a/snowflake/cortex/_translate.py +++ b/snowflake/cortex/_translate.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, cast from snowflake import snowpark from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function @@ -36,4 +36,4 @@ def _translate_impl( to_language: Union[str, snowpark.Column], session: Optional[snowpark.Session] = None, ) -> Union[str, snowpark.Column]: - return call_sql_function(function, session, text, from_language, to_language) + return cast(Union[str, snowpark.Column], call_sql_function(function, session, text, from_language, to_language)) diff --git a/snowflake/cortex/_util.py b/snowflake/cortex/_util.py index f5d80aee..bdafe23f 100644 --- a/snowflake/cortex/_util.py +++ b/snowflake/cortex/_util.py @@ -24,7 +24,7 @@ def call_sql_function( function: str, session: Optional[snowpark.Session], *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]], -) -> Union[str, snowpark.Column]: +) -> Union[str, List[float], snowpark.Column]: handle_as_column = False for arg in args: @@ -32,9 +32,9 @@ def call_sql_function( handle_as_column = True if handle_as_column: - return cast(Union[str, snowpark.Column], _call_sql_function_column(function, *args)) + return cast(Union[str, List[float], snowpark.Column], _call_sql_function_column(function, *args)) return cast( - Union[str, snowpark.Column], + Union[str, List[float], snowpark.Column], _call_sql_function_immediate(function, session, *args), ) @@ -49,7 +49,7 @@ def _call_sql_function_immediate( function: str, session: Optional[snowpark.Session], *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]], -) -> str: +) -> Union[str, List[float]]: session = session or context.get_active_session() if session is None: raise SnowflakeAuthenticationException( diff --git a/snowflake/cortex/embed_text_1024_test.py b/snowflake/cortex/embed_text_1024_test.py new file mode 100644 index 00000000..2c9721b9 --- /dev/null +++ b/snowflake/cortex/embed_text_1024_test.py @@ -0,0 +1,61 @@ +from typing import List + +import _test_util +from absl.testing import absltest + +from snowflake import snowpark +from snowflake.cortex import _embed_text_1024 +from snowflake.snowpark import functions, types + + +class EmbedTest1024Test(absltest.TestCase): + model = "snowflake-arctic-embed-m" + text = "|text|" + + @staticmethod + def embed_text_1024_for_test(model: str, text: str) -> List[float]: + return [0.0] * 1024 + + def setUp(self) -> None: + self._session = _test_util.create_test_session() + functions.udf( + self.embed_text_1024_for_test, + name="embed_text_1024", + session=self._session, + return_type=types.VectorType(float, 1024), + input_types=[types.StringType(), types.StringType()], + is_permanent=False, + ) + + def tearDown(self) -> None: + self._session.sql("drop function embed_text_1024(string,string)").collect() + self._session.close() + + def test_embed_text_1024_str(self) -> None: + res = _embed_text_1024._embed_text_1024_impl( + "embed_text_1024", + self.model, + self.text, + session=self._session, + ) + out = self.embed_text_1024_for_test(self.model, self.text) + self.assertEqual(out, res), f"Expected ({type(out)}) {out}, got ({type(res)}) {res}" + + def test_embed_text_1024_column(self) -> None: + df_in = self._session.create_dataframe([snowpark.Row(model=self.model, text=self.text)]) + df_out = df_in.select( + _embed_text_1024._embed_text_1024_impl( + "embed_text_1024", + functions.col("model"), + functions.col("text"), + session=self._session, + ) + ) + res = df_out.collect()[0][0] + out = self.embed_text_1024_for_test(self.model, self.text) + + self.assertEqual(out, res) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/cortex/embed_text_768_test.py b/snowflake/cortex/embed_text_768_test.py new file mode 100644 index 00000000..c399965d --- /dev/null +++ b/snowflake/cortex/embed_text_768_test.py @@ -0,0 +1,61 @@ +from typing import List + +import _test_util +from absl.testing import absltest + +from snowflake import snowpark +from snowflake.cortex import _embed_text_768 +from snowflake.snowpark import functions, types + + +class EmbedTest768Test(absltest.TestCase): + model = "snowflake-arctic-embed-m" + text = "|text|" + + @staticmethod + def embed_text_768_for_test(model: str, text: str) -> List[float]: + return [0.0] * 768 + + def setUp(self) -> None: + self._session = _test_util.create_test_session() + functions.udf( + self.embed_text_768_for_test, + name="embed_text_768", + session=self._session, + return_type=types.VectorType(float, 768), + input_types=[types.StringType(), types.StringType()], + is_permanent=False, + ) + + def tearDown(self) -> None: + self._session.sql("drop function embed_text_768(string,string)").collect() + self._session.close() + + def test_embed_text_768_str(self) -> None: + res = _embed_text_768._embed_text_768_impl( + "embed_text_768", + self.model, + self.text, + session=self._session, + ) + out = self.embed_text_768_for_test(self.model, self.text) + self.assertEqual(out, res), f"Expected ({type(out)}) {out}, got ({type(res)}) {res}" + + def test_embed_text_768_column(self) -> None: + df_in = self._session.create_dataframe([snowpark.Row(model=self.model, text=self.text)]) + df_out = df_in.select( + _embed_text_768._embed_text_768_impl( + "embed_text_768", + functions.col("model"), + functions.col("text"), + session=self._session, + ) + ) + res = df_out.collect()[0][0] + out = self.embed_text_768_for_test(self.model, self.text) + + self.assertEqual(out, res) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/cortex/package_visibility_test.py b/snowflake/cortex/package_visibility_test.py index 1addaa09..98655da8 100644 --- a/snowflake/cortex/package_visibility_test.py +++ b/snowflake/cortex/package_visibility_test.py @@ -16,6 +16,12 @@ def test_complete_visible(self) -> None: def test_extract_answer_visible(self) -> None: self.assertTrue(callable(cortex.ExtractAnswer)) + def test_embed_text_768_visible(self) -> None: + self.assertTrue(callable(cortex.EmbedText768)) + + def test_embed_text_1024_visible(self) -> None: + self.assertTrue(callable(cortex.EmbedText1024)) + def test_sentiment_visible(self) -> None: self.assertTrue(callable(cortex.Sentiment)) diff --git a/snowflake/ml/_internal/container_services/image_registry/BUILD.bazel b/snowflake/ml/_internal/container_services/image_registry/BUILD.bazel deleted file mode 100644 index 2e875839..00000000 --- a/snowflake/ml/_internal/container_services/image_registry/BUILD.bazel +++ /dev/null @@ -1,56 +0,0 @@ -load("//bazel:py_rules.bzl", "py_library", "py_test") - -package(default_visibility = ["//visibility:public"]) - -py_library( - name = "credential", - srcs = ["credential.py"], - deps = ["//snowflake/ml/_internal/utils:query_result_checker"], -) - -py_library( - name = "http_client", - srcs = ["http_client.py"], - deps = [ - "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/_internal/utils:retryable_http", - "//snowflake/ml/_internal/utils:session_token_manager", - ], -) - -py_test( - name = "http_client_test", - srcs = ["http_client_test.py"], - deps = [ - ":http_client", - "//snowflake/ml/test_utils:mock_session", - ], -) - -py_library( - name = "registry_client", - srcs = ["registry_client.py"], - deps = [ - ":http_client", - ":imagelib", - "//snowflake/ml/_internal/exceptions", - ], -) - -py_library( - name = "imagelib", - srcs = ["imagelib.py"], - deps = [ - ":http_client", - ], -) - -py_test( - name = "registry_client_test", - srcs = ["registry_client_test.py"], - deps = [ - ":registry_client", - "//snowflake/ml/test_utils:exception_utils", - "//snowflake/ml/test_utils:mock_session", - ], -) diff --git a/snowflake/ml/_internal/container_services/image_registry/credential.py b/snowflake/ml/_internal/container_services/image_registry/credential.py deleted file mode 100644 index 44875a37..00000000 --- a/snowflake/ml/_internal/container_services/image_registry/credential.py +++ /dev/null @@ -1,84 +0,0 @@ -# TODO[shchen]: Remove this file and use session_token_manager instead. -import base64 -import contextlib -import json -from typing import Generator, TypedDict - -from snowflake import snowpark -from snowflake.ml._internal.utils import query_result_checker - - -class SessionToken(TypedDict): - token: str - expires_in: str - - -@contextlib.contextmanager -def generate_image_registry_credential(session: snowpark.Session) -> Generator[str, None, None]: - """Construct basic auth credential that is specific to SPCS image registry. For image registry authentication, we - will use a session token obtained from the Snowpark session object. The token authentication mechanism is - automatically used when the username is set to "0sessiontoken" according to the registry implementation. - - As a workaround for SNOW-841699: Fail to authenticate to image registry with session token generated from - Snowpark. We need to temporarily set the json query format in order to process GS token response. Note that we - should set the format back only after registry authentication is complete, otherwise authentication will fail. - - Args: - session: snowpark session - - Yields: - base64-encoded credentials. - """ - - query_result = ( - query_result_checker.SqlResultValidator( - session, - query="SHOW PARAMETERS LIKE 'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT' IN SESSION", - ) - .has_dimensions(expected_rows=1) - .validate() - ) - prev_format = query_result[0].value - try: - session.sql("ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'json'").collect() - token = _get_session_token(session) - yield _get_base64_encoded_credentials(username="0sessiontoken", password=json.dumps(token)) - finally: - session.sql(f"ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = '{prev_format}'").collect() - - -def _get_session_token(session: snowpark.Session) -> SessionToken: - """ - This function retrieves the session token from a given Snowpark session object. - - Args: - session: snowpark session. - - Returns: - The session token string value. - """ - ctx = session._conn._conn - assert ctx._rest, "SnowflakeRestful is not set in session" - token_data = ctx._rest._token_request("ISSUE") - session_token = token_data["data"]["sessionToken"] - validity_in_seconds = token_data["data"]["validityInSecondsST"] - assert session_token, "session_token is not obtained successfully from the session object" - assert validity_in_seconds, "validityInSecondsST is not obtained successfully from the session object" - return {"token": session_token, "expires_in": validity_in_seconds} - - -def _get_base64_encoded_credentials(username: str, password: str) -> str: - """This function returns the base64 encoded username:password, which is compatible with registry, such as - SnowService image registry, that uses Docker credential helper. - - Args: - username: username for authentication. - password: password for authentication. - - Returns: - base64 encoded credential string. - - """ - credentials = f"{username}:{password}" - encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8") - return encoded_credentials diff --git a/snowflake/ml/_internal/container_services/image_registry/http_client.py b/snowflake/ml/_internal/container_services/image_registry/http_client.py deleted file mode 100644 index b52aaf3e..00000000 --- a/snowflake/ml/_internal/container_services/image_registry/http_client.py +++ /dev/null @@ -1,127 +0,0 @@ -import http -import json -import logging -import time -from typing import Any, Callable, Dict, FrozenSet, Optional -from urllib.parse import urlparse, urlunparse - -import requests - -from snowflake import snowpark -from snowflake.ml._internal.exceptions import ( - error_codes, - exceptions as snowml_exceptions, -) -from snowflake.ml._internal.utils import retryable_http, session_token_manager - -logger = logging.getLogger(__name__) - -_MAX_RETRIES = 5 -_RETRY_DELAY_SECONDS = 1 -_RETRYABLE_HTTP_CODE = frozenset([http.HTTPStatus.UNAUTHORIZED]) - - -def retry_on_error( - http_call_function: Callable[..., requests.Response], - retryable_http_code: FrozenSet[http.HTTPStatus] = _RETRYABLE_HTTP_CODE, -) -> Callable[..., requests.Response]: - def wrapper(*args: Any, **kwargs: Any) -> Any: - retry_delay_seconds = _RETRY_DELAY_SECONDS - for attempt in range(1, _MAX_RETRIES + 1): - resp = http_call_function(*args, **kwargs) - if resp.status_code in retryable_http_code: - logger.warning( - f"Received {resp.status_code} status code. Retrying " f"(attempt {attempt}/{_MAX_RETRIES})..." - ) - time.sleep(retry_delay_seconds) - retry_delay_seconds *= 2 # Increase the retry delay exponentially - if attempt < _MAX_RETRIES: - assert isinstance(args[0], ImageRegistryHttpClient) - args[0]._fetch_bearer_token() - else: - return resp - - if attempt == _MAX_RETRIES: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_SNOWFLAKE_IMAGE_REGISTRY_ERROR, - original_exception=RuntimeError( - f"Failed to authenticate to registry after max retries {attempt} \n" - f"Status {resp.status_code}," - f"{str(resp.text)}" - ), - ) - - return wrapper - - -class ImageRegistryHttpClient: - """ - An image registry HTTP client utilizes a retryable HTTP client underneath. Its primary function is to facilitate - re-authentication with the image registry by obtaining a new GS token, which is then used to acquire a new bearer - token for subsequent HTTP request authentication. - - Ideally you should not use this client directly. Please use ImageRegistryClient for image registry-specific - operations. For general use of a retryable HTTP client, consider using the "retryable_http" module. - """ - - def __init__(self, *, repo_url: str, session: Optional[snowpark.Session] = None, no_cred: bool = False) -> None: - self._repo_url = repo_url - self._retryable_http = retryable_http.get_http_client() - self._no_cred = no_cred - - if not self._no_cred: - self._bearer_token = "" - assert session is not None - self._session_token_manager = session_token_manager.SessionTokenManager(session) - - def _with_bearer_token_header(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: - if self._no_cred: - return {} if not headers else headers.copy() - - if not self._bearer_token: - self._fetch_bearer_token() - assert self._bearer_token - new_headers = {} if not headers else headers.copy() - new_headers["Authorization"] = f"Bearer {self._bearer_token}" - return new_headers - - def _fetch_bearer_token(self) -> None: - resp = self._login() - self._bearer_token = str(json.loads(resp.text)["token"]) - - def _login(self) -> requests.Response: - """Log in to image registry. repo_url is expected to set when _login function is invoked. - - Returns: - Bearer token when login succeeded. - """ - parsed_url = urlparse(self._repo_url) - scheme = parsed_url.scheme - host = parsed_url.netloc - - login_path = "/login" # Construct the login path - url_tuple = (scheme, host, login_path, "", "", "") - login_url = urlunparse(url_tuple) - - base64_encoded_token = self._session_token_manager.get_base64_encoded_token() - return self._retryable_http.get(login_url, headers={"Authorization": f"Basic {base64_encoded_token}"}) - - @retry_on_error - def head(self, api_url: str, *, headers: Optional[Dict[str, str]] = None) -> requests.Response: - return self._retryable_http.head(api_url, headers=self._with_bearer_token_header(headers)) - - @retry_on_error - def get(self, api_url: str, *, headers: Optional[Dict[str, str]] = None) -> requests.Response: - return self._retryable_http.get(api_url, headers=self._with_bearer_token_header(headers)) - - @retry_on_error - def put(self, api_url: str, *, headers: Optional[Dict[str, str]] = None, **kwargs: Any) -> requests.Response: - return self._retryable_http.put(api_url, headers=self._with_bearer_token_header(headers), **kwargs) - - @retry_on_error - def post(self, api_url: str, *, headers: Optional[Dict[str, str]] = None, **kwargs: Any) -> requests.Response: - return self._retryable_http.post(api_url, headers=self._with_bearer_token_header(headers), **kwargs) - - @retry_on_error - def patch(self, api_url: str, *, headers: Optional[Dict[str, str]] = None, **kwargs: Any) -> requests.Response: - return self._retryable_http.patch(api_url, headers=self._with_bearer_token_header(headers), **kwargs) diff --git a/snowflake/ml/_internal/container_services/image_registry/http_client_test.py b/snowflake/ml/_internal/container_services/image_registry/http_client_test.py deleted file mode 100644 index b5a55c8e..00000000 --- a/snowflake/ml/_internal/container_services/image_registry/http_client_test.py +++ /dev/null @@ -1,139 +0,0 @@ -import json -from typing import cast - -import requests -from absl.testing import absltest, parameterized -from absl.testing.absltest import mock - -from snowflake.ml._internal.container_services.image_registry import ( - http_client as image_registry_http_client, -) -from snowflake.ml._internal.exceptions import exceptions as snowml_exceptions -from snowflake.ml.test_utils import mock_session -from snowflake.snowpark import session - - -class ImageRegistryHttpClientTest(parameterized.TestCase): - def setUp(self) -> None: - super().setUp() - self.m_session = mock_session.MockSession(conn=None, test_case=self) - self.m_repo_url = "https://org-account.registry.snowflakecomputing.com" - - def _get_mock_response(self, *, status_code: int, text: str) -> mock.Mock: - mock_response = mock.Mock(spec=requests.Response) - mock_response.status_code = status_code - mock_response.text = text - return mock_response - - @parameterized.parameters(("head",), ("get",), ("put",), ("post",), ("patch",)) # type: ignore[misc] - def test_http_method_succeed_in_one_request(self, http_method: str) -> None: - http_client = image_registry_http_client.ImageRegistryHttpClient( - session=cast(session.Session, self.m_session), repo_url=self.m_repo_url - ) - api_url = "https://org-account.registry.snowflakecomputing.com/v2/" - - dummy_token = "fake_token" - mock_token_response = self._get_mock_response(status_code=200, text=json.dumps({"token": dummy_token})) - mock_response = self._get_mock_response(status_code=200, text="succeed") - - with mock.patch.object(http_client, "_login", return_value=mock_token_response), mock.patch.object( - http_client._retryable_http, http_method, return_value=mock_response - ): - res = getattr(http_client, http_method)(api_url, headers={}) - self.assertEqual(res, mock_response) - getattr(http_client._retryable_http, http_method).assert_called_once_with( - api_url, headers={"Authorization": f"Bearer {dummy_token}"} - ) - - @parameterized.parameters(("head",), ("get",), ("put",), ("post",), ("patch",)) # type: ignore[misc] - def test_http_method_retry_on_401(self, http_method: str) -> None: - http_client = image_registry_http_client.ImageRegistryHttpClient( - session=cast(session.Session, self.m_session), repo_url=self.m_repo_url - ) - api_url = "https://org-account.registry.snowflakecomputing.com/v2/" - - dummy_token_1 = "fake_token_1" - dummy_token_2 = "fake_token_2" - mock_token_response_1 = self._get_mock_response(status_code=200, text=json.dumps({"token": dummy_token_1})) - mock_token_response_2 = self._get_mock_response(status_code=200, text=json.dumps({"token": dummy_token_2})) - - mock_response_1 = self._get_mock_response(status_code=401, text="401 FAILED") - mock_response_2 = self._get_mock_response(status_code=200, text="Succeed") - - mock_token_responses = [mock_token_response_1, mock_token_response_2] - mock_responses = [mock_response_1, mock_response_2] - - with mock.patch.object(http_client, "_login", side_effect=mock_token_responses), mock.patch.object( - http_client._retryable_http, http_method, side_effect=mock_responses - ): - res = getattr(http_client, http_method)(api_url, headers={}) - self.assertEqual(res, mock_response_2) - getattr(http_client._retryable_http, http_method).assert_has_calls( - [ - mock.call(api_url, headers={"Authorization": f"Bearer {dummy_token_1}"}), - mock.call(api_url, headers={"Authorization": f"Bearer {dummy_token_2}"}), - ], - any_order=False, - ) - - # Running only 1 method to reduce test time; in general other test case already guarantees that each method will - # have decorator @retry_on_401 set. - @parameterized.parameters(("head",)) # type: ignore[misc] - def test_http_method_fail_after_max_retries(self, http_method: str) -> None: - http_client = image_registry_http_client.ImageRegistryHttpClient( - session=cast(session.Session, self.m_session), repo_url=self.m_repo_url - ) - api_url = "https://org-account.registry.snowflakecomputing.com/v2/" - dummy_token = "fake_token" - mock_token_responses = [ - self._get_mock_response(status_code=200, text=json.dumps({"token": f"dummy_token{i}"})) - for i in range(image_registry_http_client._MAX_RETRIES) - ] - mock_responses = [ - self._get_mock_response(status_code=401, text="401 FAILED") - for _ in range(image_registry_http_client._MAX_RETRIES) - ] - - with self.assertRaises(snowml_exceptions.SnowflakeMLException) as context: - with mock.patch.object(http_client, "_login", side_effect=mock_token_responses), mock.patch.object( - http_client._retryable_http, http_method, side_effect=mock_responses - ): - getattr(http_client, http_method)(api_url, headers={}) - - getattr(http_client._retryable_http, http_method).assert_has_calls( - [ - mock.call(api_url, headers={"Authorization": f"Bearer {dummy_token}{i}"}) - for i in range(image_registry_http_client._MAX_RETRIES) - ], - any_order=False, - ) - - expected_error_message = "Failed to authenticate to registry after max retries" - self.assertIn(expected_error_message, str(context.exception)) - - @parameterized.parameters(("head",)) # type: ignore[misc] - def test_should_not_retry_on_non_401(self, http_method: str) -> None: - http_client = image_registry_http_client.ImageRegistryHttpClient( - session=cast(session.Session, self.m_session), repo_url=self.m_repo_url - ) - api_url = "https://org-account.registry.snowflakecomputing.com/v2/" - - dummy_token_1 = "fake_token_1" - mock_token_response = self._get_mock_response(status_code=200, text=json.dumps({"token": dummy_token_1})) - mock_response = self._get_mock_response(status_code=403, text="403 FAILED") - - with mock.patch.object(http_client, "_login", return_value=mock_token_response), mock.patch.object( - http_client._retryable_http, http_method, return_value=mock_response - ): - getattr(http_client, http_method)(api_url, headers={}) - # There should only be a single call for non-401 http code - getattr(http_client._retryable_http, http_method).assert_has_calls( - [ - mock.call(api_url, headers={"Authorization": f"Bearer {dummy_token_1}"}), - ], - any_order=False, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/_internal/container_services/image_registry/imagelib.py b/snowflake/ml/_internal/container_services/image_registry/imagelib.py deleted file mode 100644 index 2732ceee..00000000 --- a/snowflake/ml/_internal/container_services/image_registry/imagelib.py +++ /dev/null @@ -1,400 +0,0 @@ -""" -A minimal pure python library to copy images between two remote registries. - -This library only supports a limited set of features: -- Works only with docker and OCI manifests and manifest lists for multiarch images (most newer images) - - Supported OCI manifest type: application/vnd.oci.image.manifest.v1+json - - Supported Docker manifest type: application/vnd.docker.distribution.manifest.v2+json -- Supports only pulling a single architecture from a multiarch image. Does not support pulling all architectures. -- Supports only schemaVersion 2. -- Streams images from source to destination without any intermediate disk storage in chunks. -- Does not support copying in parallel. - -It's recommended to use this library to copy previously tested images using sha256 to avoid surprises -with respect to compatibility. -""" - -import dataclasses -import hashlib -import io -import json -import logging -from collections import namedtuple -from typing import Dict, List, Optional, Tuple - -import requests - -from snowflake.ml._internal.container_services.image_registry import ( - http_client as image_registry_http_client, -) - -# Common HTTP headers -_CONTENT_LENGTH_HEADER = "content-length" -_CONTENT_TYPE_HEADER = "content-type" -_CONTENT_RANGE_HEADER = "content-range" -_LOCATION_HEADER = "location" -_AUTHORIZATION_HEADER = "Authorization" -_ACCEPT_HEADER = "accept" - -_OCI_MANIFEST_LIST_TYPE = "application/vnd.oci.image.index.v1+json" -_DOCKER_MANIFEST_LIST_TYPE = "application/vnd.docker.distribution.manifest.list.v2+json" - -_OCI_MANIFEST_TYPE = "application/vnd.oci.image.manifest.v1+json" -_DOCKER_MANIFEST_TYPE = "application/vnd.docker.distribution.manifest.v2+json" - -ALL_SUPPORTED_MEDIA_TYPES = [ - _OCI_MANIFEST_LIST_TYPE, - _DOCKER_MANIFEST_LIST_TYPE, - _OCI_MANIFEST_TYPE, - _DOCKER_MANIFEST_TYPE, -] -_MANIFEST_SUPPORTED_KEYS = {"schemaVersion", "mediaType", "config", "layers"} - -# Architecture descriptor as a named tuple -_Arch = namedtuple("_Arch", ["arch_name", "os"]) - -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass -class ImageDescriptor: - """ - Create an image descriptor. - - registry_name: the name of the registry like gcr.io - repository_name: the name of the repository like kaniko-project/executor - tag: the tag of the image like v1.6.0 - digest: the sha256 digest of the image like sha256:b8c0... - protocol: the protocol to use, defaults to https - - Only a tag or a digest must be specified, not both. - """ - - registry_name: str - repository_name: str - tag: Optional[str] = None - digest: Optional[str] = None - protocol: str = "https" - - def __baseurl(self) -> str: - return f"{self.protocol}://{self.registry_name}/v2/" - - def manifest_link(self) -> str: - return f"{self.__baseurl()}{self.repository_name}/manifests/{self.tag or self.digest}" - - def blob_link(self, digest: str) -> str: - return f"{self.__baseurl()}{self.repository_name}/blobs/{digest}" - - def blob_upload_link(self) -> str: - return f"{self.__baseurl()}{self.repository_name}/blobs/uploads/" - - def manifest_upload_link(self, tag: str) -> str: - return f"{self.__baseurl()}{self.repository_name}/manifests/{tag}" - - def __str__(self) -> str: - return f"{self.registry_name}/{self.repository_name}@{self.tag or self.digest}" - - -class Manifest: - def __init__(self, manifest_bytes: bytes, manifest_digest: str) -> None: - """Create a manifest object from the manifest JSON dict. - - Args: - manifest_bytes: manifest content in bytes. - manifest_digest: SHA256 digest. - """ - self.manifest_bytes = manifest_bytes - self.manifest = json.loads(manifest_bytes.decode("utf-8")) - self.__validate(self.manifest) - - self.manifest_digest = manifest_digest - self.media_type = self.manifest["mediaType"] - - def get_blob_digests(self) -> List[str]: - """ - Get the list of blob digests from the manifest including config and layers. - """ - blobs = [] - blobs.extend([x["digest"] for x in self.manifest["layers"]]) - blobs.append(self.manifest["config"]["digest"]) - - return blobs - - def __validate(self, manifest: Dict[str, str]) -> None: - """ - Validate the manifest. - """ - assert ( - manifest.keys() == _MANIFEST_SUPPORTED_KEYS - ), f"Manifest must contain all keys and no more {_MANIFEST_SUPPORTED_KEYS}" - assert int(manifest["schemaVersion"]) == 2, "Only manifest schemaVersion 2 is supported" - assert manifest["mediaType"] in [ - _OCI_MANIFEST_TYPE, - _DOCKER_MANIFEST_TYPE, - ], f'Unsupported mediaType {manifest["mediaType"]}' - - def __str__(self) -> str: - """ - Return the manifest as a string. - """ - return json.dumps(self.manifest, indent=4) - - -@dataclasses.dataclass -class BlobTransfer: - """ - Helper class to transfer a blob from one registry to another - in small chunks using in-memory buffering. - """ - - # Uploads in chunks of 1MB - chunk_size_bytes = 1024 * 1024 - - src_image: ImageDescriptor - dest_image: ImageDescriptor - manifest: Manifest - src_image_registry_http_client: image_registry_http_client.ImageRegistryHttpClient - dest_image_registry_http_client: image_registry_http_client.ImageRegistryHttpClient - - def upload_all_blobs(self) -> None: - blob_digests = self.manifest.get_blob_digests() - logger.debug(f"Found {len(blob_digests)} blobs for {self.src_image}") - - for blob_digest in blob_digests: - logger.debug(f"Transferring blob {blob_digest} from {self.src_image} to {self.dest_image}") - if self._should_upload(blob_digest): - self._transfer(blob_digest) - else: - logger.debug(f"Blob {blob_digest} already exists in {self.dest_image}") - - def _should_upload(self, blob_digest: str) -> bool: - """ - Check if the blob already exists in the destination registry. - """ - resp = self.dest_image_registry_http_client.head(self.dest_image.blob_link(blob_digest), headers={}) - return resp.status_code != 200 - - def _fetch_blob(self, blob_digest: str) -> Tuple[io.BytesIO, int]: - """ - Fetch a stream to the blob from the source registry. - """ - src_blob_link = self.src_image.blob_link(blob_digest) - headers = {_CONTENT_LENGTH_HEADER: "0"} - resp = self.src_image_registry_http_client.get(src_blob_link, headers=headers) - - assert resp.status_code == 200, f"Blob GET failed with code {resp.status_code}" - assert _CONTENT_LENGTH_HEADER in resp.headers, f"Blob does not contain {_CONTENT_LENGTH_HEADER}" - - return io.BytesIO(resp.content), int(resp.headers[_CONTENT_LENGTH_HEADER]) - - def _get_upload_url(self) -> str: - """ - Obtain the upload URL from the destination registry. - """ - response = self.dest_image_registry_http_client.post(self.dest_image.blob_upload_link()) - assert ( - response.status_code == 202 - ), f"Failed to get the upload URL to destination. Status {response.status_code}. {str(response.content)}" - return str(response.headers[_LOCATION_HEADER]) - - def _upload_blob(self, blob_digest: str, blob_data: io.BytesIO, content_length: int) -> None: - """ - Upload a blob to the destination registry. - """ - upload_url = self._get_upload_url() - headers = { - _CONTENT_TYPE_HEADER: "application/octet-stream", - } - - # Use chunked transfer - # This can be optimized to use a single PUT request for small blobs - next_loc = upload_url - start_byte = 0 - while start_byte < content_length: - chunk = blob_data.read(self.chunk_size_bytes) - chunk_length = len(chunk) - end_byte = start_byte + chunk_length - 1 - - headers[_CONTENT_RANGE_HEADER] = f"{start_byte}-{end_byte}" - headers[_CONTENT_LENGTH_HEADER] = str(chunk_length) - - resp = self.dest_image_registry_http_client.patch(next_loc, headers=headers, data=chunk) - assert resp.status_code == 202, f"Blob PATCH failed with code {resp.status_code}" - - next_loc = resp.headers[_LOCATION_HEADER] - start_byte += chunk_length - - # Finalize the upload - resp = self.dest_image_registry_http_client.put(f"{next_loc}&digest={blob_digest}") - assert resp.status_code == 201, f"Blob PUT failed with code {resp.status_code}" - - def _transfer(self, blob_digest: str) -> None: - """ - Transfer a blob from the source registry to the destination registry. - """ - blob_data, content_length = self._fetch_blob(blob_digest) - self._upload_blob(blob_digest, blob_data, content_length) - - -def get_bytes_with_sha_verification(resp: requests.Response, sha256_digest: str) -> Tuple[bytes, str]: - """Get the bytes of a response and verify the sha256 digest. - - Args: - resp: the response object - sha256_digest: the expected sha256 digest in format "sha256:b8c0..." - - Returns: - (res, sha256_digest) - - """ - digest = hashlib.sha256() - chunks = [] - for chunk in resp.iter_content(chunk_size=8192): - digest.update(chunk) - chunks.append(chunk) - - calculated_digest = digest.hexdigest() - assert not sha256_digest or sha256_digest.endswith(calculated_digest), "SHA256 digest does not match" - - content = b"".join(chunks) # Minimize allocations by joining chunks - return content, calculated_digest - - -def get_manifest( - image_descriptor: ImageDescriptor, arch: _Arch, retryable_http: image_registry_http_client.ImageRegistryHttpClient -) -> Manifest: - """Get the manifest of an image from the remote registry. - - Args: - image_descriptor: the image descriptor - arch: the architecture to filter for if it's a multi-arch image - retryable_http: a retryable http client. - - Returns: - Manifest object. - - """ - logger.debug(f"Getting manifest from {image_descriptor.manifest_link()}") - - headers = {_ACCEPT_HEADER: ",".join(ALL_SUPPORTED_MEDIA_TYPES)} - - response = retryable_http.get(image_descriptor.manifest_link(), headers=headers) - assert response.status_code == 200, f"Manifest GET failed with code {response.status_code}, {response.text}" - - assert image_descriptor.digest - manifest_bytes, manifest_digest = get_bytes_with_sha_verification(response, image_descriptor.digest) - manifest_json = json.loads(manifest_bytes.decode("utf-8")) - - # If this is a manifest list, find the manifest for the specified architecture - # and recurse till we find the real manifest - if manifest_json["mediaType"] in [ - _OCI_MANIFEST_LIST_TYPE, - _DOCKER_MANIFEST_LIST_TYPE, - ]: - logger.debug("Found a multiarch image. Following manifest reference.") - - assert "manifests" in manifest_json, "Manifest list does not contain manifests" - qualified_manifests = [ - x - for x in manifest_json["manifests"] - if x["platform"]["architecture"] == arch.arch_name and x["platform"]["os"] == arch.os - ] - assert ( - len(qualified_manifests) == 1 - ), "Manifest list does not contain exactly one qualified manifest for this arch" - - manifest_object = qualified_manifests[0] - manifest_digest = manifest_object["digest"] - - logger.debug(f"Found manifest reference for arch {arch}: {manifest_digest}") - - # Copy the image descriptor to fetch the arch-specific manifest - descriptor_copy = ImageDescriptor( - registry_name=image_descriptor.registry_name, - repository_name=image_descriptor.repository_name, - digest=manifest_digest, - tag=None, - ) - - # Supports only one level of manifest list nesting to avoid infinite recursion - return get_manifest(descriptor_copy, arch, retryable_http) - - return Manifest(manifest_bytes, manifest_digest) - - -def put_manifest( - image_descriptor: ImageDescriptor, - manifest: Manifest, - retryable_http: image_registry_http_client.ImageRegistryHttpClient, -) -> None: - """ - Upload the given manifest to the destination registry. - """ - assert image_descriptor.tag is not None, "Tag must be specified for manifest upload" - headers = {_CONTENT_TYPE_HEADER: manifest.media_type} - url = image_descriptor.manifest_upload_link(image_descriptor.tag) - logger.debug(f"Uploading manifest to {url}") - response = retryable_http.put(url, headers=headers, data=manifest.manifest_bytes) - assert response.status_code == 201, f"Manifest PUT failed with code {response.status_code}" - - -def copy_image( - src_image: ImageDescriptor, - dest_image: ImageDescriptor, - arch: _Arch, - src_retryable_http: image_registry_http_client.ImageRegistryHttpClient, - dest_retryable_http: image_registry_http_client.ImageRegistryHttpClient, -) -> None: - logger.debug(f"Pulling image manifest for {src_image}") - - # 1. Get the manifest - manifest = get_manifest(src_image, arch, src_retryable_http) - logger.debug(f"Manifest pulled for {src_image} with digest {manifest.manifest_digest}") - - # 2: Retrieve all blob digests from manifest; fetch blob based on blob digest, then upload blob. - blob_transfer = BlobTransfer( - src_image, - dest_image, - manifest, - src_image_registry_http_client=src_retryable_http, - dest_image_registry_http_client=dest_retryable_http, - ) - blob_transfer.upload_all_blobs() - - # 3. Upload the manifest - logger.debug(f"All blobs copied successfully. Copying manifest for {src_image} to {dest_image}") - put_manifest( - dest_image, - manifest, - dest_retryable_http, - ) - - logger.debug(f"Image {src_image} copied to {dest_image}") - - -def convert_to_image_descriptor( - image_name: str, - with_digest: bool = False, - with_tag: bool = False, -) -> ImageDescriptor: - """Convert a full image name to a ImageDescriptor object. - - Args: - image_name: name of image. - with_digest: boolean to specify whether a digest is included in the image name - with_tag: boolean to specify whether a tag is included in the image name. - - Returns: - An ImageDescriptor instance - """ - assert with_digest or with_tag, "image should contain either digest or tag" - sep = "@" if with_digest else ":" - parts = image_name.split("/") - assert len(parts[-1].split(sep)) == 2, f"Image {image_name} missing digest/tag" - tag_digest = parts[-1].split(sep)[1] - return ImageDescriptor( - registry_name=parts[0], - repository_name="/".join(parts[1:-1] + [parts[-1].split(sep)[0]]), - digest=tag_digest if with_digest else None, - tag=tag_digest if with_tag else None, - ) diff --git a/snowflake/ml/_internal/container_services/image_registry/registry_client.py b/snowflake/ml/_internal/container_services/image_registry/registry_client.py deleted file mode 100644 index f9acdf08..00000000 --- a/snowflake/ml/_internal/container_services/image_registry/registry_client.py +++ /dev/null @@ -1,212 +0,0 @@ -import http -import logging -from typing import Dict, Optional, cast -from urllib.parse import urlunparse - -from snowflake.ml._internal.container_services.image_registry import ( - http_client as image_registry_http_client, - imagelib, -) -from snowflake.ml._internal.exceptions import ( - error_codes, - exceptions as snowml_exceptions, -) -from snowflake.snowpark import Session -from snowflake.snowpark._internal import utils as snowpark_utils - -_MANIFEST_V1_HEADER = "application/vnd.oci.image.manifest.v1+json" -_MANIFEST_V2_HEADER = "application/vnd.docker.distribution.manifest.v2+json" -_SUPPORTED_MANIFEST_HEADERS = [_MANIFEST_V1_HEADER, _MANIFEST_V2_HEADER] - -logger = logging.getLogger(__name__) - - -class ImageRegistryClient: - """ - A partial implementation of an SPCS image registry client. The client utilizes the ImageRegistryHttpClient under - the hood, incorporating a retry mechanism to handle intermittent 401 errors from the SPCS image registry. - """ - - def __init__(self, session: Session, full_dest_image_name: str) -> None: - """Initialization - - Args: - session: Snowpark session - full_dest_image_name: Based on dest image name, repo url can be inferred. - """ - self.image_registry_http_client = image_registry_http_client.ImageRegistryHttpClient( - session=session, - repo_url=self._convert_to_v2_manifests_url(full_image_name=full_dest_image_name), - ) - - def _convert_to_v2_manifests_url(self, full_image_name: str) -> str: - """Converts a full image name to a Docker Registry HTTP API V2 URL: - https://docs.docker.com/registry/spec/api/#existing-manifests - - org-account.registry-dev.snowflakecomputing.com/db/schema/repo/image_name:tag becomes - https://org-account.registry-dev.snowflakecomputing.com/v2/db/schema/repo/image_name/manifests/tag - - Args: - full_image_name: a string consists of image name and image tag. - - Returns: - Docker HTTP V2 URL for checking manifest existence. - """ - scheme = "https" - full_image_name_parts = full_image_name.split(":") - assert len(full_image_name_parts) == 2, "full image name should include both image name and tag" - - image_name = full_image_name_parts[0] - tag = full_image_name_parts[1] - image_name_parts = image_name.split("/") - domain = image_name_parts[0] - rest = "/".join(image_name_parts[1:]) - path = f"/v2/{rest}/manifests/{tag}" - url_tuple = (scheme, domain, path, "", "", "") - return urlunparse(url_tuple) - - def _get_accept_headers(self) -> Dict[str, str]: - # Depending on the built image, the media type of the image manifest might be either - # application/vnd.oci.image.manifest.v1+json or application/vnd.docker.distribution.manifest.v2+json - # Hence we need to check for both, otherwise it could result in false negative. - return {"Accept": ",".join(_SUPPORTED_MANIFEST_HEADERS)} - - def image_exists(self, full_image_name: str) -> bool: - """Check whether image already exists in the registry. - - Args: - full_image_name: Full image name consists of image name and image tag. - - Returns: - Boolean value. True when image already exists, else False. - - """ - # When running in SPROC, the Sproc session connection will not have _rest object associated, which makes it - # unable to fetch session token needed to authenticate to SPCS image registry. - if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call] - return False - v2_api_url = self._convert_to_v2_manifests_url(full_image_name) - headers = self._get_accept_headers() - status = self.image_registry_http_client.head(v2_api_url, headers=headers).status_code - return status == http.HTTPStatus.OK - - def _get_manifest(self, full_image_name: str) -> Dict[str, str]: - """Retrieve image manifest file. Given Docker manifest comes with two versions, and for each version the - corresponding request header is required for a successful HTTP response. - - Args: - full_image_name: Full image name. - - Returns: - Full manifest content as a python dict. - - Raises: - SnowflakeMLException: when failed to retrieve manifest. - """ - - v2_api_url = self._convert_to_v2_manifests_url(full_image_name) - res = self.image_registry_http_client.get(v2_api_url, headers=self._get_accept_headers()) - if res.status_code == http.HTTPStatus.OK: - return cast(Dict[str, str], res.json()) - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_SNOWFLAKE_IMAGE_REGISTRY_ERROR, - original_exception=ValueError( - f"Failed to retrieve manifest for {full_image_name}. \n" - f"HTTP status code: {res.status_code}. Full response: {res.text}." - ), - ) - - def add_tag_to_remote_image(self, original_full_image_name: str, new_tag: str) -> None: - """Add a tag to an image in the registry. - - Args: - original_full_image_name: The full image name is required to fetch manifest. - new_tag: New tag to be added to the image. - - Returns: - None - - Raises: - SnowflakeMLException: when failed to push the newly updated manifest. - """ - - if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call] - return None - - full_image_name_parts = original_full_image_name.split(":") - assert len(full_image_name_parts) == 2, "full image name should include both image name and tag" - new_full_image_name = ":".join([full_image_name_parts[0], new_tag]) - if self.image_exists(new_full_image_name): - # Early return if image with the associated tag already exists. - return - api_url = self._convert_to_v2_manifests_url(new_full_image_name) - manifest = self._get_manifest(full_image_name=original_full_image_name) - manifest_copy = manifest.copy() - manifest_copy["tag"] = new_tag - headers = self._get_accept_headers() - # Http Content-Type does not support multi-value, hence need to construct separate header. - put_header_v1 = { - **headers, - "Content-Type": _MANIFEST_V1_HEADER, - } - put_header_v2 = { - **headers, - "Content-Type": _MANIFEST_V2_HEADER, - } - - res1 = self.image_registry_http_client.put(api_url, headers=put_header_v1, json=manifest_copy) - if res1.status_code != http.HTTPStatus.CREATED: - res2 = self.image_registry_http_client.put(api_url, headers=put_header_v2, json=manifest_copy) - if res2.status_code != http.HTTPStatus.CREATED: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_SNOWFLAKE_IMAGE_REGISTRY_ERROR, - original_exception=ValueError( - f"Failed to push manifest for {new_full_image_name}. Two requests filed: \n" - f"HTTP status code 1: {res1.status_code}. Full response 1: {res1.text}. \n" - f"HTTP status code 2: {res2.status_code}. Full response 2: {res2.text}" - ), - ) - assert self.image_exists( - new_full_image_name - ), f"{new_full_image_name} should exist in image repo after a successful manifest update" - - def copy_image( - self, - source_image_with_digest: str, - dest_image_with_tag: str, - arch: Optional[imagelib._Arch] = None, - ) -> None: - """Util function to copy image across registry. Currently supported pulling from public docker image repo to - SPCS image registry. - - Args: - source_image_with_digest: source image with digest, e.g. gcr.io/kaniko-project/executor@sha256:b8c0977 - dest_image_with_tag: destination image with tag. - arch: architecture of source image. - - Returns: - None - """ - logger.info(f"Copying image from {source_image_with_digest} to {dest_image_with_tag}") - if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call] - logger.warning(f"Running inside Sproc. Please ensure image already exists at {dest_image_with_tag}") - return None - - arch = arch or imagelib._Arch("amd64", "linux") - - src_image = imagelib.convert_to_image_descriptor(source_image_with_digest, with_digest=True) - dest_image = imagelib.convert_to_image_descriptor( - dest_image_with_tag, - with_tag=True, - ) - # TODO[shchen]: Remove the imagelib, instead rely on the copy image system function later. - imagelib.copy_image( - src_image=src_image, - dest_image=dest_image, - arch=arch, - src_retryable_http=image_registry_http_client.ImageRegistryHttpClient( - repo_url=src_image.registry_name, no_cred=True - ), - dest_retryable_http=self.image_registry_http_client, - ) - logger.info("Image copy completed successfully") diff --git a/snowflake/ml/_internal/container_services/image_registry/registry_client_test.py b/snowflake/ml/_internal/container_services/image_registry/registry_client_test.py deleted file mode 100644 index 7117e1eb..00000000 --- a/snowflake/ml/_internal/container_services/image_registry/registry_client_test.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import cast - -from absl.testing import absltest -from absl.testing.absltest import mock - -from snowflake.ml._internal.container_services.image_registry import ( - registry_client as image_registry_client, -) -from snowflake.ml.test_utils import mock_session -from snowflake.snowpark import session - - -class ImageRegistryClientTest(absltest.TestCase): - def setUp(self) -> None: - super().setUp() - self.m_session = mock_session.MockSession(conn=None, test_case=self) - - def test_convert_to_v2_head_manifests_url(self) -> None: - full_image_name = "org-account.registry.snowflakecomputing.com/db/schema/repo/image:latest" - client = image_registry_client.ImageRegistryClient(cast(session.Session, self.m_session), full_image_name) - actual = client._convert_to_v2_manifests_url(full_image_name=full_image_name) - expected = "https://org-account.registry.snowflakecomputing.com/v2/db/schema/repo/image/manifests/latest" - self.assertEqual(actual, expected) - - def test_convert_to_v2_head_manifests_url_with_invalid_full_image_name(self) -> None: - image_name_without_tag = "org-account.registry.snowflakecomputing.com/db/schema/repo/image" - with self.assertRaises(AssertionError): - image_registry_client.ImageRegistryClient(cast(session.Session, self.m_session), image_name_without_tag) - - def test_image_exists(self) -> None: - full_image_name = "org-account.registry.snowflakecomputing.com/db/schema/repo/image:latest" - client = image_registry_client.ImageRegistryClient(cast(session.Session, self.m_session), full_image_name) - url = client._convert_to_v2_manifests_url(full_image_name) - - with mock.patch.object(client.image_registry_http_client, "head", return_value=mock.MagicMock(status_code=200)): - self.assertEqual(client.image_exists(full_image_name=full_image_name), True) - - client.image_registry_http_client.head.assert_called_once_with( # type: ignore[attr-defined] - url, headers=client._get_accept_headers() - ) - - def test_add_tag_to_remote_image(self) -> None: - # Test case for updating the tag on an image that initially doesn't exist. - # Retrieves the manifest, updates it with a new tag, and pushes the manifest to add the tag. - # Covers the scenario where it takes 2 put requests to update the tag. - full_image_name = "org-account.registry.snowflakecomputing.com/db/schema/repo/image:latest" - client = image_registry_client.ImageRegistryClient(cast(session.Session, self.m_session), full_image_name) - test_manifest = { - "schemaVersion": 2, - "mediaType": "application/vnd.docker.distribution.manifest.v2+json", - "config": { - "mediaType": "application/vnd.docker.container.image.v1+json", - "size": 4753, - "digest": "sha256:fb56ac2b330e2adaefd71f89e5a7e09a415bb55929d6b14b7a0bca5096479ad1", - }, - "layers": [ - { - "mediaType": "application/vnd.docker.image.rootfs.diff.tar.gzip", - "size": 143, - "digest": "sha256:24b9c0f433244f171ec20a922de94fc83d401d5d471ec13ca75f5a5fa7867426", - }, - ], - "tag": "tag-v1", - } - with mock.patch.object(client, "image_exists", side_effect=[False, True]), mock.patch.object( - client, "_get_manifest", return_value=test_manifest - ), mock.patch.object( - client.image_registry_http_client, - "put", - side_effect=[mock.Mock(status_code=400), mock.Mock(status_code=201)], - ): - new_tag = "new_tag" - new_image_name = f"org-account.registry.snowflakecomputing.com/db/schema/repo/image:{new_tag}" - client.add_tag_to_remote_image(original_full_image_name=full_image_name, new_tag=new_tag) - headers = client._get_accept_headers() - put_header_v1 = { - **headers, - "Content-Type": image_registry_client._MANIFEST_V1_HEADER, - } - put_header_v2 = { - **headers, - "Content-Type": image_registry_client._MANIFEST_V2_HEADER, - } - url = client._convert_to_v2_manifests_url(new_image_name) - - test_manifest_copy = test_manifest.copy() - test_manifest_copy["tag"] = new_tag - client.image_registry_http_client.put.assert_has_calls( # type: ignore[attr-defined] - [ - mock.call(url, headers=put_header_v1, json=test_manifest_copy), - mock.call(url, headers=put_header_v2, json=test_manifest_copy), - ], - any_order=False, - ) - - # First call to check existence before adding tag; second call to validate tag indeed added. - client.image_exists.assert_has_calls( # type: ignore[attr-defined] - [ - mock.call(new_image_name), - mock.call(new_image_name), - ] - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/_internal/env_utils.py b/snowflake/ml/_internal/env_utils.py index 7b3edf32..5b03c046 100644 --- a/snowflake/ml/_internal/env_utils.py +++ b/snowflake/ml/_internal/env_utils.py @@ -9,7 +9,7 @@ from typing import Any, DefaultDict, Dict, List, Optional, Tuple import yaml -from packaging import requirements, specifiers, utils as packaging_utils, version +from packaging import requirements, specifiers, version import snowflake.connector from snowflake.ml._internal import env as snowml_env @@ -54,15 +54,12 @@ def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement: """ try: r = requirements.Requirement(req_str) - r.name = packaging_utils.canonicalize_name(r.name) if r.name == "python": raise ValueError("Don't specify python as a dependency, use python version argument instead.") except requirements.InvalidRequirement: raise ValueError(f"Invalid package requirement {req_str} found.") - if r.marker: - raise ValueError("Markers is not supported in conda dependency.") return r @@ -84,6 +81,8 @@ def _validate_conda_dependency_string(dep_str: str) -> Tuple[str, requirements.R channel_str, _, requirement_str = dep_str.rpartition("::") r = _validate_pip_requirement_string(requirement_str) if channel_str != "pip": + if r.marker: + raise ValueError("Markers is not supported in conda dependency.") if r.extras: raise ValueError("Extras is not supported in conda dependency.") if r.url: @@ -221,7 +220,7 @@ def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement else: return pip_req new_pip_req = copy.deepcopy(pip_req) - new_pip_req.specifier = specifiers.SpecifierSet(specifiers=f"=={local_dist_version}") + new_pip_req.specifier = specifiers.SpecifierSet(specifiers=f"=={version.parse(local_dist_version).base_version}") if not pip_req.specifier.contains(local_dist_version): warnings.warn( f"Package requirement {str(pip_req)} specified, while version {local_dist_version} is installed. " @@ -513,6 +512,7 @@ def save_conda_env_file( ) with open(path, "w", encoding="utf-8") as f: + yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign] yaml.safe_dump(env, stream=f, default_flow_style=False) diff --git a/snowflake/ml/_internal/env_utils_test.py b/snowflake/ml/_internal/env_utils_test.py index 2ac7b698..4b7401e8 100644 --- a/snowflake/ml/_internal/env_utils_test.py +++ b/snowflake/ml/_internal/env_utils_test.py @@ -24,10 +24,10 @@ def test_validate_pip_requirement_string(self) -> None: self.assertEqual(r.name, "python-package") r = env_utils._validate_pip_requirement_string("python.package==1.0.1") - self.assertEqual(r.name, "python-package") + self.assertEqual(r.name, "python.package") r = env_utils._validate_pip_requirement_string("Python-package==1.0.1") - self.assertEqual(r.name, "python-package") + self.assertEqual(r.name, "Python-package") r = env_utils._validate_pip_requirement_string("python-package>=1.0.1,<2,~=1.1,!=1.0.3") self.assertEqual(r.specifier, specifiers.SpecifierSet(">=1.0.1, <2, ~=1.1, !=1.0.3")) r = env_utils._validate_pip_requirement_string("requests [security,tests] >= 2.8.1, == 2.8.*") @@ -42,8 +42,6 @@ def test_validate_pip_requirement_string(self) -> None: env_utils._validate_pip_requirement_string("python-package=1.0.1") with self.assertRaises(ValueError): env_utils._validate_pip_requirement_string("_python-package==1.0.1") - with self.assertRaises(ValueError): - env_utils._validate_pip_requirement_string('requests; python_version < "2.7"') def test_validate_conda_dependency_string(self) -> None: c, r = env_utils._validate_conda_dependency_string("python-package==1.0.1") @@ -182,9 +180,9 @@ def test_validate_pip_requirement_string_list(self) -> None: ] env_utils.validate_pip_requirement_string_list(rl) - with self.assertRaises(env_utils.DuplicateDependencyError): - rl = ["python-package", "python_package"] - env_utils.validate_pip_requirement_string_list(rl) + rl = ["python-package", "python_package"] + trl = [requirements.Requirement("python-package"), requirements.Requirement("python_package")] + self.assertListEqual(env_utils.validate_pip_requirement_string_list(rl), trl) rl = ["python-package", "another-python-package"] trl = [requirements.Requirement("python-package"), requirements.Requirement("another-python-package")] @@ -265,6 +263,15 @@ def test_get_local_installed_version_of_pip_package(self) -> None: requirements.Requirement(f"pip!={importlib_metadata.version('pip')}") ) + mock_distribution = mock.MagicMock() + mock_distribution.version = "1.0.0.post100" + with mock.patch.object(importlib_metadata, "distribution", return_value=mock_distribution): + r = requirements.Requirement("pip") + self.assertEqual( + requirements.Requirement("pip==1.0.0"), + env_utils.get_local_installed_version_of_pip_package(r), + ) + def test_get_package_spec_with_supported_ops_only(self) -> None: r = requirements.Requirement("python-package==1.0.1") self.assertEqual(env_utils.get_package_spec_with_supported_ops_only(r), r) diff --git a/snowflake/ml/_internal/exceptions/error_codes.py b/snowflake/ml/_internal/exceptions/error_codes.py index a5507015..65a7abff 100644 --- a/snowflake/ml/_internal/exceptions/error_codes.py +++ b/snowflake/ml/_internal/exceptions/error_codes.py @@ -58,6 +58,8 @@ METHOD_NOT_ALLOWED = "2102" # Not implemented. NOT_IMPLEMENTED = "2103" +# User needs to opt in to use a feature. +OPT_IN_REQUIRED = "2104" # Calling an API with unsupported keywords/values. INVALID_ARGUMENT = "2110" diff --git a/snowflake/ml/_internal/utils/BUILD.bazel b/snowflake/ml/_internal/utils/BUILD.bazel index e3cf9367..42b4f667 100644 --- a/snowflake/ml/_internal/utils/BUILD.bazel +++ b/snowflake/ml/_internal/utils/BUILD.bazel @@ -219,40 +219,8 @@ py_test( ) py_library( - name = "log_stream_processor", - srcs = ["log_stream_processor.py"], -) - -py_test( - name = "log_stream_processor_test", - srcs = ["log_stream_processor_test.py"], - deps = [ - ":log_stream_processor", - ], -) - -py_library( - name = "session_token_manager", - srcs = ["session_token_manager.py"], -) - -py_library( - name = "spcs_attribution_utils", - srcs = ["spcs_attribution_utils.py"], - deps = [ - ":query_result_checker", - "//snowflake/ml/_internal:telemetry", - ], -) - -py_test( - name = "spcs_attribution_utils_test", - srcs = ["spcs_attribution_utils_test.py"], - deps = [ - ":spcs_attribution_utils", - "//snowflake/ml/test_utils:mock_data_frame", - "//snowflake/ml/test_utils:mock_session", - ], + name = "service_logger", + srcs = ["service_logger.py"], ) py_library( diff --git a/snowflake/ml/_internal/utils/log_stream_processor.py b/snowflake/ml/_internal/utils/log_stream_processor.py deleted file mode 100644 index 9c3f8547..00000000 --- a/snowflake/ml/_internal/utils/log_stream_processor.py +++ /dev/null @@ -1,30 +0,0 @@ -import logging -from typing import Optional - -logger = logging.getLogger(__name__) - - -class LogStreamProcessor: - def __init__(self) -> None: - self.last_line_seen = 0 - - def process_new_logs(self, job_logs: Optional[str], *, log_level: int = logging.INFO) -> None: - if not job_logs: - return - log_entries = job_logs.split("\n") - start_index = self.last_line_seen - log_length = len(log_entries) - for i in range(start_index, log_length): - log_entry = log_entries[i] - if log_level == logging.DEBUG: - logger.debug(log_entry) - elif log_level == logging.INFO: - logger.info(log_entry) - elif log_level == logging.WARNING: - logger.warning(log_entry) - elif log_level == logging.ERROR: - logger.error(log_entry) - elif log_level == logging.CRITICAL: - logger.critical(log_entry) - - self.last_line_seen = log_length diff --git a/snowflake/ml/_internal/utils/log_stream_processor_test.py b/snowflake/ml/_internal/utils/log_stream_processor_test.py deleted file mode 100644 index 5a4a984f..00000000 --- a/snowflake/ml/_internal/utils/log_stream_processor_test.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging -from io import StringIO - -from absl.testing import absltest - -from snowflake.ml._internal.utils import log_stream_processor - - -class LogStreamProcessorTest(absltest.TestCase): - def setUp(self) -> None: - self.log_stream = StringIO() - self.log_handler = logging.StreamHandler(self.log_stream) - self.log_handler.setLevel(logging.INFO) - self.log_handler.setFormatter(logging.Formatter("%(message)s")) - logging.getLogger().addHandler(self.log_handler) - - def tearDown(self) -> None: - logging.getLogger().removeHandler(self.log_handler) - self.log_stream.close() - logging.shutdown() - - def reset_log_stream(self) -> None: - # Clear the log stream - self.log_stream.truncate(0) - self.log_stream.seek(0) - - def test_only_new_log_is_shown(self) -> None: - lsp = log_stream_processor.LogStreamProcessor() - log1 = "TIMESTAMP1: HI 1" - log2 = "TIMESTAMP1: HI 1 \n TIMESTAMP2: HI 2" - log3 = "TIMESTAMP1: HI 1 \n TIMESTAMP2: HI 2 \n TIMESTAMP3: HI 3" - log4 = "TIMESTAMP1: HI 1 \n TIMESTAMP2: HI 2 \n TIMESTAMP3: HI 3" - - lsp.process_new_logs(log1) - self.assertEqual("TIMESTAMP1: HI 1", self.log_stream.getvalue().strip()) - - self.reset_log_stream() - - lsp.process_new_logs(log2) - self.assertEqual("TIMESTAMP2: HI 2", self.log_stream.getvalue().strip()) - - self.reset_log_stream() - - lsp.process_new_logs(log3) - self.assertEqual("TIMESTAMP3: HI 3", self.log_stream.getvalue().strip()) - - self.reset_log_stream() - - # No new log returned - lsp.process_new_logs(log4) - self.assertEqual("", self.log_stream.getvalue().strip()) - - self.reset_log_stream() - - # Process empty log - lsp.process_new_logs(None) - self.assertEqual("", self.log_stream.getvalue().strip()) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/_internal/utils/service_logger.py b/snowflake/ml/_internal/utils/service_logger.py new file mode 100644 index 00000000..7b788962 --- /dev/null +++ b/snowflake/ml/_internal/utils/service_logger.py @@ -0,0 +1,63 @@ +import enum +import logging +import sys + + +class LogColor(enum.Enum): + GREY = "\x1b[38;20m" + RED = "\x1b[31;20m" + BOLD_RED = "\x1b[31;1m" + YELLOW = "\x1b[33;20m" + BLUE = "\x1b[34;20m" + GREEN = "\x1b[32;20m" + + +class CustomFormatter(logging.Formatter): + + reset = "\x1b[0m" + log_format = "%(name)s [%(asctime)s] [%(levelname)s] %(message)s" + + def __init__(self, info_color: LogColor) -> None: + super().__init__() + self.level_colors = { + logging.DEBUG: LogColor.GREY.value, + logging.INFO: info_color.value, + logging.WARNING: LogColor.YELLOW.value, + logging.ERROR: LogColor.RED.value, + logging.CRITICAL: LogColor.BOLD_RED.value, + } + + def format(self, record: logging.LogRecord) -> str: + # default to DEBUG color + fmt = self.level_colors.get(record.levelno, self.level_colors[logging.DEBUG]) + self.log_format + self.reset + formatter = logging.Formatter(fmt) + + # split the log message by lines and format each line individually + original_message = record.getMessage() + message_lines = original_message.splitlines() + formatted_lines = [ + formatter.format( + logging.LogRecord( + name=record.name, + level=record.levelno, + pathname=record.pathname, + lineno=record.lineno, + msg=line, + args=None, + exc_info=None, + ) + ) + for line in message_lines + ] + + return "\n".join(formatted_lines) + + +def get_logger(logger_name: str, info_color: LogColor) -> logging.Logger: + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.INFO) + handler.setFormatter(CustomFormatter(info_color)) + logger.addHandler(handler) + return logger diff --git a/snowflake/ml/_internal/utils/session_token_manager.py b/snowflake/ml/_internal/utils/session_token_manager.py deleted file mode 100644 index cdec61fa..00000000 --- a/snowflake/ml/_internal/utils/session_token_manager.py +++ /dev/null @@ -1,46 +0,0 @@ -import base64 -import json -from typing import TypedDict - -from snowflake import snowpark - - -class SessionToken(TypedDict): - token: str - expires_in: str - - -class SessionTokenManager: - def __init__(self, session: snowpark.Session) -> None: - self._session = session - - def get_session_token(self) -> SessionToken: - """ - This function retrieves the session token from Snowpark session object. - - Returns: - The session token string value. - """ - ctx = self._session._conn._conn - assert ctx._rest, "SnowflakeRestful is not set in session" - token_data = ctx._rest._token_request("ISSUE") - session_token = token_data["data"]["sessionToken"] - validity_in_seconds = token_data["data"]["validityInSecondsST"] - assert session_token, "session_token is not obtained successfully from the session object" - assert validity_in_seconds, "validityInSecondsST is not obtained successfully from the session object" - return {"token": session_token, "expires_in": validity_in_seconds} - - def get_base64_encoded_token(self, username: str = "0sessiontoken") -> str: - """This function returns the base64 encoded username:password, which is compatible with registry, such as - SnowService image registry, that uses Docker credential helper. In this case, password will be session token. - - Args: - username: username for authentication. - - Returns: - base64 encoded credential string. - - """ - credentials = f"{username}:{json.dumps(self.get_session_token())}" - encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8") - return encoded_credentials diff --git a/snowflake/ml/_internal/utils/spcs_attribution_utils.py b/snowflake/ml/_internal/utils/spcs_attribution_utils.py deleted file mode 100644 index 7e7018c4..00000000 --- a/snowflake/ml/_internal/utils/spcs_attribution_utils.py +++ /dev/null @@ -1,122 +0,0 @@ -import logging -from datetime import datetime -from typing import Any, Dict, Optional - -from snowflake import snowpark -from snowflake.ml._internal import telemetry -from snowflake.ml._internal.utils import query_result_checker - -logger = logging.getLogger(__name__) - -_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f %z" -_COMPUTE_POOL = "compute_pool" -_CREATED_ON = "created_on" -_INSTANCE_FAMILY = "instance_family" -_NAME = "name" -_TELEMETRY_PROJECT = "MLOps" -_TELEMETRY_SUBPROJECT = "SpcsDeployment" -_SERVICE_START = "SPCS_SERVICE_START" -_SERVICE_END = "SPCS_SERVICE_END" - - -def _desc_compute_pool(session: snowpark.Session, compute_pool_name: str) -> Dict[str, Any]: - sql = f"DESC COMPUTE POOL {compute_pool_name}" - result = ( - query_result_checker.SqlResultValidator( - session=session, - query=sql, - ) - .has_column(_INSTANCE_FAMILY) - .has_column(_NAME) - .has_dimensions(expected_rows=1) - .validate() - ) - return result[0].as_dict() - - -def _desc_service(session: snowpark.Session, fully_qualified_name: str) -> Dict[str, Any]: - sql = f"DESC SERVICE {fully_qualified_name}" - result = ( - query_result_checker.SqlResultValidator( - session=session, - query=sql, - ) - .has_column(_COMPUTE_POOL) - .has_dimensions(expected_rows=1) - .validate() - ) - return result[0].as_dict() - - -def _get_current_time() -> datetime: - """ - This method exists to make it easier to mock datetime in test. - - Returns: - current datetime - """ - return datetime.now() - - -def _send_service_telemetry( - fully_qualified_name: Optional[str] = None, - compute_pool_name: Optional[str] = None, - service_details: Optional[Dict[str, Any]] = None, - compute_pool_details: Optional[Dict[str, Any]] = None, - duration_in_seconds: Optional[int] = None, - kwargs: Optional[Dict[str, Any]] = None, -) -> None: - try: - telemetry.send_custom_usage( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - telemetry_type=telemetry.TelemetryField.TYPE_SNOWML_SPCS_USAGE.value, - data={ - "service_name": fully_qualified_name, - "compute_pool_name": compute_pool_name, - "service_details": service_details, - "compute_pool_details": compute_pool_details, - "duration_in_seconds": duration_in_seconds, - }, - kwargs=kwargs, - ) - except Exception as e: - logger.error(f"Failed to send service telemetry: {e}") - - -def record_service_start(session: snowpark.Session, fully_qualified_name: str) -> None: - service_details = _desc_service(session, fully_qualified_name) - compute_pool_name = service_details[_COMPUTE_POOL] - compute_pool_details = _desc_compute_pool(session, compute_pool_name) - - _send_service_telemetry( - fully_qualified_name=fully_qualified_name, - compute_pool_name=compute_pool_name, - service_details=service_details, - compute_pool_details=compute_pool_details, - kwargs={telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: _SERVICE_START}, - ) - - logger.info(f"Service {fully_qualified_name} created with compute pool {compute_pool_name}.") - - -def record_service_end(session: snowpark.Session, fully_qualified_name: str) -> None: - service_details = _desc_service(session, fully_qualified_name) - compute_pool_details = _desc_compute_pool(session, service_details[_COMPUTE_POOL]) - compute_pool_name = service_details[_COMPUTE_POOL] - - created_on_datetime: datetime = service_details[_CREATED_ON] - current_time: datetime = _get_current_time() - current_time = current_time.replace(tzinfo=created_on_datetime.tzinfo) - duration_in_seconds = int((current_time - created_on_datetime).total_seconds()) - - _send_service_telemetry( - fully_qualified_name=fully_qualified_name, - compute_pool_name=compute_pool_name, - service_details=service_details, - compute_pool_details=compute_pool_details, - duration_in_seconds=duration_in_seconds, - kwargs={telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: _SERVICE_END}, - ) - - logger.info(f"Service {fully_qualified_name} deleted from compute pool {compute_pool_name}") diff --git a/snowflake/ml/_internal/utils/spcs_attribution_utils_test.py b/snowflake/ml/_internal/utils/spcs_attribution_utils_test.py deleted file mode 100644 index 482c5d0b..00000000 --- a/snowflake/ml/_internal/utils/spcs_attribution_utils_test.py +++ /dev/null @@ -1,135 +0,0 @@ -import datetime -from typing import Any, Dict, cast -from unittest import mock - -from absl.testing import absltest - -from snowflake import snowpark -from snowflake.ml._internal import telemetry -from snowflake.ml._internal.utils import spcs_attribution_utils -from snowflake.ml.test_utils import mock_data_frame, mock_session -from snowflake.snowpark import session - - -class SpcsAttributionUtilsTest(absltest.TestCase): - def setUp(self) -> None: - super().setUp() - self._m_session = mock_session.MockSession(conn=None, test_case=self) - self._fully_qualified_service_name = "db.schema.my_service" - self._m_compute_pool_name = "my_pool" - self._service_created_on = datetime.datetime.strptime( - "2023-11-16 13:01:00.062 -0800", spcs_attribution_utils._DATETIME_FORMAT - ) - - mock_service_detail = self._get_mock_service_details() - self._m_session.add_mock_sql( - query=f"DESC SERVICE {self._fully_qualified_service_name}", - result=mock_data_frame.MockDataFrame(collect_result=[snowpark.Row(**mock_service_detail)]), - ) - - mock_compute_pool_detail = self._get_mock_compute_pool_details() - self._m_session.add_mock_sql( - query=f"DESC COMPUTE POOL {self._m_compute_pool_name}", - result=mock_data_frame.MockDataFrame(collect_result=[snowpark.Row(**mock_compute_pool_detail)]), - ) - - def _get_mock_service_details(self) -> Dict[str, Any]: - return { - "name": "my_service", - "database_name": "my_db", - "schema_name": "my_schema", - "owner": "Engineer", - "compute_pool": self._m_compute_pool_name, - "spec": "--- spec:", - "dns_name": "service-dummy.my-schema.my-db.snowflakecomputing.internal", - "public_endpoints": {"predict": "dummy.snowflakecomputing.app"}, - "min_instances": 1, - "max_instances": 1, - "created_on": self._service_created_on, - "updated_on": "2023-11-16 13:01:00.595 -0800", - "comment": None, - } - - def _get_mock_compute_pool_details(self) -> Dict[str, Any]: - return { - "name": self._m_compute_pool_name, - "state": "Active", - "min_nodes": 1, - "max_nodes": 1, - "instance_family": "STANDARD_2", - "num_services": 1, - "num_jobs": 2, - "active_nodes": 1, - "idle_nodes": 1, - "created_on": "2023-09-21 09:17:39.627 -0700", - "resumed_on": "2023-09-21 09:17:39.628 -0700", - "updated_on": "2023-11-27 15:08:55.725 -0800", - "owner": "ACCOUNTADMIN", - "comment": None, - } - - def test_record_service_start(self) -> None: - with mock.patch.object(spcs_attribution_utils, "_send_service_telemetry", return_value=None) as m_telemetry: - with self.assertLogs(level="INFO") as cm: - spcs_attribution_utils.record_service_start( - cast(session.Session, self._m_session), self._fully_qualified_service_name - ) - - assert len(cm.output) == 1, "there should only be 1 log" - log = cm.output[0] - - service_details = self._get_mock_service_details() - compute_pool_details = self._get_mock_compute_pool_details() - - self.assertEqual( - log, - f"INFO:snowflake.ml._internal.utils.spcs_attribution_utils:Service " - f"{self._fully_qualified_service_name} created with compute pool {self._m_compute_pool_name}.", - ) - m_telemetry.assert_called_once_with( - fully_qualified_name=self._fully_qualified_service_name, - compute_pool_name=self._m_compute_pool_name, - service_details=service_details, - compute_pool_details=compute_pool_details, - kwargs={telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: spcs_attribution_utils._SERVICE_START}, - ) - - def test_record_service_end(self) -> None: - current_datetime = self._service_created_on + datetime.timedelta(days=2, hours=1, minutes=30, seconds=20) - expected_duration = 178220 # 2 days 1 hour 30 minutes and 20 seconds. - - with mock.patch( - "snowflake.ml._internal.utils.spcs_attribution_utils._get_current_time" - ) as mock_datetime_now, mock.patch.object( - spcs_attribution_utils, "_send_service_telemetry", return_value=None - ) as m_telemetry: - with self.assertLogs(level="INFO") as cm: - mock_datetime_now.return_value = current_datetime - - spcs_attribution_utils.record_service_end( - cast(session.Session, self._m_session), self._fully_qualified_service_name - ) - assert len(cm.output) == 1, "there should only be 1 log" - log = cm.output[0] - - service_details = self._get_mock_service_details() - compute_pool_details = self._get_mock_compute_pool_details() - - self.assertEqual( - log, - f"INFO:snowflake.ml._internal.utils.spcs_attribution_utils:Service " - f"{self._fully_qualified_service_name} deleted from compute pool {self._m_compute_pool_name}", - ) - - m_telemetry.assert_called_once_with( - fully_qualified_name=self._fully_qualified_service_name, - compute_pool_name=self._m_compute_pool_name, - service_details=service_details, - compute_pool_details=compute_pool_details, - duration_in_seconds=expected_duration, - kwargs={telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: spcs_attribution_utils._SERVICE_END}, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/_internal/utils/sql_identifier.py b/snowflake/ml/_internal/utils/sql_identifier.py index 9f9a6ff4..cee55122 100644 --- a/snowflake/ml/_internal/utils/sql_identifier.py +++ b/snowflake/ml/_internal/utils/sql_identifier.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from snowflake.ml._internal.utils import identifier @@ -92,3 +92,27 @@ def parse_fully_qualified_name( SqlIdentifier(schema) if schema else None, SqlIdentifier(object), ) + + +def get_fully_qualified_name( + db: Union[SqlIdentifier, str, None], + schema: Union[SqlIdentifier, str, None], + object: Union[SqlIdentifier, str], + session_db: Optional[str] = None, + session_schema: Optional[str] = None, +) -> str: + db_name: Optional[SqlIdentifier] = None + schema_name: Optional[SqlIdentifier] = None + if not db and session_db: + db_name = SqlIdentifier(session_db) + elif isinstance(db, str): + db_name = SqlIdentifier(db) + if not schema and session_schema: + schema_name = SqlIdentifier(session_schema) + elif isinstance(schema, str): + schema_name = SqlIdentifier(schema) + return identifier.get_schema_level_object_identifier( + db=db_name.identifier() if db_name else None, + schema=schema_name.identifier() if schema_name else None, + object_name=object.identifier() if isinstance(object, SqlIdentifier) else SqlIdentifier(object).identifier(), + ) diff --git a/snowflake/ml/_internal/utils/sql_identifier_test.py b/snowflake/ml/_internal/utils/sql_identifier_test.py index 6dd3000d..cca580d0 100644 --- a/snowflake/ml/_internal/utils/sql_identifier_test.py +++ b/snowflake/ml/_internal/utils/sql_identifier_test.py @@ -65,6 +65,42 @@ def test_parse_fully_qualified_name(self) -> None: with self.assertRaises(ValueError): sql_identifier.parse_fully_qualified_name("abc-def") + def test_get_fully_qualified_name(self) -> None: + self.assertEqual( + "MYDB.MYSCHEMA.ABC", + sql_identifier.get_fully_qualified_name( + None, None, sql_identifier.SqlIdentifier("abc"), "mydb", "myschema" + ), + ) + self.assertEqual( + "MYDB.MYSCHEMA.ABC", + sql_identifier.get_fully_qualified_name( + "mydb", "myschema", sql_identifier.SqlIdentifier("abc"), None, None + ), + ) + self.assertEqual( + "ABC", + sql_identifier.get_fully_qualified_name(None, None, sql_identifier.SqlIdentifier("abc"), None, None), + ) + self.assertEqual( + 'MYDB.MYSCHEMA."abc"', + sql_identifier.get_fully_qualified_name( + "mydb", "myschema", sql_identifier.SqlIdentifier('"abc"'), None, None + ), + ) + self.assertEqual( + '"mydb"."myschema".ABC', + sql_identifier.get_fully_qualified_name( + '"mydb"', '"myschema"', sql_identifier.SqlIdentifier("abc"), None, None + ), + ) + self.assertEqual( + '"mydb"."myschema".ABC', + sql_identifier.get_fully_qualified_name( + None, None, sql_identifier.SqlIdentifier("abc"), '"mydb"', '"myschema"' + ), + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/data/_internal/arrow_ingestor.py b/snowflake/ml/data/_internal/arrow_ingestor.py index a6edebe8..c659faeb 100644 --- a/snowflake/ml/data/_internal/arrow_ingestor.py +++ b/snowflake/ml/data/_internal/arrow_ingestor.py @@ -11,7 +11,6 @@ import pyarrow.dataset as pds from snowflake import snowpark -from snowflake.connector import result_batch from snowflake.ml.data import data_ingestor, data_source, ingestor_utils _EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], []) @@ -140,16 +139,7 @@ def _get_dataset(self, shuffle: bool) -> pds.Dataset: # We may be able to optimize this by splitting the result batches into # in-memory (first batch) and file URLs (subsequent batches) and creating a # union dataset. - result_batches = ingestor_utils.get_dataframe_result_batches(self._session, source) - sources.extend( - b.to_arrow(self._session.connection) - if isinstance(b, result_batch.ArrowResultBatch) - else b.to_arrow() - for b in result_batches - ) - # HACK: Mitigate typing inconsistencies in Snowpark results - if len(sources) > 0: - sources = [_cast_if_needed(s, sources[-1].schema) for s in sources] + sources.append(_cast_if_needed(ingestor_utils.get_dataframe_arrow_table(self._session, source))) source_format = None # Arrow Dataset expects "None" for in-memory datasets else: raise RuntimeError(f"Unsupported data source type: {type(source)}") diff --git a/snowflake/ml/data/ingestor_utils.py b/snowflake/ml/data/ingestor_utils.py index 907a73c3..b7512685 100644 --- a/snowflake/ml/data/ingestor_utils.py +++ b/snowflake/ml/data/ingestor_utils.py @@ -1,19 +1,17 @@ from typing import List, Optional import fsspec +import pyarrow as pa from snowflake import snowpark -from snowflake.connector import result_batch +from snowflake.connector import cursor as sf_cursor, result_batch from snowflake.ml.data import data_source from snowflake.ml.fileset import snowfs _TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading. -def get_dataframe_result_batches( - session: snowpark.Session, df_info: data_source.DataFrameInfo -) -> List[result_batch.ResultBatch]: - """Retrieve the ResultBatches for a given query""" +def _get_dataframe_cursor(session: snowpark.Session, df_info: data_source.DataFrameInfo) -> sf_cursor.SnowflakeCursor: cursor = session._conn._cursor if df_info.query_id: @@ -29,12 +27,24 @@ def get_dataframe_result_batches( if cursor._prefetch_hook is None: raise RuntimeError("Loading data from result query failed unexpectedly. Please contact Snowflake support.") cursor._prefetch_hook() + + return cursor + + +def get_dataframe_result_batches( + session: snowpark.Session, df_info: data_source.DataFrameInfo +) -> List[result_batch.ResultBatch]: + """Retrieve the ResultBatches for a given query""" + cursor = _get_dataframe_cursor(session, df_info) batches = cursor.get_result_batches() - if batches is None: - raise ValueError( - "Failed to retrieve training data. Query status:" f" {session._conn._conn.get_query_status(query_id)}" - ) - return batches + return batches or [] + + +def get_dataframe_arrow_table(session: snowpark.Session, df_info: data_source.DataFrameInfo) -> pa.Table: + """Retrieve the full in-memory result for a given query""" + cursor = _get_dataframe_cursor(session, df_info) + table = cursor.fetch_arrow_all() # type: ignore[call-overload] + return table def get_dataset_filesystem( diff --git a/snowflake/ml/feature_store/access_manager.py b/snowflake/ml/feature_store/access_manager.py index 65d5df4d..0b484a2f 100644 --- a/snowflake/ml/feature_store/access_manager.py +++ b/snowflake/ml/feature_store/access_manager.py @@ -30,6 +30,7 @@ class _Privilege: object_name: str privileges: List[str] scope: Optional[str] = None + optional: bool = False @dataclass(frozen=True) @@ -72,8 +73,7 @@ class _SessionInfo: _Privilege("VIEW", _ALL_OBJECTS, ["SELECT", "REFERENCES"], "SCHEMA {database}.{schema}"), _Privilege("TABLE", _ALL_OBJECTS, ["SELECT", "REFERENCES"], "SCHEMA {database}.{schema}"), _Privilege("DATASET", _ALL_OBJECTS, ["USAGE"], "SCHEMA {database}.{schema}"), - # User should decide whether they want to grant warehouse usage to CONSUMER - # _Privilege("WAREHOUSE", "{warehouse}", ["USAGE"]), + _Privilege("WAREHOUSE", "{warehouse}", ["USAGE"], optional=True), ], _FeatureStoreRole.NONE: [], } @@ -109,7 +109,7 @@ def _grant_privileges( query += f" TO ROLE {role_name}" session.sql(query).collect() except exceptions.SnowparkSQLException as e: - if any( + if p.optional or any( s in e.message for s in ( "Ask your account admin", diff --git a/snowflake/ml/feature_store/feature_store.py b/snowflake/ml/feature_store/feature_store.py index f9018c9c..23da0913 100644 --- a/snowflake/ml/feature_store/feature_store.py +++ b/snowflake/ml/feature_store/feature_store.py @@ -122,6 +122,14 @@ def parse(cls, val: str) -> _FeatureStoreObjTypes: flags=re.DOTALL | re.IGNORECASE | re.X, ) +_DT_INITIALIZE_PATTERN = re.compile( + r"""CREATE\ DYNAMIC\ TABLE\ .* + initialize\ =\ '(?P.*)'\ .*? + AS\ .* + """, + flags=re.DOTALL | re.IGNORECASE | re.X, +) + _LIST_FEATURE_VIEW_SCHEMA = StructType( [ StructField("name", StringType()), @@ -565,11 +573,15 @@ def register_feature_view( tagging_clause_str = ",\n".join(tagging_clause) def create_col_desc(col: StructField) -> str: - desc = feature_view.feature_descs.get(SqlIdentifier(col.name), None) + desc = feature_view.feature_descs.get(SqlIdentifier(col.name), None) # type: ignore[union-attr] desc = "" if desc is None else f"COMMENT '{desc}'" return f"{col.name} {desc}" - column_descs = ", ".join([f"{create_col_desc(col)}" for col in feature_view.output_schema.fields]) + column_descs = ( + ", ".join([f"{create_col_desc(col)}" for col in feature_view.output_schema.fields]) + if feature_view.feature_descs is not None + else "" + ) if refresh_freq is not None: schedule_task = refresh_freq != "DOWNSTREAM" and timeparse(refresh_freq) is None @@ -1819,6 +1831,7 @@ def _create_dynamic_table( ) WAREHOUSE = {warehouse} REFRESH_MODE = {feature_view.refresh_mode} + INITIALIZE = {feature_view.initialize} AS {feature_view.query} """ self._session.sql(query).collect(block=block, statement_params=self._telemetry_stmp) @@ -2293,6 +2306,8 @@ def find_and_compose_entity(name: str) -> Entity: entities = [find_and_compose_entity(n) for n in fv_metadata.entities] ts_col = fv_metadata.timestamp_col timestamp_col = ts_col if ts_col not in _LEGACY_TIMESTAMP_COL_PLACEHOLDER_VALS else None + re_initialize = re.match(_DT_INITIALIZE_PATTERN, row["text"]) + initialize = re_initialize.group("initialize") if re_initialize is not None else "ON_CREATE" fv = FeatureView._construct_feature_view( name=name, @@ -2317,6 +2332,7 @@ def find_and_compose_entity(name: str) -> Entity: ), refresh_mode=row["refresh_mode"], refresh_mode_reason=row["refresh_mode_reason"], + initialize=initialize, owner=row["owner"], infer_schema_df=infer_schema_df, session=self._session, @@ -2343,6 +2359,7 @@ def find_and_compose_entity(name: str) -> Entity: warehouse=None, refresh_mode=None, refresh_mode_reason=None, + initialize="ON_CREATE", owner=row["owner"], infer_schema_df=infer_schema_df, session=self._session, diff --git a/snowflake/ml/feature_store/feature_view.py b/snowflake/ml/feature_store/feature_view.py index 021ecd8b..b6de5b51 100644 --- a/snowflake/ml/feature_store/feature_view.py +++ b/snowflake/ml/feature_store/feature_view.py @@ -22,6 +22,7 @@ from snowflake.ml.feature_store.entity import Entity from snowflake.ml.lineage import lineage_node from snowflake.snowpark import DataFrame, Session +from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.types import ( DateType, StructType, @@ -167,6 +168,7 @@ def __init__( refresh_freq: Optional[str] = None, desc: str = "", warehouse: Optional[str] = None, + initialize: str = "ON_CREATE", **_kwargs: Any, ) -> None: """ @@ -190,6 +192,10 @@ def __init__( warehouse: warehouse to refresh feature view. Not needed for static feature view (refresh_freq is None). For managed feature view, this warehouse will overwrite the default warehouse of Feature Store if it is specified, otherwise the default warehouse will be used. + initialize: Specifies the behavior of the initial refresh of feature view. This property cannot be altered + after you register the feature view. It supports ON_CREATE (default) or ON_SCHEDULE. ON_CREATE refreshes + the feature view synchronously at creation. ON_SCHEDULE refreshes the feature view at the next scheduled + refresh. It is only effective when refresh_freq is not None. _kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE. Example:: @@ -227,10 +233,14 @@ def __init__( self._query: str = self._get_query() self._version: Optional[FeatureViewVersion] = None self._status: FeatureViewStatus = FeatureViewStatus.DRAFT - self._feature_desc: OrderedDict[SqlIdentifier, str] = OrderedDict((f, "") for f in self._get_feature_names()) + feature_names = self._get_feature_names() + self._feature_desc: Optional[OrderedDict[SqlIdentifier, str]] = ( + OrderedDict((f, "") for f in feature_names) if feature_names is not None else None + ) self._refresh_freq: Optional[str] = refresh_freq self._database: Optional[SqlIdentifier] = None self._schema: Optional[SqlIdentifier] = None + self._initialize: str = initialize self._warehouse: Optional[SqlIdentifier] = SqlIdentifier(warehouse) if warehouse is not None else None self._refresh_mode: Optional[str] = _kwargs.get("refresh_mode", "AUTO") self._refresh_mode_reason: Optional[str] = None @@ -345,6 +355,15 @@ def attach_feature_desc(self, descs: Dict[str, str]) -> FeatureView: ('START_STATION_LATITUDE', 'Latitude of the start station.')]) """ + if self._feature_desc is None: + warnings.warn( + "Failed to read feature view schema. Probably feature view is not refreshed yet. " + "Schema will be available after initial refresh.", + stacklevel=2, + category=UserWarning, + ) + return self + for f, d in descs.items(): f = SqlIdentifier(f) if f not in self._feature_desc: @@ -424,10 +443,10 @@ def status(self) -> FeatureViewStatus: @property def feature_names(self) -> List[SqlIdentifier]: - return list(self._feature_desc.keys()) + return list(self._feature_desc.keys()) if self._feature_desc is not None else [] @property - def feature_descs(self) -> Dict[SqlIdentifier, str]: + def feature_descs(self) -> Optional[Dict[SqlIdentifier, str]]: return self._feature_desc def list_columns(self) -> DataFrame: @@ -463,7 +482,17 @@ def list_columns(self) -> DataFrame: """ session = self._feature_df.session - rows = [] + rows = [] # type: ignore[var-annotated] + + if self.feature_descs is None: + warnings.warn( + "Failed to read feature view schema. Probably feature view is not refreshed yet. " + "Schema will be available after initial refresh.", + stacklevel=2, + category=UserWarning, + ) + return session.create_dataframe(rows, schema=["name", "category", "dtype", "desc"]) + for name, type in self._feature_df.dtypes: if SqlIdentifier(name) in self.feature_descs: desc = self.feature_descs[SqlIdentifier(name)] @@ -565,6 +594,10 @@ def warehouse(self, new_value: str) -> None: ) self._warehouse = SqlIdentifier(new_value) + @property + def initialize(self) -> str: + return self._initialize + @property def output_schema(self) -> StructType: return self._infer_schema_df.schema @@ -601,33 +634,49 @@ def _validate(self) -> None: f"FeatureView name `{self._name}` contains invalid character `{_FEATURE_VIEW_NAME_DELIMITER}`." ) - unescaped_df_cols = to_sql_identifiers(self._infer_schema_df.columns) - for e in self._entities: - for k in e.join_keys: - if k not in unescaped_df_cols: - raise ValueError( - f"join_key {k} in Entity {e.name} is not found in input dataframe: {unescaped_df_cols}" - ) - - if self._timestamp_col is not None: - ts_col = self._timestamp_col - if ts_col == SqlIdentifier(_TIMESTAMP_COL_PLACEHOLDER): - raise ValueError(f"Invalid timestamp_col name, cannot be {_TIMESTAMP_COL_PLACEHOLDER}.") - if ts_col not in to_sql_identifiers(self._infer_schema_df.columns): - raise ValueError(f"timestamp_col {ts_col} is not found in input dataframe.") - - col_type = self._infer_schema_df.schema[ts_col].datatype - if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)): - raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.") + df_cols = self._get_column_names() + if df_cols is not None: + for e in self._entities: + for k in e.join_keys: + if k not in df_cols: + raise ValueError(f"join_key {k} in Entity {e.name} is not found in input dataframe: {df_cols}") + + if self._timestamp_col is not None: + ts_col = self._timestamp_col + if ts_col == SqlIdentifier(_TIMESTAMP_COL_PLACEHOLDER): + raise ValueError(f"Invalid timestamp_col name, cannot be {_TIMESTAMP_COL_PLACEHOLDER}.") + if ts_col not in df_cols: + raise ValueError(f"timestamp_col {ts_col} is not found in input dataframe.") + + col_type = self._infer_schema_df.schema[ts_col].datatype + if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)): + raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.") if re.match(_RESULT_SCAN_QUERY_PATTERN, self._query) is not None: raise ValueError(f"feature_df should not be reading from RESULT_SCAN. Invalid query: {self._query}") - def _get_feature_names(self) -> List[SqlIdentifier]: + if self._initialize not in ["ON_CREATE", "ON_SCHEDULE"]: + raise ValueError("'initialize' only supports ON_CREATE or ON_SCHEDULE.") + + def _get_column_names(self) -> Optional[List[SqlIdentifier]]: + try: + return to_sql_identifiers(self._infer_schema_df.columns) + except SnowparkSQLException as e: + warnings.warn( + "Failed to read feature view schema. Probably feature view is not refreshed yet. " + f"Schema will be available after initial refresh. Original exception: {e}", + stacklevel=2, + category=UserWarning, + ) + return None + + def _get_feature_names(self) -> Optional[List[SqlIdentifier]]: join_keys = [k for e in self._entities for k in e.join_keys] ts_col = [self._timestamp_col] if self._timestamp_col is not None else [] - feature_names = to_sql_identifiers(self._infer_schema_df.columns, case_sensitive=False) - return [c for c in feature_names if c not in join_keys + ts_col] + feature_names = self._get_column_names() + if feature_names is not None: + return [c for c in feature_names if c not in join_keys + ts_col] + return None def __repr__(self) -> str: states = (f"{k}={v}" for k, v in vars(self).items()) @@ -670,11 +719,13 @@ def _to_dict(self) -> Dict[str, str]: fv_dict["_schema"] = str(self._schema) if self._schema is not None else None fv_dict["_warehouse"] = str(self._warehouse) if self._warehouse is not None else None fv_dict["_timestamp_col"] = str(self._timestamp_col) if self._timestamp_col is not None else None + fv_dict["_initialize"] = str(self._initialize) feature_desc_dict = {} - for k, v in self._feature_desc.items(): - feature_desc_dict[k.identifier()] = v - fv_dict["_feature_desc"] = feature_desc_dict + if self._feature_desc is not None: + for k, v in self._feature_desc.items(): + feature_desc_dict[k.identifier()] = v + fv_dict["_feature_desc"] = feature_desc_dict lineage_node_keys = [key for key in fv_dict if key.startswith("_node") or key == "_session"] @@ -760,6 +811,7 @@ def from_json(cls, json_str: str, session: Session) -> FeatureView: warehouse=json_dict["_warehouse"], refresh_mode=json_dict["_refresh_mode"], refresh_mode_reason=json_dict["_refresh_mode_reason"], + initialize=json_dict["_initialize"], owner=json_dict["_owner"], infer_schema_df=session.sql(json_dict.get("_infer_schema_query", None)), session=session, @@ -830,6 +882,7 @@ def _construct_feature_view( warehouse: Optional[str], refresh_mode: Optional[str], refresh_mode_reason: Optional[str], + initialize: str, owner: Optional[str], infer_schema_df: Optional[DataFrame], session: Session, @@ -850,6 +903,7 @@ def _construct_feature_view( fv._warehouse = SqlIdentifier(warehouse) if warehouse is not None else None fv._refresh_mode = refresh_mode fv._refresh_mode_reason = refresh_mode_reason + fv._initialize = initialize fv._owner = owner fv.attach_feature_desc(feature_descs) diff --git a/snowflake/ml/fileset/stage_fs.py b/snowflake/ml/fileset/stage_fs.py index ed6403c0..bbf30ea2 100644 --- a/snowflake/ml/fileset/stage_fs.py +++ b/snowflake/ml/fileset/stage_fs.py @@ -431,4 +431,5 @@ def _resolve_async_job(async_job: snowpark.AsyncJob) -> List[snowpark.Row]: error_code=error_codes.SNOWML_NOT_FOUND, original_exception=fileset_errors.StageNotFoundError("Query failed."), ) from e - raise + assert e.msg is not None + raise snowpark_exceptions.SnowparkSQLException(e.msg, conn_error=e) from e diff --git a/snowflake/ml/lineage/lineage_node.py b/snowflake/ml/lineage/lineage_node.py index feec84d0..754ceee4 100644 --- a/snowflake/ml/lineage/lineage_node.py +++ b/snowflake/ml/lineage/lineage_node.py @@ -118,16 +118,21 @@ def lineage( ) domain = lineage_object["domain"].lower() if domain_filter is None or domain in domain_filter: + obj_name = ".".join( + identifier.rename_to_valid_snowflake_identifier(s) + for s in identifier.parse_schema_level_object_identifier(lineage_object["name"]) + if s is not None + ) if domain in DOMAIN_LINEAGE_REGISTRY and lineage_object["status"] == "ACTIVE": lineage_nodes.append( DOMAIN_LINEAGE_REGISTRY[domain]._load_from_lineage_node( - self._session, lineage_object["name"], lineage_object.get("version") + self._session, obj_name, lineage_object.get("version") ) ) else: lineage_nodes.append( LineageNode( - name=lineage_object["name"], + name=obj_name, version=lineage_object.get("version"), domain=domain, status=lineage_object["status"], diff --git a/snowflake/ml/model/BUILD.bazel b/snowflake/ml/model/BUILD.bazel index 5b0847cb..3f861e36 100644 --- a/snowflake/ml/model/BUILD.bazel +++ b/snowflake/ml/model/BUILD.bazel @@ -6,17 +6,11 @@ py_library( name = "type_hints", srcs = ["type_hints.py"], deps = [ - ":deploy_platforms", "//snowflake/ml/model/_signatures:core", "//snowflake/ml/modeling/framework", ], ) -py_library( - name = "deploy_platforms", - srcs = ["deploy_platforms.py"], -) - py_library( name = "model_signature", srcs = ["model_signature.py"], @@ -27,7 +21,6 @@ py_library( "//snowflake/ml/_internal/utils:formatting", "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/_internal/utils:sql_identifier", - "//snowflake/ml/model/_deploy_client/warehouse:infer_template", "//snowflake/ml/model/_signatures:base_handler", "//snowflake/ml/model/_signatures:builtins_handler", "//snowflake/ml/model/_signatures:core", @@ -48,22 +41,6 @@ py_library( ], ) -py_library( - name = "_api", - srcs = ["_api.py"], - deps = [ - ":deploy_platforms", - ":model_signature", - ":type_hints", - "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/model/_deploy_client/snowservice:deploy", - "//snowflake/ml/model/_deploy_client/warehouse:deploy", - "//snowflake/ml/model/_deploy_client/warehouse:infer_template", - "//snowflake/ml/model/_model_composer:model_composer", - "//snowflake/ml/model/_signatures:snowpark_handler", - ], -) - py_library( name = "model", srcs = ["__init__.py"], @@ -71,7 +48,6 @@ py_library( "//snowflake/ml/model/_client/model:model_impl", "//snowflake/ml/model/_client/model:model_version_impl", "//snowflake/ml/model/models:huggingface_pipeline", - "//snowflake/ml/model/models:llm_model", ], ) @@ -79,9 +55,7 @@ py_test( name = "package_visibility_test", srcs = ["package_visibility_test.py"], deps = [ - ":_api", ":custom_model", - ":deploy_platforms", ":model", ":model_signature", ":type_hints", diff --git a/snowflake/ml/model/__init__.py b/snowflake/ml/model/__init__.py index d14485f5..54ff8122 100644 --- a/snowflake/ml/model/__init__.py +++ b/snowflake/ml/model/__init__.py @@ -1,6 +1,5 @@ from snowflake.ml.model._client.model.model_impl import Model from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel -from snowflake.ml.model.models.llm import LLM, LLMOptions -__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "LLM", "LLMOptions"] +__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel"] diff --git a/snowflake/ml/model/_api.py b/snowflake/ml/model/_api.py deleted file mode 100644 index f9f316b4..00000000 --- a/snowflake/ml/model/_api.py +++ /dev/null @@ -1,568 +0,0 @@ -from types import ModuleType -from typing import Any, Dict, List, Literal, Optional, Union, cast, overload - -import pandas as pd -from typing_extensions import deprecated - -from snowflake.ml._internal.exceptions import ( - error_codes, - exceptions as snowml_exceptions, -) -from snowflake.ml.model import ( - deploy_platforms, - model_signature, - type_hints as model_types, -) -from snowflake.ml.model._deploy_client.snowservice import deploy as snowservice_deploy -from snowflake.ml.model._deploy_client.utils import constants as snowservice_constants -from snowflake.ml.model._deploy_client.warehouse import ( - deploy as warehouse_deploy, - infer_template, -) -from snowflake.ml.model._model_composer import model_composer -from snowflake.ml.model._signatures import snowpark_handler -from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session, functions as F - - -@deprecated("Only used by PrPr model registry.") -@overload -def save_model( - *, - name: str, - model: model_types.SupportedNoSignatureRequirementsModelType, - session: Session, - stage_path: str, - metadata: Optional[Dict[str, str]] = None, - conda_dependencies: Optional[List[str]] = None, - pip_requirements: Optional[List[str]] = None, - python_version: Optional[str] = None, - ext_modules: Optional[List[ModuleType]] = None, - code_paths: Optional[List[str]] = None, - options: Optional[model_types.ModelSaveOption] = None, -) -> model_composer.ModelComposer: - """Save a model that does not require a signature as model to a stage path. - - Args: - name: Name of the model. - model: Model object. - session: Snowpark connection session. - stage_path: Path to the stage where model will be saved. - metadata: Model metadata. - conda_dependencies: List of Conda package specs. Use "[channel::]package [operator version]" syntax to specify - a dependency. It is a recommended way to specify your dependencies using conda. When channel is not - specified, defaults channel will be used. When deploying to Snowflake Warehouse, defaults channel would be - replaced with the Snowflake Anaconda channel. - pip_requirements: List of PIP package specs. Model will not be able to deploy to the warehouse if there is pip - requirements. - python_version: A string of python version where model is run. Used for user override. If specified as None, - current version would be captured. Defaults to None. - code_paths: Directory of code to import. - ext_modules: External modules that user might want to get pickled with model object. Defaults to None. - options: Model specific kwargs. - """ - ... - - -@deprecated("Only used by PrPr model registry.") -@overload -def save_model( - *, - name: str, - model: model_types.SupportedRequireSignatureModelType, - session: Session, - stage_path: str, - signatures: Dict[str, model_signature.ModelSignature], - metadata: Optional[Dict[str, str]] = None, - conda_dependencies: Optional[List[str]] = None, - pip_requirements: Optional[List[str]] = None, - python_version: Optional[str] = None, - ext_modules: Optional[List[ModuleType]] = None, - code_paths: Optional[List[str]] = None, - options: Optional[model_types.ModelSaveOption] = None, -) -> model_composer.ModelComposer: - """Save a model that requires a external signature with user provided signatures as model to a stage path. - - Args: - name: Name of the model. - model: Model object. - session: Snowpark connection session. - stage_path: Path to the stage where model will be saved. - signatures: Model data signatures for inputs and output for every target methods. - metadata: Model metadata. - conda_dependencies: List of Conda package specs. Use "[channel::]package [operator version]" syntax to specify - a dependency. It is a recommended way to specify your dependencies using conda. When channel is not - specified, defaults channel will be used. When deploying to Snowflake Warehouse, defaults channel would be - replaced with the Snowflake Anaconda channel. - pip_requirements: List of PIP package specs. Model will not be able to deploy to the warehouse if there is pip - requirements. - python_version: A string of python version where model is run. Used for user override. If specified as None, - current version would be captured. Defaults to None. - code_paths: Directory of code to import. - ext_modules: External modules that user might want to get pickled with model object. Defaults to None. - options: Model specific kwargs. - """ - ... - - -@deprecated("Only used by PrPr model registry.") -@overload -def save_model( - *, - name: str, - model: model_types.SupportedRequireSignatureModelType, - session: Session, - stage_path: str, - sample_input_data: model_types.SupportedDataType, - metadata: Optional[Dict[str, str]] = None, - conda_dependencies: Optional[List[str]] = None, - pip_requirements: Optional[List[str]] = None, - python_version: Optional[str] = None, - ext_modules: Optional[List[ModuleType]] = None, - code_paths: Optional[List[str]] = None, - options: Optional[model_types.ModelSaveOption] = None, -) -> model_composer.ModelComposer: - """Save a model that requires a external signature as model to a stage path with signature inferred from a - sample_input_data. - - Args: - name: Name of the model. - model: Model object. - session: Snowpark connection session. - stage_path: Path to the stage where model will be saved. - sample_input_data: Sample input data to infer the model signatures from. - metadata: Model metadata. - conda_dependencies: List of Conda package specs. Use "[channel::]package [operator version]" syntax to specify - a dependency. It is a recommended way to specify your dependencies using conda. When channel is not - specified, defaults channel will be used. When deploying to Snowflake Warehouse, defaults channel would be - replaced with the Snowflake Anaconda channel. - pip_requirements: List of PIP package specs. Model will not be able to deploy to the warehouse if there is pip - requirements. - python_version: A string of python version where model is run. Used for user override. If specified as None, - current version would be captured. Defaults to None. - code_paths: Directory of code to import. - ext_modules: External modules that user might want to get pickled with model object. Defaults to None. - options: Model specific kwargs. - """ - ... - - -@deprecated("Only used by PrPr model registry.") -def save_model( - *, - name: str, - model: model_types.SupportedModelType, - session: Session, - stage_path: str, - signatures: Optional[Dict[str, model_signature.ModelSignature]] = None, - sample_input_data: Optional[model_types.SupportedDataType] = None, - metadata: Optional[Dict[str, str]] = None, - conda_dependencies: Optional[List[str]] = None, - pip_requirements: Optional[List[str]] = None, - python_version: Optional[str] = None, - ext_modules: Optional[List[ModuleType]] = None, - code_paths: Optional[List[str]] = None, - options: Optional[model_types.ModelSaveOption] = None, -) -> model_composer.ModelComposer: - """Save the model. - - Args: - name: Name of the model. - model: Model object. - session: Snowpark connection session. - stage_path: Path to the stage where model will be saved. - signatures: Model data signatures for inputs and output for every target methods. If it is None, - sample_input_data would be used to infer the signatures if it is a local (non-SnowML modeling model). - If not None, sample_input_data should not be specified. Defaults to None. - sample_input_data: Sample input data to infer the model signatures from. If it is None, signatures must be - specified if it is a local (non-SnowML modeling model). If not None, signatures should not be specified. - Defaults to None. - metadata: Model metadata. - conda_dependencies: List of Conda package specs. Use "[channel::]package [operator version]" syntax to specify - a dependency. It is a recommended way to specify your dependencies using conda. When channel is not - specified, defaults channel will be used. When deploying to Snowflake Warehouse, defaults channel would be - replaced with the Snowflake Anaconda channel. - pip_requirements: List of PIP package specs. Model will not be able to deploy to the warehouse if there is pip - requirements. - python_version: A string of python version where model is run. Used for user override. If specified as None, - current version would be captured. Defaults to None. - code_paths: Directory of code to import. - ext_modules: External modules that user might want to get pickled with model object. Defaults to None. - options: Model specific kwargs. - - Returns: - Model - """ - if options is None: - options = {} - options["_legacy_save"] = True - - m = model_composer.ModelComposer(session=session, stage_path=stage_path) - m.save( - name=name, - model=model, - signatures=signatures, - sample_input_data=sample_input_data, - metadata=metadata, - conda_dependencies=conda_dependencies, - pip_requirements=pip_requirements, - python_version=python_version, - ext_modules=ext_modules, - code_paths=code_paths, - options=options, - ) - return m - - -@deprecated("Only used by PrPr model registry.") -@overload -def load_model(*, session: Session, stage_path: str) -> model_composer.ModelComposer: - """Load the model into memory from a zip file in the stage. - - Args: - session: Snowflake connection session. - stage_path: Path to the stage where model will be loaded from. - """ - ... - - -@deprecated("Only used by PrPr model registry.") -@overload -def load_model(*, session: Session, stage_path: str, meta_only: Literal[False]) -> model_composer.ModelComposer: - """Load the model into memory from a zip file in the stage. - - Args: - session: Snowflake connection session. - stage_path: Path to the stage where model will be loaded from. - meta_only: Flag to indicate that if only load metadata. - """ - ... - - -@deprecated("Only used by PrPr model registry.") -@overload -def load_model(*, session: Session, stage_path: str, meta_only: Literal[True]) -> model_composer.ModelComposer: - """Load the model into memory from a zip file in the stage with metadata only. - - Args: - session: Snowflake connection session. - stage_path: Path to the stage where model will be loaded from. - meta_only: Flag to indicate that if only load metadata. - """ - ... - - -@deprecated("Only used by PrPr model registry.") -def load_model( - *, - session: Session, - stage_path: str, - meta_only: bool = False, -) -> model_composer.ModelComposer: - """Load the model into memory from directory or a zip file in the stage. - - Args: - session: Snowflake connection session. Must be specified when specifying model_stage_file_path. - Exclusive with model_dir_path. - stage_path: Path to the stage where model will be loaded from. - meta_only: Flag to indicate that if only load metadata. - - Returns: - Loaded model. - """ - m = model_composer.ModelComposer(session=session, stage_path=stage_path) - m.legacy_load(meta_only=meta_only) - return m - - -@deprecated("Only used by PrPr model registry.") -@overload -def deploy( - session: Session, - *, - name: str, - platform: deploy_platforms.TargetPlatform, - target_method: Optional[str], - stage_path: str, - options: Optional[model_types.DeployOptions], -) -> Optional[model_types.Deployment]: - """Create a deployment from a model in a zip file in a stage and deploy it to remote platform. - - Args: - session: Snowpark Connection Session. - name: Name of the deployment for the model. - platform: Target platform to deploy the model. - target_method: The name of the target method to be deployed. Can be omitted if there is only 1 target method in - the model. - stage_path: Path to the stage where model will be deployed. - options: Additional options when deploying the model. - Each target platform will have their own specifications of options. - """ - ... - - -@deprecated("Only used by PrPr model registry.") -@overload -def deploy( - session: Session, - *, - model_id: str, - name: str, - platform: deploy_platforms.TargetPlatform, - target_method: Optional[str], - stage_path: str, - deployment_stage_path: str, - options: Optional[model_types.DeployOptions], -) -> Optional[model_types.Deployment]: - """Create a deployment from a model in a local directory and deploy it to remote platform. - - Args: - session: Snowpark Connection Session. - model_id: Internal model ID string. - name: Name of the deployment for the model. - platform: Target platform to deploy the model. - target_method: The name of the target method to be deployed. Can be omitted if there is only 1 target method in - the model. - stage_path: Path to the stage where model will be deployed. - deployment_stage_path: Path to stage containing snowpark container service deployment artifacts. - options: Additional options when deploying the model. - Each target platform will have their own specifications of options. - """ - ... - - -@deprecated("Only used by PrPr model registry.") -def deploy( - session: Session, - *, - name: str, - platform: deploy_platforms.TargetPlatform, - stage_path: str, - target_method: Optional[str] = None, - deployment_stage_path: Optional[str] = None, - model_id: Optional[str] = None, - options: Optional[model_types.DeployOptions], -) -> Optional[model_types.Deployment]: - """Create a deployment from a model and deploy it to remote platform. - - Args: - session: Snowpark Connection Session. - model_id: Internal model ID string. - name: Name of the deployment for the model. - platform: Target platform to deploy the model. - target_method: The name of the target method to be deployed. Can be omitted if there is only 1 target method in - the model. - stage_path: Path to the stage where model will be deployed. - deployment_stage_path: Path to stage containing deployment artifacts. - options: Additional options when deploying the model. - Each target platform will have their own specifications of options. - - Raises: - SnowflakeMLException: Raised when target platform is unsupported. - SnowflakeMLException: Raised when target method does not exist in model. - - Returns: - The deployment information. - """ - - info = None - - if not options: - options = {} - - m = load_model(session=session, stage_path=stage_path, meta_only=True) - assert m.packager.meta - - if target_method is None: - if len(m.packager.meta.signatures.keys()) == 1: - target_method = list(m.packager.meta.signatures.keys())[0] - else: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - "Only when the model has 1 target methods can target_method be omitted when deploying." - ), - ) - - details: model_types.DeployDetails = {} - if platform == deploy_platforms.TargetPlatform.WAREHOUSE: - warehouse_deploy._deploy_to_warehouse( - session=session, - model_stage_file_path=m.model_stage_path, - model_meta=m.packager.meta, - udf_name=name, - target_method=target_method, - **options, - ) - - elif platform == deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES: - options = cast(model_types.SnowparkContainerServiceDeployOptions, options) - assert model_id, "Require 'model_id' for Snowpark container service deployment" - assert m.model_stage_path, "Require 'model_stage_file_path' for Snowpark container service deployment" - assert deployment_stage_path, "Require 'deployment_stage_path' for Snowpark container service deployment" - if snowservice_constants.COMPUTE_POOL not in options: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - "Missing 'compute_pool' in options field for Snowpark container service deployment" - ), - ) - - details = snowservice_deploy._deploy( - session=session, - model_id=model_id, - model_meta=m.packager.meta, - service_func_name=name, - model_zip_stage_path=m.model_stage_path, - deployment_stage_path=deployment_stage_path, - target_method=target_method, - **options, - ) - - else: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_TYPE, - original_exception=ValueError(f"Unsupported target Platform: {platform}"), - ) - signature = m.packager.meta.signatures.get(target_method, None) - if not signature: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError(f"Target method {target_method} does not exist in model."), - ) - info = model_types.Deployment( - name=name, platform=platform, target_method=target_method, signature=signature, options=options, details=details - ) - return info - - -@deprecated("Only used by PrPr model registry.") -@overload -def predict( - session: Session, - *, - deployment: model_types.Deployment, - X: model_types.SupportedLocalDataType, - statement_params: Optional[Dict[str, Any]] = None, -) -> pd.DataFrame: - """Execute batch inference of a model remotely on local data. Can be any supported data type. Return a local - Pandas Dataframe. - - Args: - session: Snowpark Connection Session. - deployment: The deployment info to use for predict. - X: The input data. - statement_params: Statement Parameters for telemetry. - """ - ... - - -@deprecated("Only used by PrPr model registry.") -@overload -def predict( - session: Session, - *, - deployment: model_types.Deployment, - X: SnowparkDataFrame, - statement_params: Optional[Dict[str, Any]] = None, -) -> SnowparkDataFrame: - """Execute batch inference of a model remotely on a Snowpark DataFrame. Return a Snowpark DataFrame. - - Args: - session: Snowpark Connection Session. - deployment: The deployment info to use for predict. - X: The input Snowpark dataframe. - statement_params: Statement Parameters for telemetry. - """ - ... - - -@deprecated("Only used by PrPr model registry.") -def predict( - session: Session, - *, - deployment: model_types.Deployment, - X: Union[model_types.SupportedDataType, SnowparkDataFrame], - statement_params: Optional[Dict[str, Any]] = None, -) -> Union[pd.DataFrame, SnowparkDataFrame]: - """Execute batch inference of a model remotely. - - Args: - session: Snowpark Connection Session. - deployment: The deployment info to use for predict. - X: The input dataframe. - statement_params: Statement Parameters for telemetry. - - Returns: - The output dataframe. - """ - - # Get options - INTERMEDIATE_OBJ_NAME = "tmp_result" - sig = deployment["signature"] - identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED - - # Validate and prepare input - if not isinstance(X, SnowparkDataFrame): - keep_order = True - output_with_input_features = False - df = model_signature._convert_and_validate_local_data(X, sig.inputs) - s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df( - session, df, keep_order=keep_order, features=sig.inputs - ) - else: - keep_order = False - output_with_input_features = True - identifier_rule = model_signature._validate_snowpark_data(X, sig.inputs) - s_df = X - - if statement_params: - if s_df._statement_params is not None: - s_df._statement_params.update(statement_params) - else: - s_df._statement_params = statement_params # type: ignore[assignment] - - original_cols = s_df.columns - - # Infer and get intermediate result - input_cols = [] - for input_feature in sig.inputs: - literal_col_name = input_feature.name - col_name = identifier_rule.get_identifier_from_feature(input_feature.name) - - input_cols.extend( - [ - F.lit(literal_col_name), - F.col(col_name), - ] - ) - - udf_name = deployment["name"] - output_obj = F.call_udf(udf_name, F.object_construct_keep_null(*input_cols)) - df_res = s_df.with_column(INTERMEDIATE_OBJ_NAME, output_obj) - - if keep_order: - df_res = df_res.order_by( - F.col(infer_template._KEEP_ORDER_COL_NAME), - ascending=True, - ) - - if not output_with_input_features: - df_res = df_res.drop(*original_cols) - - # Prepare the output - output_cols = [] - output_col_names = [] - for output_feature in sig.outputs: - output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_feature.name].astype(output_feature.as_snowpark_type())) - output_col_names.append(identifier_rule.get_identifier_from_feature(output_feature.name)) - - df_res = df_res.with_columns( - output_col_names, - output_cols, - ).drop(INTERMEDIATE_OBJ_NAME) - - # Get final result - if not isinstance(X, SnowparkDataFrame): - return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=sig.outputs) - else: - return df_res diff --git a/snowflake/ml/model/_client/model/model_version_impl.py b/snowflake/ml/model/_client/model/model_version_impl.py index 97cc93e0..bec04914 100644 --- a/snowflake/ml/model/_client/model/model_version_impl.py +++ b/snowflake/ml/model/_client/model/model_version_impl.py @@ -310,12 +310,12 @@ def _get_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]: project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def get_model_objective(self) -> model_types.ModelObjective: + def get_model_task(self) -> model_types.Task: statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - return self._model_ops.get_model_objective( + return self._model_ops.get_model_task( database_name=None, schema_name=None, model_name=self._model_name, @@ -423,6 +423,7 @@ def run( partition_column = sql_identifier.SqlIdentifier(partition_column) functions: List[model_manifest_schema.ModelFunctionInfo] = self._functions + if function_name: req_method_name = sql_identifier.SqlIdentifier(function_name).identifier() find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = ( @@ -625,6 +626,7 @@ def _load_from_lineage_node(session: Session, name: str, version: str) -> "Model "image_repo", "gpu_requests", "num_workers", + "max_batch_rows", ], ) def create_service( @@ -638,6 +640,7 @@ def create_service( max_instances: int = 1, gpu_requests: Optional[str] = None, num_workers: Optional[int] = None, + max_batch_rows: Optional[int] = None, force_rebuild: bool = False, build_external_access_integration: str, ) -> str: @@ -646,22 +649,27 @@ def create_service( Args: service_name: The name of the service, can be fully qualified. If not fully qualified, the database or schema of the model will be used. - image_build_compute_pool: The name of the compute pool used to build the model inference image. Use + image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses the service compute pool if None. service_compute_pool: The name of the compute pool used to run the inference service. image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database or schema of the model will be used. - ingress_enabled: Whether to enable ingress. - max_instances: The maximum number of inference service instances to run. + ingress_enabled: If true, creates an service endpoint associated with the service. User must have + BIND SERVICE ENDPOINT privilege on the account. + max_instances: The maximum number of inference service instances to run. The same value it set to + MIN_INSTANCES property of the service. gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU if None. - num_workers: The number of workers (replicas of models) to run the inference service. - Auto determined if None. + num_workers: The number of workers to run the inference service for handling requests in parallel within an + instance of the service. By default, it is set to 2*vCPU+1 of the node for CPU based inference and 1 for + GPU based inference. For GPU based inference, please see best practices before playing with this value. + max_batch_rows: The maximum number of rows to batch for inference. Auto determined if None. Minimum 32. force_rebuild: Whether to force a model inference image rebuild. - build_external_access_integration: The external access integration for image build. + build_external_access_integration: The external access integration for image build. This is usually + permitting access to conda & PyPI repositories. Returns: - The service name. + Result information about service creation from server. """ statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, @@ -690,10 +698,71 @@ def create_service( max_instances=max_instances, gpu_requests=gpu_requests, num_workers=num_workers, + max_batch_rows=max_batch_rows, force_rebuild=force_rebuild, build_external_access_integration=sql_identifier.SqlIdentifier(build_external_access_integration), statement_params=statement_params, ) + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def list_services( + self, + ) -> List[str]: + """List all the service names using this model version. + + Returns: + List of service_names: The name of the service, can be fully qualified. If not fully qualified, the database + or schema of the model will be used. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + + return self._model_ops.list_inference_services( + database_name=None, + schema_name=None, + model_name=self._model_name, + version_name=self._version_name, + statement_params=statement_params, + ) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def delete_service( + self, + service_name: str, + ) -> None: + """Drops the given service. + + Args: + service_name: The name of the service, can be fully qualified. If not fully qualified, the database or + schema of the model will be used. + + Raises: + ValueError: If the service does not exist or operation is not permitted by user or service does not belong + to this model. + """ + if not service_name: + raise ValueError("service_name cannot be empty.") + + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + self._model_ops.delete_service( + database_name=None, + schema_name=None, + model_name=self._model_name, + version_name=self._version_name, + service_name=service_name, + statement_params=statement_params, + ) + lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion diff --git a/snowflake/ml/model/_client/model/model_version_impl_test.py b/snowflake/ml/model/_client/model/model_version_impl_test.py index e17c588f..c11e2606 100644 --- a/snowflake/ml/model/_client/model/model_version_impl_test.py +++ b/snowflake/ml/model/_client/model/model_version_impl_test.py @@ -223,14 +223,14 @@ def test_get_functions(self) -> None: statement_params=mock.ANY, ) - def test_get_model_objective(self) -> None: + def test_get_model_task(self) -> None: with mock.patch.object( self.m_mv._model_ops, - attribute="get_model_objective", - return_value=model_types.ModelObjective.REGRESSION, - ) as mock_get_model_objective: - self.assertEqual(model_types.ModelObjective.REGRESSION, self.m_mv.get_model_objective()) - mock_get_model_objective.assert_called_once_with( + attribute="get_model_task", + return_value=model_types.Task.TABULAR_REGRESSION, + ) as mock_get_model_task: + self.assertEqual(model_types.Task.TABULAR_REGRESSION, self.m_mv.get_model_task()) + mock_get_model_task.assert_called_once_with( database_name=None, schema_name=None, model_name=sql_identifier.SqlIdentifier("MODEL"), @@ -732,6 +732,7 @@ def test_create_service(self) -> None: max_instances=3, gpu_requests="GPU", num_workers=1, + max_batch_rows=1024, force_rebuild=True, build_external_access_integration="EAI", ) @@ -752,6 +753,7 @@ def test_create_service(self) -> None: max_instances=3, gpu_requests="GPU", num_workers=1, + max_batch_rows=1024, force_rebuild=True, build_external_access_integration=sql_identifier.SqlIdentifier("EAI"), statement_params=mock.ANY, @@ -766,6 +768,7 @@ def test_create_service_same_pool(self) -> None: max_instances=3, gpu_requests="GPU", num_workers=1, + max_batch_rows=1024, force_rebuild=True, build_external_access_integration="EAI", ) @@ -786,11 +789,42 @@ def test_create_service_same_pool(self) -> None: max_instances=3, gpu_requests="GPU", num_workers=1, + max_batch_rows=1024, force_rebuild=True, build_external_access_integration=sql_identifier.SqlIdentifier("EAI"), statement_params=mock.ANY, ) + def test_list_services(self) -> None: + with mock.patch.object( + self.m_mv._model_ops, attribute="list_inference_services", return_value=["a.b.c", "d.e.f"] + ) as mock_get_functions: + self.assertListEqual(["a.b.c", "d.e.f"], self.m_mv.list_services()) + mock_get_functions.assert_called_once_with( + database_name=None, + schema_name=None, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + def test_delete_service_empty(self) -> None: + with self.assertRaisesRegex(ValueError, "service_name cannot be empty."): + self.m_mv.delete_service("") + + def test_delete_service(self) -> None: + with mock.patch.object(self.m_mv._model_ops, attribute="delete_service") as mock_delete_service: + self.m_mv.delete_service("c") + + mock_delete_service.assert_called_with( + database_name=None, + schema_name=None, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + service_name="c", + statement_params=mock.ANY, + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_client/ops/BUILD.bazel b/snowflake/ml/model/_client/ops/BUILD.bazel index fef6e76c..3efdb2b9 100644 --- a/snowflake/ml/model/_client/ops/BUILD.bazel +++ b/snowflake/ml/model/_client/ops/BUILD.bazel @@ -11,6 +11,7 @@ py_library( srcs = ["model_ops.py"], deps = [ ":metadata_ops", + "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model:model_signature", @@ -35,6 +36,7 @@ py_test( srcs = ["model_ops_test.py"], deps = [ ":model_ops", + "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model:model_signature", "//snowflake/ml/model/_packager/model_meta", @@ -69,6 +71,8 @@ py_library( name = "service_ops", srcs = ["service_ops.py"], deps = [ + "//snowflake/ml/_internal/utils:service_logger", + "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model/_client/service:model_deployment_spec", "//snowflake/ml/model/_client/sql:service", diff --git a/snowflake/ml/model/_client/ops/model_ops.py b/snowflake/ml/model/_client/ops/model_ops.py index a5c0aee8..70b8bbf1 100644 --- a/snowflake/ml/model/_client/ops/model_ops.py +++ b/snowflake/ml/model/_client/ops/model_ops.py @@ -1,3 +1,4 @@ +import json import os import pathlib import tempfile @@ -6,6 +7,7 @@ import yaml +from snowflake.ml._internal.exceptions import error_codes, exceptions from snowflake.ml._internal.utils import formatting, identifier, sql_identifier from snowflake.ml.model import model_signature, type_hints from snowflake.ml.model._client.ops import metadata_ops @@ -512,6 +514,71 @@ def unset_tag( statement_params=statement_params, ) + def list_inference_services( + self, + *, + database_name: Optional[sql_identifier.SqlIdentifier], + schema_name: Optional[sql_identifier.SqlIdentifier], + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> List[str]: + res = self._model_client.show_versions( + database_name=database_name, + schema_name=schema_name, + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + col_name = self._model_client.MODEL_VERSION_INFERENCE_SERVICES_COL_NAME + if col_name not in res[0]: + # User need to opt into BCR 2024_08 + raise exceptions.SnowflakeMLException( + error_code=error_codes.OPT_IN_REQUIRED, + original_exception=RuntimeError( + "Please opt in to BCR Bundle 2024_08 (" + "https://docs.snowflake.com/en/release-notes/bcr-bundles/2024_08_bundle)." + ), + ) + json_array = json.loads(res[0][col_name]) + # TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side. + return [str(service) for service in json_array if "MODEL_BUILD_" not in service] + + def delete_service( + self, + *, + database_name: Optional[sql_identifier.SqlIdentifier], + schema_name: Optional[sql_identifier.SqlIdentifier], + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + service_name: str, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + services = self.list_inference_services( + database_name=database_name, + schema_name=schema_name, + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + db, schema, service_name = sql_identifier.parse_fully_qualified_name(service_name) + fully_qualified_service_name = sql_identifier.get_fully_qualified_name( + db, schema, service_name, self._session.get_current_database(), self._session.get_current_schema() + ) + + for service in services: + if service == fully_qualified_service_name: + self._service_client.drop_service( + database_name=db, + schema_name=schema, + service_name=service_name, + statement_params=statement_params, + ) + return + raise ValueError( + f"Service '{service_name}' does not exist or unauthorized or not associated with this model version." + ) + def get_model_version_manifest( self, *, @@ -538,7 +605,8 @@ def get_model_version_manifest( def _match_model_spec_with_sql_functions( sql_functions_names: List[sql_identifier.SqlIdentifier], target_methods: List[str] ) -> Dict[sql_identifier.SqlIdentifier, str]: - res = {} + res: Dict[sql_identifier.SqlIdentifier, str] = {} + for target_method in target_methods: # Here we need to find the SQL function corresponding to the Python function. # If the python function name is `abc`, then SQL function name can be `ABC` or `"abc"`. @@ -574,7 +642,7 @@ def _fetch_model_spec( model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict) return model_spec - def get_model_objective( + def get_model_task( self, *, database_name: Optional[sql_identifier.SqlIdentifier], @@ -582,7 +650,7 @@ def get_model_objective( model_name: sql_identifier.SqlIdentifier, version_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, - ) -> type_hints.ModelObjective: + ) -> type_hints.Task: model_spec = self._fetch_model_spec( database_name=database_name, schema_name=schema_name, @@ -590,8 +658,8 @@ def get_model_objective( version_name=version_name, statement_params=statement_params, ) - model_objective_val = model_spec.get("model_objective", type_hints.ModelObjective.UNKNOWN.value) - return type_hints.ModelObjective(model_objective_val) + task_val = model_spec.get("task", type_hints.Task.UNKNOWN.value) + return type_hints.Task(task_val) def get_functions( self, @@ -633,6 +701,20 @@ def get_functions( function_names_and_types.append((function_name, function_type)) + if not function_names_and_types: + # If function_names_and_types is not populated, there are currently + # no warehouse functions for the model version. In order to do inference + # we must populate the functions so the mapping can be constructed. + model_manifest = self.get_model_version_manifest( + database_name=database_name, + schema_name=schema_name, + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + for method in model_manifest["methods"]: + function_names_and_types.append((sql_identifier.SqlIdentifier(method["name"]), method["type"])) + signatures = model_spec["signatures"] function_names = [name for name, _ in function_names_and_types] function_name_mapping = ModelOperator._match_model_spec_with_sql_functions( @@ -799,7 +881,7 @@ def invoke_method( if keep_order: # if it's a partitioned table function, _ID will be null and we won't be able to sort. - if df_res.select("_ID").limit(1).collect()[0][0] is None: + if df_res.select(snowpark_handler._KEEP_ORDER_COL_NAME).limit(1).collect()[0][0] is None: warnings.warn( formatting.unwrap( """ @@ -812,7 +894,7 @@ def invoke_method( ) else: df_res = df_res.sort( - "_ID", + snowpark_handler._KEEP_ORDER_COL_NAME, ascending=True, ) diff --git a/snowflake/ml/model/_client/ops/model_ops_test.py b/snowflake/ml/model/_client/ops/model_ops_test.py index cfb9b0de..a0926f41 100644 --- a/snowflake/ml/model/_client/ops/model_ops_test.py +++ b/snowflake/ml/model/_client/ops/model_ops_test.py @@ -7,6 +7,7 @@ import yaml from absl.testing import absltest +from snowflake.ml._internal.exceptions import exceptions from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.model import model_signature, type_hints from snowflake.ml.model._client.ops import model_ops @@ -459,6 +460,173 @@ def test_unset_tag(self) -> None: statement_params=self.m_statement_params, ) + def test_list_inference_services(self) -> None: + m_list_res = [Row(inference_services='["a.b.c", "d.e.f"]')] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + res = self.m_ops.list_inference_services( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + self.assertListEqual(res, ["a.b.c", "d.e.f"]) + mock_show_versions.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + + def test_list_inference_services_pre_bcr(self) -> None: + m_list_res = [Row(comment="mycomment")] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + with self.assertRaises(exceptions.SnowflakeMLException) as context: + self.m_ops.list_inference_services( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + self.assertEqual( + str(context.exception), + "RuntimeError('(2104) Please opt in to BCR Bundle 2024_08 " + "(https://docs.snowflake.com/en/release-notes/bcr-bundles/2024_08_bundle).')", + ) + mock_show_versions.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + + def test_list_inference_services_skip_build(self) -> None: + m_list_res = [Row(inference_services='["A.B.MODEL_BUILD_34d35ew", "A.B.SERVICE"]')] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + res = self.m_ops.list_inference_services( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + self.assertListEqual(res, ["A.B.SERVICE"]) + mock_show_versions.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + + def test_delete_service_non_existent(self) -> None: + m_list_res = [Row(inference_services='["A.B.C", "D.E.F"]')] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions, mock.patch.object( + self.m_session, attribute="get_current_database", return_value="a" + ) as mock_get_database, mock.patch.object( + self.m_session, attribute="get_current_schema", return_value="b" + ) as mock_get_schema: + with self.assertRaisesRegex( + ValueError, "Service 'A' does not exist or unauthorized or not associated with this model version." + ): + self.m_ops.delete_service( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + service_name="a", + ) + with self.assertRaisesRegex( + ValueError, "Service 'B' does not exist or unauthorized or not associated with this model version." + ): + self.m_ops.delete_service( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + service_name="a.b", + ) + with self.assertRaisesRegex( + ValueError, "Service 'D' does not exist or unauthorized or not associated with this model version." + ): + self.m_ops.delete_service( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + service_name="b.c.d", + ) + + mock_show_versions.assert_called_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + mock_get_database.assert_called() + mock_get_schema.assert_called() + + def test_delete_service_exists(self) -> None: + m_list_res = [Row(inference_services='["A.B.C", "D.E.F"]')] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions, mock.patch.object( + self.m_ops._service_client, "drop_service" + ) as mock_drop_service, mock.patch.object( + self.m_session, attribute="get_current_database", return_value="a" + ) as mock_get_database, mock.patch.object( + self.m_session, attribute="get_current_schema", return_value="b" + ) as mock_get_schema: + self.m_ops.delete_service( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + service_name="c", + ) + self.m_ops.delete_service( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + service_name="b.c", + ) + self.m_ops.delete_service( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + service_name="a.b.c", + ) + + mock_show_versions.assert_called_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + mock_get_database.assert_called() + mock_get_schema.assert_called() + mock_drop_service.assert_called_with( + database_name="A", + schema_name="B", + service_name="C", + statement_params=mock.ANY, + ) + def test_create_from_stage_1(self) -> None: mock_composer = mock.MagicMock() mock_composer.stage_path = '@TEMP."test".MODEL/V1' @@ -1280,6 +1448,12 @@ def test_match_model_spec_with_sql_functions(self) -> None: [sql_identifier.SqlIdentifier("ABC")], ["predict"] ) + with self.assertRaises(AssertionError): + self.assertDictEqual( + {}, + model_ops.ModelOperator._match_model_spec_with_sql_functions([], ["predict"]), + ) + self.assertDictEqual( {sql_identifier.SqlIdentifier("PREDICT"): "predict"}, model_ops.ModelOperator._match_model_spec_with_sql_functions( @@ -1349,13 +1523,13 @@ def test_get_functions(self) -> None: ) mock_validate_model_metadata.assert_called_once_with(m_spec) - def test_get_model_objective(self) -> None: + def test_get_model_task(self) -> None: m_spec = { "signatures": { "predict": _DUMMY_SIG["predict"].to_dict(), "predict_table": _DUMMY_SIG["predict_table"].to_dict(), }, - "model_objective": "binary_classification", + "task": "TABULAR_BINARY_CLASSIFICATION", } m_show_versions_result = [Row(model_spec=yaml.safe_dump(m_spec))] with mock.patch.object( @@ -1367,7 +1541,7 @@ def test_get_model_objective(self) -> None: "_validate_model_metadata", return_value=cast(model_meta_schema.ModelMetadataDict, m_spec), ) as mock_validate_model_metadata: - res = self.m_ops.get_model_objective( + res = self.m_ops.get_model_task( database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), model_name=sql_identifier.SqlIdentifier("MODEL"), @@ -1383,9 +1557,9 @@ def test_get_model_objective(self) -> None: statement_params={**self.m_statement_params, "SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL": True}, ) mock_validate_model_metadata.assert_called_once_with(m_spec) - self.assertEqual(res, type_hints.ModelObjective.BINARY_CLASSIFICATION) + self.assertEqual(res, type_hints.Task.TABULAR_BINARY_CLASSIFICATION) - def test_get_model_objective_empty(self) -> None: + def test_get_model_task_empty(self) -> None: m_spec = { "signatures": { "predict": _DUMMY_SIG["predict"].to_dict(), @@ -1402,7 +1576,7 @@ def test_get_model_objective_empty(self) -> None: "_validate_model_metadata", return_value=cast(model_meta_schema.ModelMetadataDict, m_spec), ) as mock_validate_model_metadata: - res = self.m_ops.get_model_objective( + res = self.m_ops.get_model_task( database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), model_name=sql_identifier.SqlIdentifier("MODEL"), @@ -1418,7 +1592,7 @@ def test_get_model_objective_empty(self) -> None: statement_params={**self.m_statement_params, "SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL": True}, ) mock_validate_model_metadata.assert_called_once_with(m_spec) - self.assertEqual(res, type_hints.ModelObjective.UNKNOWN) + self.assertEqual(res, type_hints.Task.UNKNOWN) def test_download_files_minimal(self) -> None: m_list_files_res = [ diff --git a/snowflake/ml/model/_client/ops/service_ops.py b/snowflake/ml/model/_client/ops/service_ops.py index 84f80152..50392d0f 100644 --- a/snowflake/ml/model/_client/ops/service_ops.py +++ b/snowflake/ml/model/_client/ops/service_ops.py @@ -2,43 +2,49 @@ import hashlib import logging import pathlib -import queue -import sys +import re import tempfile import threading import time -import uuid from typing import Any, Dict, List, Optional, Tuple, cast +from packaging import version + from snowflake import snowpark from snowflake.ml._internal import file_utils -from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml._internal.utils import service_logger, snowflake_env, sql_identifier from snowflake.ml.model._client.service import model_deployment_spec from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql from snowflake.snowpark import exceptions, row, session from snowflake.snowpark._internal import utils as snowpark_utils - -def get_logger(logger_name: str) -> logging.Logger: - logger = logging.getLogger(logger_name) - logger.setLevel(logging.INFO) - handler = logging.StreamHandler(sys.stdout) - handler.setLevel(logging.INFO) - handler.setFormatter(logging.Formatter("%(name)s [%(asctime)s] [%(levelname)s] %(message)s")) - logger.addHandler(handler) - return logger - - -logger = get_logger(__name__) -logger.propagate = False +module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY) +module_logger.propagate = False @dataclasses.dataclass class ServiceLogInfo: - service_name: str + database_name: Optional[sql_identifier.SqlIdentifier] + schema_name: Optional[sql_identifier.SqlIdentifier] + service_name: sql_identifier.SqlIdentifier container_name: str instance_id: str = "0" + def __post_init__(self) -> None: + # service name used in logs for display + self.display_service_name = sql_identifier.get_fully_qualified_name( + self.database_name, self.schema_name, self.service_name + ) + + +@dataclasses.dataclass +class ServiceLogMetadata: + service_logger: logging.Logger + service: ServiceLogInfo + service_status: Optional[service_sql.ServiceStatus] + is_model_build_service_done: bool + log_offset: int + class ServiceOperator: """Service operator for container services logic.""" @@ -96,6 +102,7 @@ def create_service( max_instances: int, gpu_requests: Optional[str], num_workers: Optional[int], + max_batch_rows: Optional[int], force_rebuild: bool, build_external_access_integration: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, @@ -129,6 +136,7 @@ def create_service( max_instances=max_instances, gpu=gpu_requests, num_workers=num_workers, + max_batch_rows=max_batch_rows, force_rebuild=force_rebuild, external_access_integration=build_external_access_integration, ) @@ -140,15 +148,13 @@ def create_service( ) # check if the inference service is already running - try: - model_inference_service_status, _ = self._service_client.get_service_status( - service_name=service_name, - include_message=False, - statement_params=statement_params, - ) - model_inference_service_exists = model_inference_service_status == service_sql.ServiceStatus.READY - except exceptions.SnowparkSQLException: - model_inference_service_exists = False + model_inference_service_exists = self._check_if_service_exists( + database_name=service_database_name, + schema_name=service_schema_name, + service_name=service_name, + service_status_list_if_exists=[service_sql.ServiceStatus.READY], + statement_params=statement_params, + ) # deploy the model service query_id, async_job = self._service_client.deploy_model( @@ -157,39 +163,55 @@ def create_service( statement_params=statement_params, ) - # stream service logs in a thread - services = [ - ServiceLogInfo(service_name=self._get_model_build_service_name(query_id), container_name="model-build"), - ServiceLogInfo(service_name=service_name, container_name="model-inference"), - ] - exception_queue: queue.Queue = queue.Queue() # type: ignore[type-arg] - log_thread = self._start_service_log_streaming( - async_job, services, model_inference_service_exists, exception_queue, statement_params - ) - log_thread.join() - - try: - # non-blocking check for an exception - exception = exception_queue.get(block=False) - if exception: - raise exception - except queue.Empty: - pass + # TODO(hayu): Remove the version check after Snowflake 8.37.0 release + if snowflake_env.get_current_snowflake_version( + self._session, statement_params=statement_params + ) >= version.parse("8.37.0"): + # stream service logs in a thread + model_build_service_name = sql_identifier.SqlIdentifier(self._get_model_build_service_name(query_id)) + model_build_service = ServiceLogInfo( + database_name=service_database_name, + schema_name=service_schema_name, + service_name=model_build_service_name, + container_name="model-build", + ) + model_inference_service = ServiceLogInfo( + database_name=service_database_name, + schema_name=service_schema_name, + service_name=service_name, + container_name="model-inference", + ) + services = [model_build_service, model_inference_service] + log_thread = self._start_service_log_streaming( + async_job, services, model_inference_service_exists, force_rebuild, statement_params + ) + log_thread.join() + else: + while not async_job.is_done(): + time.sleep(5) - return service_name + res = cast(str, cast(List[row.Row], async_job.result())[0][0]) + module_logger.info(f"Inference service {service_name} deployment complete: {res}") + return res def _start_service_log_streaming( self, async_job: snowpark.AsyncJob, services: List[ServiceLogInfo], model_inference_service_exists: bool, - exception_queue: queue.Queue, # type: ignore[type-arg] + force_rebuild: bool, statement_params: Optional[Dict[str, Any]] = None, ) -> threading.Thread: """Start the service log streaming in a separate thread.""" log_thread = threading.Thread( target=self._stream_service_logs, - args=(async_job, services, model_inference_service_exists, exception_queue, statement_params), + args=( + async_job, + services, + model_inference_service_exists, + force_rebuild, + statement_params, + ), ) log_thread.start() return log_thread @@ -199,15 +221,17 @@ def _stream_service_logs( async_job: snowpark.AsyncJob, services: List[ServiceLogInfo], model_inference_service_exists: bool, - exception_queue: queue.Queue, # type: ignore[type-arg] + force_rebuild: bool, statement_params: Optional[Dict[str, Any]] = None, ) -> None: """Stream service logs while the async job is running.""" - def fetch_logs(service_name: str, container_name: str, offset: int) -> Tuple[str, int]: + def fetch_logs(service: ServiceLogInfo, offset: int) -> Tuple[str, int]: service_logs = self._service_client.get_service_logs( - service_name=service_name, - container_name=container_name, + database_name=service.database_name, + schema_name=service.schema_name, + service_name=service.service_name, + container_name=service.container_name, statement_params=statement_params, ) @@ -221,67 +245,121 @@ def fetch_logs(service_name: str, container_name: str, offset: int) -> Tuple[str return new_logs, new_offset - is_model_build_service_done = False - log_offset = 0 + def set_service_log_metadata_to_model_inference( + meta: ServiceLogMetadata, inference_service: ServiceLogInfo, msg: str + ) -> None: + model_inference_service_logger = service_logger.get_logger( # InferenceServiceName-InstanceId + f"{inference_service.display_service_name}-{inference_service.instance_id}", + service_logger.LogColor.BLUE, + ) + model_inference_service_logger.propagate = False + meta.service_logger = model_inference_service_logger + meta.service = inference_service + meta.service_status = None + meta.is_model_build_service_done = True + meta.log_offset = 0 + block_size = 180 + module_logger.info(msg) + module_logger.info("-" * block_size) + model_build_service, model_inference_service = services[0], services[1] - service_name, container_name = model_build_service.service_name, model_build_service.container_name - # BuildJobName - service_logger = get_logger(service_name) - service_logger.propagate = False + model_build_service_logger = service_logger.get_logger( # BuildJobName + model_build_service.display_service_name, service_logger.LogColor.GREEN + ) + model_build_service_logger.propagate = False + service_log_meta = ServiceLogMetadata( + service_logger=model_build_service_logger, + service=model_build_service, + service_status=None, + is_model_build_service_done=False, + log_offset=0, + ) while not async_job.is_done(): if model_inference_service_exists: time.sleep(5) continue try: - block_size = 180 + # check if using an existing model build image + if not force_rebuild and not service_log_meta.is_model_build_service_done: + model_build_service_exists = self._check_if_service_exists( + database_name=model_build_service.database_name, + schema_name=model_build_service.schema_name, + service_name=model_build_service.service_name, + statement_params=statement_params, + ) + new_model_inference_service_exists = self._check_if_service_exists( + database_name=model_inference_service.database_name, + schema_name=model_inference_service.schema_name, + service_name=model_inference_service.service_name, + statement_params=statement_params, + ) + if not model_build_service_exists and new_model_inference_service_exists: + set_service_log_metadata_to_model_inference( + service_log_meta, + model_inference_service, + "Model Inference image build is not rebuilding the image and using previously built image.", + ) + continue + service_status, message = self._service_client.get_service_status( - service_name=service_name, include_message=True, statement_params=statement_params + database_name=service_log_meta.service.database_name, + schema_name=service_log_meta.service.schema_name, + service_name=service_log_meta.service.service_name, + include_message=True, + statement_params=statement_params, ) - logger.info(f"Inference service {service_name} is {service_status.value}.") + if (service_status != service_sql.ServiceStatus.READY) or ( + service_status != service_log_meta.service_status + ): + service_log_meta.service_status = service_status + module_logger.info( + f"{'Inference' if service_log_meta.is_model_build_service_done else 'Image build'} service " + f"{service_log_meta.service.display_service_name} is " + f"{service_log_meta.service_status.value}." + ) + module_logger.info(f"Service message: {message}") - new_logs, new_offset = fetch_logs(service_name, container_name, log_offset) + new_logs, new_offset = fetch_logs( + service_log_meta.service, + service_log_meta.log_offset, + ) if new_logs: - service_logger.info(new_logs) - log_offset = new_offset + service_log_meta.service_logger.info(new_logs) + service_log_meta.log_offset = new_offset # check if model build service is done - if not is_model_build_service_done: + if not service_log_meta.is_model_build_service_done: service_status, _ = self._service_client.get_service_status( + database_name=model_build_service.database_name, + schema_name=model_build_service.schema_name, service_name=model_build_service.service_name, include_message=False, statement_params=statement_params, ) if service_status == service_sql.ServiceStatus.DONE: - is_model_build_service_done = True - log_offset = 0 - service_name = model_inference_service.service_name - container_name = model_inference_service.container_name - # InferenceServiceName-InstanceId - service_logger = get_logger(f"{service_name}-{model_inference_service.instance_id}") - service_logger.propagate = False - logger.info(f"Model build service {model_build_service.service_name} complete.") - logger.info("-" * block_size) - except ValueError: - logger.warning(f"Unknown service status: {service_status.value}") + set_service_log_metadata_to_model_inference( + service_log_meta, + model_inference_service, + f"Image build service {model_build_service.display_service_name} complete.", + ) except Exception as ex: - logger.warning(f"Caught an exception when logging: {repr(ex)}") + pattern = r"002003 \(02000\)" # error code: service does not exist + is_snowpark_sql_exception = isinstance(ex, exceptions.SnowparkSQLException) + contains_msg = any(msg in str(ex) for msg in ["Pending scheduling", "Waiting to start"]) + matches_pattern = service_log_meta.service_status is None and re.search(pattern, str(ex)) is not None + if not (is_snowpark_sql_exception and (contains_msg or matches_pattern)): + module_logger.warning(f"Caught an exception when logging: {repr(ex)}") time.sleep(5) if model_inference_service_exists: - logger.info(f"Inference service {model_inference_service.service_name} is already RUNNING.") + module_logger.info(f"Inference service {model_inference_service.display_service_name} is already RUNNING.") else: - self._finalize_logs(service_logger, services[-1], log_offset, statement_params) - - # catch exceptions from the deploy model execution - try: - res = cast(List[row.Row], async_job.result()) - logger.info(f"Model deployment for inference service {model_inference_service.service_name} complete.") - logger.info(res[0][0]) - except Exception as ex: - exception_queue.put(ex) + self._finalize_logs( + service_log_meta.service_logger, service_log_meta.service, service_log_meta.log_offset, statement_params + ) def _finalize_logs( self, @@ -292,7 +370,10 @@ def _finalize_logs( ) -> None: """Fetch service logs after the async job is done to ensure no logs are missed.""" try: + time.sleep(5) # wait for complete service logs service_logs = self._service_client.get_service_logs( + database_name=service.database_name, + schema_name=service.schema_name, service_name=service.service_name, container_name=service.container_name, statement_params=statement_params, @@ -301,12 +382,40 @@ def _finalize_logs( if len(service_logs) > offset: service_logger.info(service_logs[offset:]) except Exception as ex: - logger.warning(f"Caught an exception when logging: {repr(ex)}") + module_logger.warning(f"Caught an exception when logging: {repr(ex)}") @staticmethod def _get_model_build_service_name(query_id: str) -> str: """Get the model build service name through the server-side logic.""" - most_significant_bits = uuid.UUID(query_id).int >> 64 - md5_hash = hashlib.md5(str(most_significant_bits).encode()).hexdigest() - identifier = md5_hash[:6] + uuid = query_id.replace("-", "") + big_int = int(uuid, 16) + md5_hash = hashlib.md5(str(big_int).encode()).hexdigest() + identifier = md5_hash[:8] return ("model_build_" + identifier).upper() + + def _check_if_service_exists( + self, + database_name: Optional[sql_identifier.SqlIdentifier], + schema_name: Optional[sql_identifier.SqlIdentifier], + service_name: sql_identifier.SqlIdentifier, + service_status_list_if_exists: Optional[List[service_sql.ServiceStatus]] = None, + statement_params: Optional[Dict[str, Any]] = None, + ) -> bool: + if service_status_list_if_exists is None: + service_status_list_if_exists = [ + service_sql.ServiceStatus.PENDING, + service_sql.ServiceStatus.READY, + service_sql.ServiceStatus.DONE, + service_sql.ServiceStatus.FAILED, + ] + try: + service_status, _ = self._service_client.get_service_status( + database_name=database_name, + schema_name=schema_name, + service_name=service_name, + include_message=False, + statement_params=statement_params, + ) + return any(service_status == status for status in service_status_list_if_exists) + except exceptions.SnowparkSQLException: + return False diff --git a/snowflake/ml/model/_client/ops/service_ops_test.py b/snowflake/ml/model/_client/ops/service_ops_test.py index eee46a37..f6f93897 100644 --- a/snowflake/ml/model/_client/ops/service_ops_test.py +++ b/snowflake/ml/model/_client/ops/service_ops_test.py @@ -1,4 +1,3 @@ -import hashlib import pathlib import uuid from typing import cast @@ -11,14 +10,19 @@ from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.model._client.ops import service_ops from snowflake.ml.model._client.sql import service as service_sql -from snowflake.ml.test_utils import mock_session -from snowflake.snowpark import Session +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import Session, row from snowflake.snowpark._internal import utils as snowpark_utils class ModelOpsTest(absltest.TestCase): def setUp(self) -> None: self.m_session = mock_session.MockSession(conn=None, test_case=self) + # TODO(hayu): Remove mock sql after Snowflake 8.37.0 release + query = "SELECT CURRENT_VERSION() AS CURRENT_VERSION" + sql_result = [row.Row(CURRENT_VERSION="8.37.0 1234567890ab")] + self.m_session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) + self.m_statement_params = {"test": "1"} self.c_session = cast(Session, self.m_session) self.m_ops = service_ops.ServiceOperator( @@ -48,7 +52,7 @@ def test_create_service(self) -> None: version_name=sql_identifier.SqlIdentifier("VERSION"), service_database_name=sql_identifier.SqlIdentifier("SERVICE_DB"), service_schema_name=sql_identifier.SqlIdentifier("SERVICE_SCHEMA"), - service_name=sql_identifier.SqlIdentifier("SERVICE"), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), image_repo_database_name=sql_identifier.SqlIdentifier("IMAGE_REPO_DB"), @@ -58,6 +62,7 @@ def test_create_service(self) -> None: max_instances=1, gpu_requests="1", num_workers=1, + max_batch_rows=1024, force_rebuild=True, build_external_access_integration=sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION"), statement_params=self.m_statement_params, @@ -75,7 +80,7 @@ def test_create_service(self) -> None: version_name=sql_identifier.SqlIdentifier("VERSION"), service_database_name=sql_identifier.SqlIdentifier("SERVICE_DB"), service_schema_name=sql_identifier.SqlIdentifier("SERVICE_SCHEMA"), - service_name=sql_identifier.SqlIdentifier("SERVICE"), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), image_repo_database_name=sql_identifier.SqlIdentifier("IMAGE_REPO_DB"), @@ -85,6 +90,7 @@ def test_create_service(self) -> None: max_instances=1, gpu="1", num_workers=1, + max_batch_rows=1024, force_rebuild=True, external_access_integration=sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION"), ) @@ -106,18 +112,28 @@ def test_create_service(self) -> None: statement_params=self.m_statement_params, ) mock_get_service_status.assert_called_once_with( - service_name="SERVICE", + database_name="SERVICE_DB", + schema_name="SERVICE_SCHEMA", + service_name="MYSERVICE", include_message=False, statement_params=self.m_statement_params, ) def test_get_model_build_service_name(self) -> None: - query_id = str(uuid.uuid4()) - most_significant_bits = uuid.UUID(query_id).int >> 64 - md5_hash = hashlib.md5(str(most_significant_bits).encode()).hexdigest() - identifier = md5_hash[:6] - service_name = ("model_build_" + identifier).upper() - self.assertEqual(self.m_ops._get_model_build_service_name(query_id), service_name) + query_id = "01b6fc10-0002-c121-0000-6ed10736311e" + """ + Java code to generate the expected value: + import java.math.BigInteger; + import org.apache.commons.codec.digest.DigestUtils; + String uuid = "01b6fc10-0002-c121-0000-6ed10736311e"; + String uuidString = uuid.replace("-", ""); + BigInteger bigInt = new BigInteger(uuidString, 16); + String identifier = DigestUtils.md5Hex(bigInt.toString()).substring(0, 8); + System.out.println(identifier); + """ + identifier = "81edd120" + expected = ("model_build_" + identifier).upper() + self.assertEqual(self.m_ops._get_model_build_service_name(query_id), expected) if __name__ == "__main__": diff --git a/snowflake/ml/model/_client/service/model_deployment_spec.py b/snowflake/ml/model/_client/service/model_deployment_spec.py index a7946366..5e3d0264 100644 --- a/snowflake/ml/model/_client/service/model_deployment_spec.py +++ b/snowflake/ml/model/_client/service/model_deployment_spec.py @@ -38,6 +38,7 @@ def save( max_instances: int, gpu: Optional[str], num_workers: Optional[int], + max_batch_rows: Optional[int], force_rebuild: bool, external_access_integration: sql_identifier.SqlIdentifier, ) -> None: @@ -79,6 +80,9 @@ def save( if num_workers: service_dict["num_workers"] = num_workers + if max_batch_rows: + service_dict["max_batch_rows"] = max_batch_rows + # model deployment spec model_deployment_spec_dict = model_deployment_spec_schema.ModelDeploymentSpecDict( models=[model_dict], diff --git a/snowflake/ml/model/_client/service/model_deployment_spec_schema.py b/snowflake/ml/model/_client/service/model_deployment_spec_schema.py index 2dd58ce6..c9f71a6e 100644 --- a/snowflake/ml/model/_client/service/model_deployment_spec_schema.py +++ b/snowflake/ml/model/_client/service/model_deployment_spec_schema.py @@ -22,6 +22,7 @@ class ServiceDict(TypedDict): max_instances: Required[int] gpu: NotRequired[str] num_workers: NotRequired[int] + max_batch_rows: NotRequired[int] class ModelDeploymentSpecDict(TypedDict): diff --git a/snowflake/ml/model/_client/service/model_deployment_spec_test.py b/snowflake/ml/model/_client/service/model_deployment_spec_test.py index ae7b7aa1..5f99d074 100644 --- a/snowflake/ml/model/_client/service/model_deployment_spec_test.py +++ b/snowflake/ml/model/_client/service/model_deployment_spec_test.py @@ -29,6 +29,7 @@ def test_minimal(self) -> None: max_instances=1, gpu=None, num_workers=None, + max_batch_rows=None, force_rebuild=False, external_access_integration=sql_identifier.SqlIdentifier("external_access_integration"), ) @@ -77,6 +78,7 @@ def test_minimal_case_sensitive(self) -> None: max_instances=1, gpu=None, num_workers=None, + max_batch_rows=None, force_rebuild=False, external_access_integration=sql_identifier.SqlIdentifier( "external_access_integration", case_sensitive=True @@ -125,6 +127,7 @@ def test_full(self) -> None: max_instances=10, gpu="1", num_workers=10, + max_batch_rows=1024, force_rebuild=True, external_access_integration=sql_identifier.SqlIdentifier("external_access_integration"), ) @@ -149,6 +152,7 @@ def test_full(self) -> None: "max_instances": 10, "gpu": "1", "num_workers": 10, + "max_batch_rows": 1024, }, }, ) diff --git a/snowflake/ml/model/_client/sql/BUILD.bazel b/snowflake/ml/model/_client/sql/BUILD.bazel index 9eca12d0..db1a5988 100644 --- a/snowflake/ml/model/_client/sql/BUILD.bazel +++ b/snowflake/ml/model/_client/sql/BUILD.bazel @@ -108,6 +108,7 @@ py_library( deps = [ ":_base", "//snowflake/ml/_internal/utils:query_result_checker", + "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/_internal/utils:sql_identifier", ], ) diff --git a/snowflake/ml/model/_client/sql/_base.py b/snowflake/ml/model/_client/sql/_base.py index 9edff9e6..0a599749 100644 --- a/snowflake/ml/model/_client/sql/_base.py +++ b/snowflake/ml/model/_client/sql/_base.py @@ -2,6 +2,7 @@ from snowflake.ml._internal.utils import identifier, sql_identifier from snowflake.snowpark import session +from snowflake.snowpark._internal import utils as snowpark_utils class _BaseSQLClient: @@ -32,3 +33,7 @@ def fully_qualified_object_name( return identifier.get_schema_level_object_identifier( actual_database_name.identifier(), actual_schema_name.identifier(), object_name.identifier() ) + + @staticmethod + def get_tmp_name_with_prefix(prefix: str) -> str: + return f"{prefix}_{snowpark_utils.generate_random_alphanumeric().upper()}" diff --git a/snowflake/ml/model/_client/sql/model.py b/snowflake/ml/model/_client/sql/model.py index 9e4ac7f1..5646adac 100644 --- a/snowflake/ml/model/_client/sql/model.py +++ b/snowflake/ml/model/_client/sql/model.py @@ -15,6 +15,7 @@ class ModelSQLClient(_base._BaseSQLClient): MODEL_VERSION_METADATA_COL_NAME = "metadata" MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec" MODEL_VERSION_ALIASES_COL_NAME = "aliases" + MODEL_VERSION_INFERENCE_SERVICES_COL_NAME = "inference_services" def show_models( self, diff --git a/snowflake/ml/model/_client/sql/model_version.py b/snowflake/ml/model/_client/sql/model_version.py index 6af568da..91f14bf7 100644 --- a/snowflake/ml/model/_client/sql/model_version.py +++ b/snowflake/ml/model/_client/sql/model_version.py @@ -298,7 +298,9 @@ def invoke_function_method( ) -> dataframe.DataFrame: with_statements = [] if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0: - INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT" + INTERMEDIATE_TABLE_NAME = ModelVersionSQLClient.get_tmp_name_with_prefix( + "SNOWPARK_ML_MODEL_INFERENCE_INPUT" + ) with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})") else: actual_database_name = database_name or self._database_name @@ -316,9 +318,9 @@ def invoke_function_method( statement_params=statement_params, ) - INTERMEDIATE_OBJ_NAME = "TMP_RESULT" + INTERMEDIATE_OBJ_NAME = ModelVersionSQLClient.get_tmp_name_with_prefix("TMP_RESULT") - module_version_alias = "MODEL_VERSION_ALIAS" + module_version_alias = ModelVersionSQLClient.get_tmp_name_with_prefix("MODEL_VERSION_ALIAS") with_statements.append( f"{module_version_alias} AS " f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}" @@ -375,7 +377,9 @@ def invoke_table_function_method( ) -> dataframe.DataFrame: with_statements = [] if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0: - INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT" + INTERMEDIATE_TABLE_NAME = ( + f"SNOWPARK_ML_MODEL_INFERENCE_INPUT_{snowpark_utils.generate_random_alphanumeric().upper()}" + ) with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})") else: actual_database_name = database_name or self._database_name @@ -393,7 +397,7 @@ def invoke_table_function_method( statement_params=statement_params, ) - module_version_alias = "MODEL_VERSION_ALIAS" + module_version_alias = f"MODEL_VERSION_ALIAS_{snowpark_utils.generate_random_alphanumeric().upper()}" with_statements.append( f"{module_version_alias} AS " f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}" diff --git a/snowflake/ml/model/_client/sql/model_version_test.py b/snowflake/ml/model/_client/sql/model_version_test.py index 2271eb30..7a7b9b0e 100644 --- a/snowflake/ml/model/_client/sql/model_version_test.py +++ b/snowflake/ml/model/_client/sql/model_version_test.py @@ -341,21 +341,21 @@ def test_invoke_function_method(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame() self.m_session.add_mock_sql( - """WITH MODEL_VERSION_ALIAS AS MODEL TEMP."test".MODEL VERSION V1 + """WITH MODEL_VERSION_ALIAS_ABCDEF0123 AS MODEL TEMP."test".MODEL VERSION V1 SELECT *, - MODEL_VERSION_ALIAS!PREDICT(COL1, COL2) AS TMP_RESULT + MODEL_VERSION_ALIAS_ABCDEF0123!PREDICT(COL1, COL2) AS TMP_RESULT_ABCDEF0123 FROM TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123""", m_df, ) - m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT") + m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT_ABCDEF0123") c_session = cast(Session, self.m_session) mock_writer = mock.MagicMock() m_df.__setattr__("write", mock_writer) m_df.add_query("queries", "query_1") m_df.add_query("queries", "query_2") with mock.patch.object(mock_writer, "save_as_table") as mock_save_as_table, mock.patch.object( - snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_TABLE_ABCDEF0123" - ) as mock_random_name_for_temp_object: + snowpark_utils, "generate_random_alphanumeric", return_value="ABCDEF0123" + ): model_version_sql.ModelVersionSQLClient( c_session, database_name=sql_identifier.SqlIdentifier("TEMP"), @@ -371,7 +371,6 @@ def test_invoke_function_method(self) -> None: returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], statement_params=m_statement_params, ) - mock_random_name_for_temp_object.assert_called_once_with(snowpark_utils.TempObjectType.TABLE) mock_save_as_table.assert_called_once_with( table_name='TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123', mode="errorifexists", @@ -383,21 +382,21 @@ def test_invoke_function_method_fully_qualified(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame() self.m_session.add_mock_sql( - """WITH MODEL_VERSION_ALIAS AS MODEL TEMP."test".MODEL VERSION V1 + """WITH MODEL_VERSION_ALIAS_ABCDEF0123 AS MODEL TEMP."test".MODEL VERSION V1 SELECT *, - MODEL_VERSION_ALIAS!PREDICT(COL1, COL2) AS TMP_RESULT + MODEL_VERSION_ALIAS_ABCDEF0123!PREDICT(COL1, COL2) AS TMP_RESULT_ABCDEF0123 FROM TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123""", m_df, ) - m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT") + m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT_ABCDEF0123") c_session = cast(Session, self.m_session) mock_writer = mock.MagicMock() m_df.__setattr__("write", mock_writer) m_df.add_query("queries", "query_1") m_df.add_query("queries", "query_2") with mock.patch.object(mock_writer, "save_as_table") as mock_save_as_table, mock.patch.object( - snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_TABLE_ABCDEF0123" - ) as mock_random_name_for_temp_object: + snowpark_utils, "generate_random_alphanumeric", return_value="ABCDEF0123" + ): model_version_sql.ModelVersionSQLClient( c_session, database_name=sql_identifier.SqlIdentifier("foo"), @@ -413,7 +412,6 @@ def test_invoke_function_method_fully_qualified(self) -> None: returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], statement_params=m_statement_params, ) - mock_random_name_for_temp_object.assert_called_once_with(snowpark_utils.TempObjectType.TABLE) mock_save_as_table.assert_called_once_with( table_name='TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123', mode="errorifexists", @@ -425,21 +423,21 @@ def test_invoke_function_method_1(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame() self.m_session.add_mock_sql( - """WITH MODEL_VERSION_ALIAS AS MODEL TEMP."test".MODEL VERSION V1 + """WITH MODEL_VERSION_ALIAS_ABCDEF0123 AS MODEL TEMP."test".MODEL VERSION V1 SELECT *, - MODEL_VERSION_ALIAS!PREDICT(COL1, COL2) AS TMP_RESULT + MODEL_VERSION_ALIAS_ABCDEF0123!PREDICT(COL1, COL2) AS TMP_RESULT_ABCDEF0123 FROM TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123""", m_df, ) - m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT") + m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT_ABCDEF0123") c_session = cast(Session, self.m_session) mock_writer = mock.MagicMock() m_df.__setattr__("write", mock_writer) m_df.add_query("queries", "query_1") m_df.add_query("queries", "query_2") with mock.patch.object(mock_writer, "save_as_table") as mock_save_as_table, mock.patch.object( - snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_TABLE_ABCDEF0123" - ) as mock_random_name_for_temp_object: + snowpark_utils, "generate_random_alphanumeric", return_value="ABCDEF0123" + ): model_version_sql.ModelVersionSQLClient( c_session, database_name=sql_identifier.SqlIdentifier("TEMP"), @@ -455,7 +453,6 @@ def test_invoke_function_method_1(self) -> None: returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], statement_params=m_statement_params, ) - mock_random_name_for_temp_object.assert_called_once_with(snowpark_utils.TempObjectType.TABLE) mock_save_as_table.assert_called_once_with( table_name='TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123', mode="errorifexists", @@ -467,40 +464,41 @@ def test_invoke_function_method_2(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame() self.m_session.add_mock_sql( - """WITH SNOWPARK_ML_MODEL_INFERENCE_INPUT AS (query_1), - MODEL_VERSION_ALIAS AS MODEL TEMP."test".MODEL VERSION V1 + """WITH SNOWPARK_ML_MODEL_INFERENCE_INPUT_ABCDEF0123 AS (query_1), + MODEL_VERSION_ALIAS_ABCDEF0123 AS MODEL TEMP."test".MODEL VERSION V1 SELECT *, - MODEL_VERSION_ALIAS!PREDICT(COL1, COL2) AS TMP_RESULT - FROM SNOWPARK_ML_MODEL_INFERENCE_INPUT""", + MODEL_VERSION_ALIAS_ABCDEF0123!PREDICT(COL1, COL2) AS TMP_RESULT_ABCDEF0123 + FROM SNOWPARK_ML_MODEL_INFERENCE_INPUT_ABCDEF0123""", m_df, ) - m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT") + m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT_ABCDEF0123") c_session = cast(Session, self.m_session) m_df.add_query("queries", "query_1") - model_version_sql.ModelVersionSQLClient( - c_session, - database_name=sql_identifier.SqlIdentifier("TEMP"), - schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), - ).invoke_function_method( - database_name=None, - schema_name=None, - model_name=sql_identifier.SqlIdentifier("MODEL"), - version_name=sql_identifier.SqlIdentifier("V1"), - method_name=sql_identifier.SqlIdentifier("PREDICT"), - input_df=cast(DataFrame, m_df), - input_args=[sql_identifier.SqlIdentifier("COL1"), sql_identifier.SqlIdentifier("COL2")], - returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], - statement_params=m_statement_params, - ) + with mock.patch.object(snowpark_utils, "generate_random_alphanumeric", return_value="ABCDEF0123"): + model_version_sql.ModelVersionSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).invoke_function_method( + database_name=None, + schema_name=None, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + method_name=sql_identifier.SqlIdentifier("PREDICT"), + input_df=cast(DataFrame, m_df), + input_args=[sql_identifier.SqlIdentifier("COL1"), sql_identifier.SqlIdentifier("COL2")], + returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], + statement_params=m_statement_params, + ) def test_invoke_table_function_method_no_partition_col(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame() self.m_session.add_mock_sql( - """WITH MODEL_VERSION_ALIAS AS MODEL TEMP."test".MODEL VERSION V1 + """WITH MODEL_VERSION_ALIAS_ABCDEF0123 AS MODEL TEMP."test".MODEL VERSION V1 SELECT *, FROM TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123, - TABLE(MODEL_VERSION_ALIAS!EXPLAIN(COL1, COL2)) + TABLE(MODEL_VERSION_ALIAS_ABCDEF0123!EXPLAIN(COL1, COL2)) """, m_df, ) @@ -511,8 +509,8 @@ def test_invoke_table_function_method_no_partition_col(self) -> None: m_df.add_query("queries", "query_1") m_df.add_query("queries", "query_2") with mock.patch.object(mock_writer, "save_as_table") as mock_save_as_table, mock.patch.object( - snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_TABLE_ABCDEF0123" - ) as mock_random_name_for_temp_object: + snowpark_utils, "generate_random_alphanumeric", return_value="ABCDEF0123" + ): model_version_sql.ModelVersionSQLClient( c_session, database_name=sql_identifier.SqlIdentifier("TEMP"), @@ -530,7 +528,6 @@ def test_invoke_table_function_method_no_partition_col(self) -> None: statement_params=m_statement_params, is_partitioned=False, ) - mock_random_name_for_temp_object.assert_called_once_with(snowpark_utils.TempObjectType.TABLE) mock_save_as_table.assert_called_once_with( table_name='TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123', mode="errorifexists", @@ -543,10 +540,10 @@ def test_invoke_table_function_method_partition_col(self) -> None: m_df = mock_data_frame.MockDataFrame() partition_column = "partition_col" self.m_session.add_mock_sql( - f"""WITH MODEL_VERSION_ALIAS AS MODEL TEMP."test".MODEL VERSION V1 + f"""WITH MODEL_VERSION_ALIAS_ABCDEF0123 AS MODEL TEMP."test".MODEL VERSION V1 SELECT *, FROM TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123, - TABLE(MODEL_VERSION_ALIAS!PREDICT_TABLE(COL1, COL2) OVER (PARTITION BY {partition_column})) + TABLE(MODEL_VERSION_ALIAS_ABCDEF0123!PREDICT_TABLE(COL1, COL2) OVER (PARTITION BY {partition_column})) """, m_df, ) @@ -557,8 +554,8 @@ def test_invoke_table_function_method_partition_col(self) -> None: m_df.add_query("queries", "query_1") m_df.add_query("queries", "query_2") with mock.patch.object(mock_writer, "save_as_table") as mock_save_as_table, mock.patch.object( - snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_TABLE_ABCDEF0123" - ) as mock_random_name_for_temp_object: + snowpark_utils, "generate_random_alphanumeric", return_value="ABCDEF0123" + ): model_version_sql.ModelVersionSQLClient( c_session, database_name=sql_identifier.SqlIdentifier("TEMP"), @@ -575,7 +572,6 @@ def test_invoke_table_function_method_partition_col(self) -> None: partition_column=sql_identifier.SqlIdentifier(partition_column), statement_params=m_statement_params, ) - mock_random_name_for_temp_object.assert_called_once_with(snowpark_utils.TempObjectType.TABLE) mock_save_as_table.assert_called_once_with( table_name='TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123', mode="errorifexists", @@ -588,10 +584,10 @@ def test_invoke_table_function_method_partition_col_fully_qualified(self) -> Non m_df = mock_data_frame.MockDataFrame() partition_column = "partition_col" self.m_session.add_mock_sql( - f"""WITH MODEL_VERSION_ALIAS AS MODEL TEMP."test".MODEL VERSION V1 + f"""WITH MODEL_VERSION_ALIAS_ABCDEF0123 AS MODEL TEMP."test".MODEL VERSION V1 SELECT *, FROM TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123, - TABLE(MODEL_VERSION_ALIAS!PREDICT_TABLE(COL1, COL2) OVER (PARTITION BY {partition_column})) + TABLE(MODEL_VERSION_ALIAS_ABCDEF0123!PREDICT_TABLE(COL1, COL2) OVER (PARTITION BY {partition_column})) """, m_df, ) @@ -602,8 +598,8 @@ def test_invoke_table_function_method_partition_col_fully_qualified(self) -> Non m_df.add_query("queries", "query_1") m_df.add_query("queries", "query_2") with mock.patch.object(mock_writer, "save_as_table") as mock_save_as_table, mock.patch.object( - snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_TABLE_ABCDEF0123" - ) as mock_random_name_for_temp_object: + snowpark_utils, "generate_random_alphanumeric", return_value="ABCDEF0123" + ): model_version_sql.ModelVersionSQLClient( c_session, database_name=sql_identifier.SqlIdentifier("foo"), @@ -620,7 +616,6 @@ def test_invoke_table_function_method_partition_col_fully_qualified(self) -> Non partition_column=sql_identifier.SqlIdentifier(partition_column), statement_params=m_statement_params, ) - mock_random_name_for_temp_object.assert_called_once_with(snowpark_utils.TempObjectType.TABLE) mock_save_as_table.assert_called_once_with( table_name='TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123', mode="errorifexists", diff --git a/snowflake/ml/model/_client/sql/service.py b/snowflake/ml/model/_client/sql/service.py index 063e853b..43293010 100644 --- a/snowflake/ml/model/_client/sql/service.py +++ b/snowflake/ml/model/_client/sql/service.py @@ -3,10 +3,13 @@ import textwrap from typing import Any, Dict, List, Optional, Tuple +from packaging import version + from snowflake import snowpark from snowflake.ml._internal.utils import ( identifier, query_result_checker, + snowflake_env, sql_identifier, ) from snowflake.ml.model._client.sql import _base @@ -92,15 +95,8 @@ def invoke_function_method( actual_database_name = database_name or self._database_name actual_schema_name = schema_name or self._schema_name - function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()]) - fully_qualified_function_name = identifier.get_schema_level_object_identifier( - actual_database_name.identifier(), - actual_schema_name.identifier(), - function_name, - ) - if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0: - INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT" + INTERMEDIATE_TABLE_NAME = ServiceSQLClient.get_tmp_name_with_prefix("SNOWPARK_ML_MODEL_INFERENCE_INPUT") with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})") else: tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE) @@ -116,7 +112,7 @@ def invoke_function_method( statement_params=statement_params, ) - INTERMEDIATE_OBJ_NAME = "TMP_RESULT" + INTERMEDIATE_OBJ_NAME = ServiceSQLClient.get_tmp_name_with_prefix("TMP_RESULT") with_sql = f"WITH {','.join(with_statements)}" if with_statements else "" args_sql_list = [] @@ -124,6 +120,22 @@ def invoke_function_method( args_sql_list.append(input_arg_value) args_sql = ", ".join(args_sql_list) + if snowflake_env.get_current_snowflake_version( + self._session, statement_params=statement_params + ) >= version.parse("8.39.0"): + fully_qualified_service_name = self.fully_qualified_object_name( + actual_database_name, actual_schema_name, service_name + ) + fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}" + + else: + function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()]) + fully_qualified_function_name = identifier.get_schema_level_object_identifier( + actual_database_name.identifier(), + actual_schema_name.identifier(), + function_name, + ) + sql = textwrap.dedent( f"""{with_sql} SELECT *, @@ -154,7 +166,9 @@ def invoke_function_method( def get_service_logs( self, *, - service_name: str, + database_name: Optional[sql_identifier.SqlIdentifier], + schema_name: Optional[sql_identifier.SqlIdentifier], + service_name: sql_identifier.SqlIdentifier, instance_id: str = "0", container_name: str, statement_params: Optional[Dict[str, Any]] = None, @@ -163,7 +177,11 @@ def get_service_logs( rows = ( query_result_checker.SqlResultValidator( self._session, - f"CALL {system_func}('{service_name}', '{instance_id}', '{container_name}')", + ( + f"CALL {system_func}(" + f"'{self.fully_qualified_object_name(database_name, schema_name, service_name)}', '{instance_id}', " + f"'{container_name}')" + ), statement_params=statement_params, ) .has_dimensions(expected_rows=1, expected_cols=1) @@ -174,7 +192,9 @@ def get_service_logs( def get_service_status( self, *, - service_name: str, + database_name: Optional[sql_identifier.SqlIdentifier], + schema_name: Optional[sql_identifier.SqlIdentifier], + service_name: sql_identifier.SqlIdentifier, include_message: bool = False, statement_params: Optional[Dict[str, Any]] = None, ) -> Tuple[ServiceStatus, Optional[str]]: @@ -182,7 +202,7 @@ def get_service_status( rows = ( query_result_checker.SqlResultValidator( self._session, - f"CALL {system_func}('{service_name}')", + f"CALL {system_func}('{self.fully_qualified_object_name(database_name, schema_name, service_name)}')", statement_params=statement_params, ) .has_dimensions(expected_rows=1, expected_cols=1) @@ -194,3 +214,17 @@ def get_service_status( message = metadata["message"] if include_message else None return service_status, message return ServiceStatus.UNKNOWN, None + + def drop_service( + self, + *, + database_name: Optional[sql_identifier.SqlIdentifier], + schema_name: Optional[sql_identifier.SqlIdentifier], + service_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + query_result_checker.SqlResultValidator( + self._session, + f"DROP SERVICE {self.fully_qualified_object_name(database_name, schema_name, service_name)}", + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() diff --git a/snowflake/ml/model/_client/sql/service_test.py b/snowflake/ml/model/_client/sql/service_test.py index 7827e108..38806a37 100644 --- a/snowflake/ml/model/_client/sql/service_test.py +++ b/snowflake/ml/model/_client/sql/service_test.py @@ -102,21 +102,24 @@ def test_deploy_model(self) -> None: def test_invoke_function_method(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame() + m_df0 = mock_data_frame.MockDataFrame(collect_result=[Row(CURRENT_VERSION="1")]) + self.m_session.add_mock_sql("SELECT CURRENT_VERSION() AS CURRENT_VERSION", m_df0) + self.m_session.add_mock_sql( """SELECT *, - TEMP."test".SERVICE_PREDICT(COL1, COL2) AS TMP_RESULT + TEMP."test".SERVICE_PREDICT(COL1, COL2) AS TMP_RESULT_ABCDEF0123 FROM TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123""", m_df, ) - m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT") + m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT_ABCDEF0123") c_session = cast(Session, self.m_session) mock_writer = mock.MagicMock() m_df.__setattr__("write", mock_writer) m_df.add_query("queries", "query_1") m_df.add_query("queries", "query_2") with mock.patch.object(mock_writer, "save_as_table") as mock_save_as_table, mock.patch.object( - snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_TABLE_ABCDEF0123" - ) as mock_random_name_for_temp_object: + snowpark_utils, "generate_random_alphanumeric", return_value="ABCDEF0123" + ): service_sql.ServiceSQLClient( c_session, database_name=sql_identifier.SqlIdentifier("TEMP"), @@ -131,7 +134,6 @@ def test_invoke_function_method(self) -> None: returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], statement_params=m_statement_params, ) - mock_random_name_for_temp_object.assert_called_once_with(snowpark_utils.TempObjectType.TABLE) mock_save_as_table.assert_called_once_with( table_name='TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123', mode="errorifexists", @@ -141,14 +143,17 @@ def test_invoke_function_method(self) -> None: def test_invoke_function_method_1(self) -> None: m_statement_params = {"test": "1"} + m_df0 = mock_data_frame.MockDataFrame(collect_result=[Row(CURRENT_VERSION="1")]) + self.m_session.add_mock_sql("SELECT CURRENT_VERSION() AS CURRENT_VERSION", m_df0) + m_df = mock_data_frame.MockDataFrame() self.m_session.add_mock_sql( """SELECT *, - FOO."bar"."service_PREDICT"(COL1, COL2) AS TMP_RESULT + FOO."bar"."service_PREDICT"(COL1, COL2) AS TMP_RESULT_ABCDEF0123 FROM FOO."bar".SNOWPARK_TEMP_TABLE_ABCDEF0123""", m_df, ) - m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT") + m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT_ABCDEF0123") c_session = cast(Session, self.m_session) mock_writer = mock.MagicMock() m_df.__setattr__("write", mock_writer) @@ -156,7 +161,9 @@ def test_invoke_function_method_1(self) -> None: m_df.add_query("queries", "query_2") with mock.patch.object(mock_writer, "save_as_table") as mock_save_as_table, mock.patch.object( snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_TABLE_ABCDEF0123" - ) as mock_random_name_for_temp_object: + ) as mock_random_name_for_temp_object, mock.patch.object( + snowpark_utils, "generate_random_alphanumeric", return_value="ABCDEF0123" + ): service_sql.ServiceSQLClient( c_session, database_name=sql_identifier.SqlIdentifier("TEMP"), @@ -181,31 +188,34 @@ def test_invoke_function_method_1(self) -> None: def test_invoke_function_method_2(self) -> None: m_statement_params = {"test": "1"} + m_df0 = mock_data_frame.MockDataFrame(collect_result=[Row(CURRENT_VERSION="1")]) + self.m_session.add_mock_sql("SELECT CURRENT_VERSION() AS CURRENT_VERSION", m_df0) m_df = mock_data_frame.MockDataFrame() self.m_session.add_mock_sql( - """WITH SNOWPARK_ML_MODEL_INFERENCE_INPUT AS (query_1) + """WITH SNOWPARK_ML_MODEL_INFERENCE_INPUT_ABCDEF0123 AS (query_1) SELECT *, - TEMP."test".SERVICE_PREDICT(COL1, COL2) AS TMP_RESULT - FROM SNOWPARK_ML_MODEL_INFERENCE_INPUT""", + TEMP."test".SERVICE_PREDICT(COL1, COL2) AS TMP_RESULT_ABCDEF0123 + FROM SNOWPARK_ML_MODEL_INFERENCE_INPUT_ABCDEF0123""", m_df, ) - m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT") + m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT_ABCDEF0123") c_session = cast(Session, self.m_session) m_df.add_query("queries", "query_1") - service_sql.ServiceSQLClient( - c_session, - database_name=sql_identifier.SqlIdentifier("TEMP"), - schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), - ).invoke_function_method( - database_name=None, - schema_name=None, - service_name=sql_identifier.SqlIdentifier("SERVICE"), - method_name=sql_identifier.SqlIdentifier("PREDICT"), - input_df=cast(DataFrame, m_df), - input_args=[sql_identifier.SqlIdentifier("COL1"), sql_identifier.SqlIdentifier("COL2")], - returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], - statement_params=m_statement_params, - ) + with mock.patch.object(snowpark_utils, "generate_random_alphanumeric", return_value="ABCDEF0123"): + service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).invoke_function_method( + database_name=None, + schema_name=None, + service_name=sql_identifier.SqlIdentifier("SERVICE"), + method_name=sql_identifier.SqlIdentifier("PREDICT"), + input_df=cast(DataFrame, m_df), + input_args=[sql_identifier.SqlIdentifier("COL1"), sql_identifier.SqlIdentifier("COL2")], + returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], + statement_params=m_statement_params, + ) def test_get_service_logs(self) -> None: m_statement_params = {"test": "1"} @@ -214,7 +224,7 @@ def test_get_service_logs(self) -> None: m_df = mock_data_frame.MockDataFrame(collect_result=[row(m_res)], collect_statement_params=m_statement_params) self.m_session.add_mock_sql( - """CALL SYSTEM$GET_SERVICE_LOGS('SERVICE', '0', 'model-container')""", + """CALL SYSTEM$GET_SERVICE_LOGS('TEMP."test".MYSERVICE', '0', 'model-container')""", copy.deepcopy(m_df), ) c_session = cast(Session, self.m_session) @@ -224,7 +234,9 @@ def test_get_service_logs(self) -> None: database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), ).get_service_logs( - service_name="SERVICE", + database_name=None, + schema_name=None, + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), instance_id="0", container_name="model-container", statement_params=m_statement_params, @@ -238,7 +250,7 @@ def test_get_service_status_include_message(self) -> None: m_res = (m_service_status, m_message) status_res = ( f'[{{"status":"{m_service_status.value}","message":"{m_message}",' - '"containerName":"model-inference","instanceId":"0","serviceName":"SERVICE",' + '"containerName":"model-inference","instanceId":"0","serviceName":"MYSERVICE",' '"image":"image_url","restartCount":0,"startTime":""}]' ) row = Row("SYSTEM$GET_SERVICE_STATUS") @@ -247,7 +259,7 @@ def test_get_service_status_include_message(self) -> None: ) self.m_session.add_mock_sql( - """CALL SYSTEM$GET_SERVICE_STATUS('SERVICE')""", + """CALL SYSTEM$GET_SERVICE_STATUS('TEMP."test".MYSERVICE')""", copy.deepcopy(m_df), ) c_session = cast(Session, self.m_session) @@ -256,7 +268,9 @@ def test_get_service_status_include_message(self) -> None: database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), ).get_service_status( - service_name="SERVICE", + database_name=None, + schema_name=None, + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), include_message=True, statement_params=m_statement_params, ) @@ -269,7 +283,7 @@ def test_get_service_status_exclude_message(self) -> None: m_res = (m_service_status, None) status_res = ( f'[{{"status":"{m_service_status.value}","message":"{m_message}",' - '"containerName":"model-inference","instanceId":"0","serviceName":"SERVICE",' + '"containerName":"model-inference","instanceId":"0","serviceName":"MYSERVICE",' '"image":"image_url","restartCount":0,"startTime":""}]' ) row = Row("SYSTEM$GET_SERVICE_STATUS") @@ -278,7 +292,7 @@ def test_get_service_status_exclude_message(self) -> None: ) self.m_session.add_mock_sql( - """CALL SYSTEM$GET_SERVICE_STATUS('SERVICE')""", + """CALL SYSTEM$GET_SERVICE_STATUS('TEMP."test".MYSERVICE')""", copy.deepcopy(m_df), ) c_session = cast(Session, self.m_session) @@ -287,7 +301,9 @@ def test_get_service_status_exclude_message(self) -> None: database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), ).get_service_status( - service_name="SERVICE", + database_name=None, + schema_name=None, + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), include_message=False, statement_params=m_statement_params, ) @@ -299,7 +315,7 @@ def test_get_service_status_no_status(self) -> None: m_res = (service_sql.ServiceStatus.UNKNOWN, None) status_res = ( f'[{{"status":"","message":"{m_message}","containerName":"model-inference","instanceId":"0",' - '"serviceName":"SERVICE","image":"image_url","restartCount":0,"startTime":""}]' + '"serviceName":"MYSERVICE","image":"image_url","restartCount":0,"startTime":""}]' ) row = Row("SYSTEM$GET_SERVICE_STATUS") m_df = mock_data_frame.MockDataFrame( @@ -307,7 +323,7 @@ def test_get_service_status_no_status(self) -> None: ) self.m_session.add_mock_sql( - """CALL SYSTEM$GET_SERVICE_STATUS('SERVICE')""", + """CALL SYSTEM$GET_SERVICE_STATUS('TEMP."test".MYSERVICE')""", copy.deepcopy(m_df), ) c_session = cast(Session, self.m_session) @@ -316,12 +332,36 @@ def test_get_service_status_no_status(self) -> None: database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), ).get_service_status( - service_name="SERVICE", + database_name=None, + schema_name=None, + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), include_message=False, statement_params=m_statement_params, ) self.assertEqual(res, m_res) + def test_drop_service(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row("Service MYSERVICE successfully dropped.")], collect_statement_params=m_statement_params + ) + self.m_session.add_mock_sql( + """DROP SERVICE TEMP."test".MYSERVICE""", + copy.deepcopy(m_df), + ) + c_session = cast(Session, self.m_session) + + service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).drop_service( + database_name=None, + schema_name=None, + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + statement_params=m_statement_params, + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_deploy_client/image_builds/BUILD.bazel b/snowflake/ml/model/_deploy_client/image_builds/BUILD.bazel deleted file mode 100644 index e1b3e968..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/BUILD.bazel +++ /dev/null @@ -1,100 +0,0 @@ -load("//bazel:py_rules.bzl", "py_library", "py_test") - -package(default_visibility = ["//visibility:public"]) - -py_library( - name = "base_image_builder", - srcs = ["base_image_builder.py"], -) - -py_library( - name = "client_image_builder", - srcs = ["client_image_builder.py"], - deps = [ - ":base_image_builder", - ":docker_context", - "//snowflake/ml/_internal/container_services/image_registry:credential", - "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/_internal/utils:query_result_checker", - "//snowflake/ml/model/_packager/model_meta", - ], -) - -py_library( - name = "server_image_builder", - srcs = ["server_image_builder.py"], - data = [ - "templates/image_build_job_spec_template", - "templates/kaniko_shell_script_template", - ], - deps = [ - ":base_image_builder", - ":docker_context", - "//snowflake/ml/_internal:file_utils", - "//snowflake/ml/_internal/container_services/image_registry:registry_client", - "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/_internal/utils:identifier", - "//snowflake/ml/model/_deploy_client/utils:constants", - "//snowflake/ml/model/_deploy_client/utils:snowservice_client", - ], -) - -py_library( - name = "docker_context", - srcs = ["docker_context.py"], - data = [ - "gunicorn_run.sh", - "templates/dockerfile_template", - ":inference_server", - ], - deps = [ - "//snowflake/ml/_internal:file_utils", - "//snowflake/ml/_internal/utils:identifier", - "//snowflake/ml/model/_deploy_client/utils:constants", - "//snowflake/ml/model/_packager/model_meta", - ], -) - -py_test( - name = "client_image_builder_test", - srcs = ["client_image_builder_test.py"], - deps = [ - ":client_image_builder", - "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/test_utils:exception_utils", - "//snowflake/ml/test_utils:mock_session", - ], -) - -py_test( - name = "server_image_builder_test", - srcs = ["server_image_builder_test.py"], - data = [ - "test_fixtures/kaniko_shell_script_fixture.sh", - ], - deps = [ - ":server_image_builder", - "//snowflake/ml/test_utils:mock_session", - ], -) - -py_test( - name = "docker_context_test", - srcs = ["docker_context_test.py"], - data = [ - "test_fixtures/dockerfile_test_fixture", - "test_fixtures/dockerfile_test_fixture_with_CUDA", - "test_fixtures/dockerfile_test_fixture_with_model", - ], - deps = [ - ":docker_context", - "//snowflake/ml/model:_api", - ], -) - -filegroup( - name = "inference_server", - srcs = [ - "//snowflake/ml/model/_deploy_client/image_builds/inference_server:main.py", - ], -) diff --git a/snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py b/snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py deleted file mode 100644 index ba213bc2..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +++ /dev/null @@ -1,12 +0,0 @@ -from abc import ABC, abstractmethod - - -class ImageBuilder(ABC): - """ - Abstract class encapsulating image building and upload to model registry. - """ - - @abstractmethod - def build_and_upload_image(self) -> None: - """Builds and uploads an image to the model registry.""" - pass diff --git a/snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py b/snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py deleted file mode 100644 index 11878bc2..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +++ /dev/null @@ -1,249 +0,0 @@ -import json -import logging -import os -import shutil -import subprocess -import tempfile -import time -from enum import Enum -from typing import List - -from snowflake import snowpark -from snowflake.ml._internal.container_services.image_registry import credential -from snowflake.ml._internal.exceptions import ( - error_codes, - exceptions as snowml_exceptions, -) -from snowflake.ml.model._deploy_client.image_builds import base_image_builder - -logger = logging.getLogger(__name__) - - -class Platform(Enum): - LINUX_AMD64 = "linux/amd64" - - -class ClientImageBuilder(base_image_builder.ImageBuilder): - """ - Client-side image building and upload to model registry. - - Usage requirements: - Requires prior installation and running of Docker with BuildKit. See installation instructions in - https://docs.docker.com/engine/install/ - """ - - def __init__( - self, - *, - context_dir: str, - full_image_name: str, - image_repo: str, - session: snowpark.Session, - ) -> None: - """Initialization - - Args: - context_dir: Local docker context dir. - full_image_name: Full image name consists of image name and image tag. - image_repo: Path to image repository. - session: Snowpark session - """ - self.context_dir = context_dir - self.full_image_name = full_image_name - self.image_repo = image_repo - self.session = session - - def build_and_upload_image(self) -> None: - """Builds and uploads an image to the model registry. - - Raises: - SnowflakeMLException: Occurs when failed to build image or push to image registry. - """ - - def _setup_docker_config(docker_config_dir: str, registry_cred: str) -> None: - """Set up a temporary docker config, which is used for running all docker commands. The format of config - is based on the format that is compatible with docker credential helper: - { - "auths": { - "https://index.docker.io/v1/": { - "auth": "" - } - } - } - - Docker will try to find cli-plugins in the config dir, and then fallback to - /usr/local/lib/docker/cli-plugins OR /usr/local/libexec/docker/cli-plugins - /usr/lib/docker/cli-plugins OR /usr/libexec/docker/cli-plugins - To prevent the case that none of them exists, we copy the cli-plugins in the current config directory, - which is defined in DOCKER_CONFIG and default to $HOME/.docker to our temp config dir. - - Args: - docker_config_dir: Path to docker configuration directory, which stores the temporary session token. - registry_cred: image registry basic auth credential. - """ - orig_docker_config_dir = os.getenv("DOCKER_CONFIG", os.path.join(os.path.expanduser("~"), ".docker")) - if os.path.exists(os.path.join(orig_docker_config_dir, "cli-plugins")): - shutil.copytree( - os.path.join(orig_docker_config_dir, "cli-plugins"), - os.path.join(docker_config_dir, "cli-plugins"), - symlinks=True, - ) - content = {"auths": {self.full_image_name: {"auth": registry_cred}}} - config_path = os.path.join(docker_config_dir, "config.json") - with open(config_path, "w", encoding="utf-8") as file: - json.dump(content, file) - - def _cleanup_local_image(docker_config_dir: str) -> None: - try: - image_exist_command = ["docker", "image", "inspect", self.full_image_name] - self._run_docker_commands(image_exist_command) - except Exception: - # Image does not exist, probably due to failed build step - pass - else: - commands = ["docker", "--config", docker_config_dir, "rmi", self.full_image_name] - logger.debug(f"Removing local image: {self.full_image_name}") - self._run_docker_commands(commands) - - self.validate_docker_client_env() - with credential.generate_image_registry_credential( - self.session - ) as registry_cred, tempfile.TemporaryDirectory() as docker_config_dir: - try: - _setup_docker_config(docker_config_dir=docker_config_dir, registry_cred=registry_cred) - start = time.time() - self._build_and_tag(docker_config_dir) - end = time.time() - logger.info(f"Time taken to build the image on the client: {end - start:.2f} seconds") - - except Exception as e: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_DOCKER_ERROR, - original_exception=RuntimeError("Failed to build docker image."), - ) from e - else: - try: - start = time.time() - self._upload(docker_config_dir) - end = time.time() - logger.info(f"Time taken to upload the image to image registry: {end - start:.2f} seconds") - except Exception as e: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_DOCKER_ERROR, - original_exception=RuntimeError("Failed to upload docker image to registry."), - ) from e - finally: - _cleanup_local_image(docker_config_dir) - - def validate_docker_client_env(self) -> None: - """Ensure docker client is running and BuildKit is enabled. Note that Buildx always uses BuildKit. - - Ensure docker daemon is running through the "docker info" command on shell. When docker daemon is running, - return code will be 0, else return code will be 1. - - Ensure BuildKit is enabled by checking "docker buildx version". - - Raises: - SnowflakeMLException: Occurs when Docker is not installed or is not running. - - """ - try: - self._run_docker_commands(["docker", "info"]) - except Exception: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.CLIENT_DEPENDENCY_MISSING_ERROR, - original_exception=ConnectionError( - "Failed to initialize Docker client. Please ensure Docker is installed and running." - ), - ) - - try: - self._run_docker_commands(["docker", "buildx", "version"]) - except Exception: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.CLIENT_DEPENDENCY_MISSING_ERROR, - original_exception=ConnectionError( - "Please ensured Docker is installed with BuildKit by following " - "https://docs.docker.com/build/buildkit/#getting-started" - ), - ) - - def _build_and_tag(self, docker_config_dir: str) -> None: - """Constructs the Docker context directory and then builds a Docker image based on that context. - - Args: - docker_config_dir: Path to docker configuration directory, which stores the temporary session token. - """ - self._build_image_from_context(docker_config_dir=docker_config_dir) - - def _run_docker_commands(self, commands: List[str]) -> None: - """Run docker commands in a new child process. - - Args: - commands: List of commands to run. - - Raises: - SnowflakeMLException: Occurs when docker commands failed to execute. - """ - proc = subprocess.Popen( - commands, cwd=os.getcwd(), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, shell=False - ) - output_lines = [] - - if proc.stdout: - for line in iter(proc.stdout.readline, ""): - output_lines.append(line) - logger.debug(line) - - if proc.wait(): - for line in output_lines: - logger.error(line) - - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_DOCKER_ERROR, - original_exception=RuntimeError(f"Docker command failed: {' '.join(commands)}"), - ) - - def _build_image_from_context(self, docker_config_dir: str, *, platform: Platform = Platform.LINUX_AMD64) -> None: - """Builds a Docker image based on provided context. - - Args: - docker_config_dir: Path to docker configuration directory, which stores the temporary session token. - platform: Target platform for the build output, in the format "os[/arch[/variant]]". - """ - - commands = [ - "docker", - "--config", - docker_config_dir, - "buildx", - "build", - "--platform", - platform.value, - "--tag", - f"{self.full_image_name}", - self.context_dir, - ] - - self._run_docker_commands(commands) - - def _upload(self, docker_config_dir: str) -> None: - """ - Uploads image to the image registry. This process requires a "docker login" followed by a "docker push". Remove - local image at the end of the upload operation to save up local space. Image cache is kept for more performant - built experience at the cost of small storage footprint. - - By default, Docker overwrites the local Docker config file "/.docker/config.json" whenever a docker login - occurs. However, to ensure better isolation between Snowflake-managed Docker credentials and the user's own - Docker credentials, we will not use the default Docker config. Instead, we will write the username and session - token to a temporary file and use "docker --config" so that it only applies to the specific Docker command being - executed, without affecting the user's local Docker setup. The credential file will be automatically removed - at the end of upload operation. - - Args: - docker_config_dir: Path to docker configuration directory, which stores the temporary session token. - """ - commands = ["docker", "--config", docker_config_dir, "login", self.full_image_name] - self._run_docker_commands(commands) - - logger.debug(f"Pushing image to image repo {self.full_image_name}") - commands = ["docker", "--config", docker_config_dir, "push", self.full_image_name] - self._run_docker_commands(commands) diff --git a/snowflake/ml/model/_deploy_client/image_builds/client_image_builder_test.py b/snowflake/ml/model/_deploy_client/image_builds/client_image_builder_test.py deleted file mode 100644 index 5dab2c98..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/client_image_builder_test.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import cast - -from absl.testing import absltest -from absl.testing.absltest import mock - -from snowflake import snowpark -from snowflake.ml._internal.exceptions import ( - error_codes, - exceptions as snowml_exceptions, -) -from snowflake.ml.model._deploy_client.image_builds import client_image_builder -from snowflake.ml.test_utils import exception_utils, mock_session - - -class ClientImageBuilderTestCase(absltest.TestCase): - def setUp(self) -> None: - super().setUp() - self.m_session = cast(snowpark.session.Session, mock_session.MockSession(conn=None, test_case=self)) - self.full_image_name = "mock_full_image_name" - self.image_repo = "mock_image_repo" - self.context_dir = "/tmp/context_dir" - - self.client_image_builder = client_image_builder.ClientImageBuilder( - context_dir=self.context_dir, - full_image_name=self.full_image_name, - image_repo=self.image_repo, - session=self.m_session, - ) - - def test_throw_exception_when_docker_is_not_running(self) -> None: - with mock.patch.object(self.client_image_builder, "_run_docker_commands") as m_run_docker_commands: - m_run_docker_commands.side_effect = snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_DOCKER_ERROR, original_exception=ConnectionError() - ) - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=ConnectionError): - self.client_image_builder.build_and_upload_image() - m_run_docker_commands.assert_called_once_with(["docker", "info"]) - - def test_build(self) -> None: - m_docker_config_dir = "mock_docker_config_dir" - - with mock.patch.object(self.client_image_builder, "_build_image_from_context") as m_build_image_from_context: - self.client_image_builder._build_and_tag(m_docker_config_dir) - m_build_image_from_context.assert_called_once_with(docker_config_dir=m_docker_config_dir) - - def test_build_image_from_context(self) -> None: - with mock.patch.object(self.client_image_builder, "_run_docker_commands") as m_run_docker_commands: - m_run_docker_commands.return_value = None - m_docker_config_dir = "mock_docker_config_dir" - self.client_image_builder._build_image_from_context(docker_config_dir=m_docker_config_dir) - - expected_commands = [ - "docker", - "--config", - m_docker_config_dir, - "buildx", - "build", - "--platform", - "linux/amd64", - "--tag", - self.full_image_name, - self.context_dir, - ] - - m_run_docker_commands.assert_called_once() - actual_commands = m_run_docker_commands.call_args.args[0] - self.assertListEqual(expected_commands, actual_commands) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/model/_deploy_client/image_builds/docker_context.py b/snowflake/ml/model/_deploy_client/image_builds/docker_context.py deleted file mode 100644 index bb82eddb..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/docker_context.py +++ /dev/null @@ -1,130 +0,0 @@ -import os -import posixpath -import shutil -import string -from typing import Optional - -import importlib_resources - -from snowflake.ml._internal import file_utils -from snowflake.ml._internal.utils import identifier -from snowflake.ml.model._deploy_client import image_builds -from snowflake.ml.model._deploy_client.utils import constants -from snowflake.ml.model._packager.model_meta import model_meta -from snowflake.snowpark import FileOperation, Session - - -class DockerContext: - """ - Constructs the Docker context directory required for image building. - """ - - def __init__( - self, - context_dir: str, - model_meta: model_meta.ModelMetadata, - session: Optional[Session] = None, - model_zip_stage_path: Optional[str] = None, - ) -> None: - """Initialization - - Args: - context_dir: Path to context directory. - model_meta: Model Metadata. - session: Snowpark session. - model_zip_stage_path: Path to model zip file on stage. - """ - self.context_dir = context_dir - self.model_meta = model_meta - assert (session is None) == (model_zip_stage_path is None) - self.session = session - self.model_zip_stage_path = model_zip_stage_path - - def build(self) -> None: - """ - Generates and/or moves resources into the Docker context directory.Rename the random model directory name to - constant "model_dir" instead for better readability. - """ - self._generate_inference_code() - self._copy_entrypoint_script_to_docker_context() - self._copy_model_env_dependency_to_docker_context() - self._generate_docker_file() - - def _copy_entrypoint_script_to_docker_context(self) -> None: - """Copy gunicorn_run.sh entrypoint to docker context directory.""" - script_path = importlib_resources.files(image_builds).joinpath(constants.ENTRYPOINT_SCRIPT) - target_path = os.path.join(self.context_dir, constants.ENTRYPOINT_SCRIPT) - - with open(script_path, encoding="utf-8") as source_file, file_utils.open_file(target_path, "w") as target_file: - target_file.write(source_file.read()) - - def _copy_model_env_dependency_to_docker_context(self) -> None: - """ - Convert model dependencies to files from model metadata. - """ - self.model_meta.save(self.context_dir) - - def _generate_docker_file(self) -> None: - """ - Generates dockerfile based on dockerfile template. - """ - docker_file_path = os.path.join(self.context_dir, "Dockerfile") - docker_file_template = ( - importlib_resources.files(image_builds).joinpath("templates/dockerfile_template").read_text("utf-8") - ) - - if self.model_zip_stage_path is not None: - norm_stage_path = posixpath.normpath(identifier.remove_prefix(self.model_zip_stage_path, "@")) - assert self.session - fop = FileOperation(self.session) - # The explicit download here is inefficient but a compromise. - # We could in theory reuse the download needed for metadata extraction, but it's hacky and will go away. - # Ideally, the model download should happen as part of the server side image build, - # but it requires have our own image builder since there's need to be logic downloading model - # into the context directory. - get_res_list = fop.get(stage_location=self.model_zip_stage_path, target_directory=self.context_dir) - assert len(get_res_list) == 1, f"Single zip file should be returned, but got {len(get_res_list)} files." - local_zip_file_path = os.path.basename(get_res_list[0].file) - copy_model_statement = f"COPY {local_zip_file_path} ./{norm_stage_path}" - extra_env_statement = f"ENV MODEL_ZIP_STAGE_PATH={norm_stage_path}" - else: - copy_model_statement = "" - extra_env_statement = "" - - with open(docker_file_path, "w", encoding="utf-8") as dockerfile: - base_image = "mambaorg/micromamba:1.4.3" - tag = base_image.split(":")[1] - assert tag != constants.LATEST_IMAGE_TAG, ( - "Base image tag should not be 'latest' as it might cause false" "positive image cache hit" - ) - dockerfile_content = string.Template(docker_file_template).safe_substitute( - { - "base_image": "mambaorg/micromamba:1.4.3", - "model_env_folder": constants.MODEL_ENV_FOLDER, - "inference_server_dir": constants.INFERENCE_SERVER_DIR, - "entrypoint_script": constants.ENTRYPOINT_SCRIPT, - # Instead of omitting this ENV var when no CUDA required, we explicitly set it to empty to override - # as no CUDA is detected thus it won't be affected by the existence of CUDA in base image. - # https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-virtual.html - "cuda_override_env": self.model_meta.env.cuda_version if self.model_meta.env.cuda_version else "", - "copy_model_statement": copy_model_statement, - "extra_env_statement": extra_env_statement, - } - ) - dockerfile.write(dockerfile_content) - - def _generate_inference_code(self) -> None: - """ - Generates inference code based on the app template and creates a folder named 'server' to house the inference - server code. - """ - with importlib_resources.as_file( - importlib_resources.files(image_builds).joinpath(constants.INFERENCE_SERVER_DIR) - ) as inference_server_folder_path: - destination_folder_path = os.path.join(self.context_dir, constants.INFERENCE_SERVER_DIR) - ignore_patterns = shutil.ignore_patterns("BUILD.bazel", "*test.py", "*.\\.*", "__pycache__") - file_utils.copytree( - inference_server_folder_path, - destination_folder_path, - ignore=ignore_patterns, - ) diff --git a/snowflake/ml/model/_deploy_client/image_builds/docker_context_test.py b/snowflake/ml/model/_deploy_client/image_builds/docker_context_test.py deleted file mode 100644 index acd9b5a3..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/docker_context_test.py +++ /dev/null @@ -1,217 +0,0 @@ -import os -import re -import shutil -import tempfile -from unittest import mock - -import sklearn.base -import sklearn.datasets as datasets -from absl.testing import absltest -from sklearn import neighbors - -from snowflake.ml.model._deploy_client.image_builds import docker_context -from snowflake.ml.model._deploy_client.utils import constants -from snowflake.ml.model._packager import model_packager -from snowflake.snowpark import FileOperation, GetResult, Session - -_IRIS = datasets.load_iris(as_frame=True) -_IRIS_X = _IRIS.data -_IRIS_Y = _IRIS.target - - -def _get_sklearn_model() -> "sklearn.base.BaseEstimator": - knn_model = neighbors.KNeighborsClassifier() - knn_model.fit(_IRIS_X, _IRIS_Y) - return knn_model - - -class DockerContextTest(absltest.TestCase): - def setUp(self) -> None: - self.context_dir = tempfile.mkdtemp() - self.model_dir = tempfile.mkdtemp() - - self.packager = model_packager.ModelPackager(self.model_dir) - self.packager.save( - name="model", - model=_get_sklearn_model(), - sample_input_data=_IRIS_X, - ) - assert self.packager.meta - self.model_meta = self.packager.meta - - self.docker_context = docker_context.DockerContext(self.context_dir, model_meta=self.model_meta) - - def tearDown(self) -> None: - shutil.rmtree(self.model_dir) - shutil.rmtree(self.context_dir) - - def test_build_results_in_correct_docker_context_file_structure(self) -> None: - expected_files = [ - "Dockerfile", - constants.INFERENCE_SERVER_DIR, - constants.ENTRYPOINT_SCRIPT, - "runtimes", - "env", - "model.yaml", - ] - self.docker_context.build() - generated_files = os.listdir(self.context_dir) - self.assertCountEqual(expected_files, generated_files) - - actual_inference_files = os.listdir(os.path.join(self.context_dir, constants.INFERENCE_SERVER_DIR)) - self.assertCountEqual(["main.py"], actual_inference_files) - - model_env_dir = os.path.join(self.context_dir, "env") - self.assertTrue(os.path.exists(model_env_dir)) - - def test_docker_file_content(self) -> None: - self.docker_context.build() - dockerfile_path = os.path.join(self.context_dir, "Dockerfile") - dockerfile_fixture_path = os.path.join(os.path.dirname(__file__), "test_fixtures", "dockerfile_test_fixture") - with open(dockerfile_path, encoding="utf-8") as dockerfile, open( - dockerfile_fixture_path, encoding="utf-8" - ) as expected_dockerfile: - actual = dockerfile.read() - expected = expected_dockerfile.read() - - # Define a regular expression pattern to match comment lines - comment_pattern = r"\s*#.*$" - # Remove comments - actual = re.sub(comment_pattern, "", actual, flags=re.MULTILINE) - self.assertEqual(actual, expected, "Generated dockerfile is not aligned with the docker template") - - -class DockerContextTestCuda(absltest.TestCase): - def setUp(self) -> None: - self.context_dir = tempfile.mkdtemp() - self.model_dir = tempfile.mkdtemp() - - self.packager = model_packager.ModelPackager(self.model_dir) - self.packager.save( - name="model", - model=_get_sklearn_model(), - sample_input_data=_IRIS_X, - ) - assert self.packager.meta - self.model_meta = self.packager.meta - - self.model_meta.env.cuda_version = "11.7.1" - - self.docker_context = docker_context.DockerContext(self.context_dir, model_meta=self.model_meta) - - def tearDown(self) -> None: - shutil.rmtree(self.model_dir) - shutil.rmtree(self.context_dir) - - def test_build_results_in_correct_docker_context_file_structure(self) -> None: - expected_files = [ - "Dockerfile", - constants.INFERENCE_SERVER_DIR, - constants.ENTRYPOINT_SCRIPT, - "env", - "runtimes", - "model.yaml", - ] - self.docker_context.build() - generated_files = os.listdir(self.context_dir) - self.assertCountEqual(expected_files, generated_files) - - actual_inference_files = os.listdir(os.path.join(self.context_dir, constants.INFERENCE_SERVER_DIR)) - self.assertCountEqual(["main.py"], actual_inference_files) - - model_env_dir = os.path.join(self.context_dir, "env") - self.assertTrue(os.path.exists(model_env_dir)) - - def test_docker_file_content(self) -> None: - self.docker_context.build() - dockerfile_path = os.path.join(self.context_dir, "Dockerfile") - dockerfile_fixture_path = os.path.join( - os.path.dirname(__file__), "test_fixtures", "dockerfile_test_fixture_with_CUDA" - ) - with open(dockerfile_path, encoding="utf-8") as dockerfile, open( - dockerfile_fixture_path, encoding="utf-8" - ) as expected_dockerfile: - actual = dockerfile.read() - expected = expected_dockerfile.read() - - # Define a regular expression pattern to match comment lines - comment_pattern = r"\s*#.*$" - # Remove comments - actual = re.sub(comment_pattern, "", actual, flags=re.MULTILINE) - self.assertEqual(actual, expected, "Generated dockerfile is not aligned with the docker template") - - -class DockerContextTestModelWeights(absltest.TestCase): - def setUp(self) -> None: - self.context_dir = tempfile.mkdtemp() - self.model_dir = tempfile.mkdtemp() - - self.packager = model_packager.ModelPackager(self.model_dir) - self.packager.save( - name="model", - model=_get_sklearn_model(), - sample_input_data=_IRIS_X, - ) - assert self.packager.meta - self.model_meta = self.packager.meta - - self.model_meta.env.cuda_version = "11.7.1" - - self.mock_session = absltest.mock.MagicMock(spec=Session) - self.model_zip_stage_path = "@model_repo/model.zip" - - self.docker_context = docker_context.DockerContext( - self.context_dir, - model_meta=self.model_meta, - session=self.mock_session, - model_zip_stage_path=self.model_zip_stage_path, - ) - - def tearDown(self) -> None: - shutil.rmtree(self.model_dir) - shutil.rmtree(self.context_dir) - - def test_build_results_in_correct_docker_context_file_structure(self) -> None: - get_results = [GetResult(file="/tmp/model.zip", size="1", status="yes", message="hi")] - with mock.patch.object(FileOperation, "get", return_value=get_results): - expected_files = [ - "Dockerfile", - constants.INFERENCE_SERVER_DIR, - constants.ENTRYPOINT_SCRIPT, - "env", - "runtimes", - "model.yaml", - ] - self.docker_context.build() - generated_files = os.listdir(self.context_dir) - self.assertCountEqual(expected_files, generated_files) - - actual_inference_files = os.listdir(os.path.join(self.context_dir, constants.INFERENCE_SERVER_DIR)) - self.assertCountEqual(["main.py"], actual_inference_files) - - model_env_dir = os.path.join(self.context_dir, "env") - self.assertTrue(os.path.exists(model_env_dir)) - - def test_docker_file_content(self) -> None: - get_results = [GetResult(file="/tmp/model.zip", size="1", status="yes", message="hi")] - with mock.patch.object(FileOperation, "get", return_value=get_results): - self.docker_context.build() - dockerfile_path = os.path.join(self.context_dir, "Dockerfile") - dockerfile_fixture_path = os.path.join( - os.path.dirname(__file__), "test_fixtures", "dockerfile_test_fixture_with_model" - ) - with open(dockerfile_path, encoding="utf-8") as dockerfile, open( - dockerfile_fixture_path, encoding="utf-8" - ) as expected_dockerfile: - actual = dockerfile.read() - expected = expected_dockerfile.read() - - # Define a regular expression pattern to match comment lines - comment_pattern = r"\s*#.*$" - # Remove comments - actual = re.sub(comment_pattern, "", actual, flags=re.MULTILINE) - self.assertEqual(actual, expected, "Generated dockerfile is not aligned with the docker template") - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh b/snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh deleted file mode 100644 index e380c51d..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash -set -eu - -OS=$(uname) - -if [[ ${OS} = "Linux" ]]; then - NUM_CORES=$(nproc) -elif [[ ${OS} = "Darwin" ]]; then - # macOS - NUM_CORES=$(sysctl -n hw.ncpu) -elif [[ ${OS} = "Windows" ]]; then - NUM_CORES=$(wmic cpu get NumberOfCores | grep -Eo '[0-9]+') -else - echo "Unsupported operating system: ${OS}" - exit 1 -fi - -# Check if the "NUM_WORKERS" variable is set by the user -if [[ -n "${NUM_WORKERS-}" && "${NUM_WORKERS}" != "None" ]]; then - # If the user has set the "num_workers" variable, use it to overwrite the default value - FINAL_NUM_WORKERS=${NUM_WORKERS} -else - # Based on the Gunicorn documentation, set the number of workers to number_of_cores * 2 + 1. This assumption is - # based on an ideal scenario where one core is handling two processes simultaneously, while one process is dedicated to - # IO operations and the other process is performing compute tasks. - # However, in case when the model is large, we will run into OOM error as each process will need to load the model - # into memory. In such cases, we require the user to pass in "num_workers" to overwrite the default. - FINAL_NUM_WORKERS=$((NUM_CORES * 2 + 1)) -fi - -echo "Number of CPU cores: $NUM_CORES" -echo "Setting number of workers to $FINAL_NUM_WORKERS" - -# Exclude preload option as it won't work with non-thread-safe model, and no easy way to detect whether model is -# thread-safe or not. Defer the optimization later. -exec /opt/conda/bin/gunicorn -w "$FINAL_NUM_WORKERS" -k uvicorn.workers.UvicornWorker -b 0.0.0.0:5000 --timeout 600 inference_server.main:app diff --git a/snowflake/ml/model/_deploy_client/image_builds/inference_server/BUILD.bazel b/snowflake/ml/model/_deploy_client/image_builds/inference_server/BUILD.bazel deleted file mode 100644 index e9fff968..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/inference_server/BUILD.bazel +++ /dev/null @@ -1,46 +0,0 @@ -load("//bazel:py_rules.bzl", "py_library", "py_test") - -package(default_visibility = ["//visibility:public"]) - -exports_files([ - "main.py", -]) - -py_library( - name = "main", - srcs = ["main.py"], - deps = [ - "//snowflake/ml/model:_api", - "//snowflake/ml/model:custom_model", - "//snowflake/ml/model:type_hints", - ], -) - -py_test( - name = "main_test", - srcs = ["main_test.py"], - deps = [ - ":main", - "//snowflake/ml/_internal:file_utils", - "//snowflake/ml/model/_packager/model_meta", - ], -) - -py_test( - name = "main_vllm_test", - srcs = ["main_vllm_test.py"], - compatible_with_snowpark = False, - require_gpu = True, - deps = [ - ":main", - "//snowflake/ml/_internal:file_utils", - "//snowflake/ml/model/models:llm_model", - ], -) - -py_test( - name = "gpu_test", - srcs = ["gpu_test.py"], - compatible_with_snowpark = False, - require_gpu = True, -) diff --git a/snowflake/ml/model/_deploy_client/image_builds/inference_server/gpu_test.py b/snowflake/ml/model/_deploy_client/image_builds/inference_server/gpu_test.py deleted file mode 100644 index 4891e97c..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/inference_server/gpu_test.py +++ /dev/null @@ -1,12 +0,0 @@ -from absl.testing import absltest - - -class GPUTest(absltest.TestCase): - def test_gpu(self): - import torch - - self.assertEqual(torch.cuda.is_available(), True) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py b/snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py deleted file mode 100644 index 67426462..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +++ /dev/null @@ -1,268 +0,0 @@ -import asyncio -import http -import logging -import os -import sys -import tempfile -import threading -import time -import traceback -import zipfile -from enum import Enum -from typing import Dict, List, Optional, cast - -import pandas as pd -from gunicorn import arbiter -from starlette import applications, concurrency, requests, responses, routing - - -class _ModelLoadingState(Enum): - """ - Enum class to represent various model loading state. - """ - - LOADING = "loading" - SUCCEEDED = "succeeded" - FAILED = "failed" - - -class CustomThread(threading.Thread): - """ - Custom Thread implementation that overrides Thread.run. - - This is necessary because the default Thread implementation suppresses exceptions in child threads. The standard - behavior involves the Thread class catching exceptions and throwing a SystemExit exception, which requires - Thread.join to terminate the process. To address this, we overwrite Thread.run and use os._exit instead. - - We throw specific error code "Arbiter.APP_LOAD_ERROR" such that Gunicorn Arbiter master process will be killed, - which then trigger the container to be marked as failed. This ensures the container becomes ready when all workers - loaded the model successfully. - """ - - def run(self) -> None: - try: - super().run() - # Keep the daemon thread alive to avoid destroy - # state attached to thread when loading the model. - while True: - time.sleep(60) - except Exception: - logger.error(traceback.format_exc()) - os._exit(arbiter.Arbiter.APP_LOAD_ERROR) - - -logger = logging.getLogger(__name__) -_LOADED_MODEL = None -_LOADED_META = None -_MODEL_CODE_DIR = "code" -_MODEL_LOADING_STATE = _ModelLoadingState.LOADING -_MODEL_LOADING_EVENT = threading.Event() -_CONCURRENT_REQUESTS_MAX: Optional[int] = None -_CONCURRENT_COUNTER = 0 -_CONCURRENT_COUNTER_LOCK = asyncio.Lock() -TARGET_METHOD = None - - -def _run_setup() -> None: - """Set up logging and load model into memory.""" - # Align the application logger's handler with Gunicorn's to capture logs from all processes. - gunicorn_logger = logging.getLogger("gunicorn.error") - logger.handlers = gunicorn_logger.handlers - logger.setLevel(gunicorn_logger.level) - - logger.info(f"ENV: {os.environ}") - - global _LOADED_MODEL - global _LOADED_META - global _MODEL_LOADING_STATE - global _MODEL_LOADING_EVENT - global _CONCURRENT_REQUESTS_MAX - global TARGET_METHOD - - try: - model_zip_stage_path = os.getenv("MODEL_ZIP_STAGE_PATH") - assert model_zip_stage_path, "Missing environment variable MODEL_ZIP_STAGE_PATH" - - TARGET_METHOD = os.getenv("TARGET_METHOD") - - _concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX", "1") - _CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env) - - with tempfile.TemporaryDirectory() as tmp_dir: - if zipfile.is_zipfile(model_zip_stage_path): - extracted_dir = os.path.join(tmp_dir, "extracted_model_dir") - logger.info(f"Extracting model zip from {model_zip_stage_path} to {extracted_dir}") - with zipfile.ZipFile(model_zip_stage_path, "r") as model_zip: - if len(model_zip.namelist()) > 1: - model_zip.extractall(extracted_dir) - else: - raise RuntimeError(f"No model zip found at stage path: {model_zip_stage_path}") - logger.info(f"Loading model from {extracted_dir} into memory") - - sys.path.insert(0, os.path.join(extracted_dir, _MODEL_CODE_DIR)) - - # TODO (Server-side Model Rollout): - # Keep try block only - # SPCS spec will convert all environment variables as strings. - use_gpu = os.environ.get("SNOWML_USE_GPU", "False").lower() == "true" - try: - from snowflake.ml.model._packager import model_packager - - pk = model_packager.ModelPackager(extracted_dir) - pk.load( - as_custom_model=True, - meta_only=False, - options={"use_gpu": use_gpu}, - ) - _LOADED_MODEL = pk.model - _LOADED_META = pk.meta - except ImportError as e: - if e.name and not e.name.startswith("snowflake.ml"): - raise e - # Legacy model support - from snowflake.ml.model import ( # type: ignore[attr-defined] - _model as model_api, - ) - - if hasattr(model_api, "_load_model_for_deploy"): - _LOADED_MODEL, _LOADED_META = model_api._load_model_for_deploy(extracted_dir) - else: - _LOADED_MODEL, meta_LOADED_META = model_api._load( - local_dir_path=extracted_dir, - as_custom_model=True, - options={"use_gpu": use_gpu}, - ) - _MODEL_LOADING_STATE = _ModelLoadingState.SUCCEEDED - logger.info("Successfully loaded model into memory") - _MODEL_LOADING_EVENT.set() - except Exception as e: - _MODEL_LOADING_STATE = _ModelLoadingState.FAILED - raise e - - -async def ready(request: requests.Request) -> responses.JSONResponse: - """Check if the application is ready to serve requests. - - This endpoint is used to determine the readiness of the application to handle incoming requests. It returns an HTTP - 200 status code only when the model has been successfully loaded into memory. If the model has not yet been loaded, - it responds with an HTTP 503 status code, which signals to the readiness probe to continue probing until the - application becomes ready or until the client's timeout is reached. - - Args: - request: - The HTTP request object. - - Returns: - A JSON response with status information: - - HTTP 200 status code and {"status": "ready"} when the model is loaded and the application is ready. - - HTTP 503 status code and {"status": "not ready"} when the model is not yet loaded. - - """ - if _MODEL_LOADING_STATE == _ModelLoadingState.SUCCEEDED: - return responses.JSONResponse({"status": "ready"}) - return responses.JSONResponse({"status": "not ready"}, status_code=http.HTTPStatus.SERVICE_UNAVAILABLE) - - -def _do_predict(input_json: Dict[str, List[List[object]]]) -> responses.JSONResponse: - from snowflake.ml.model.model_signature import FeatureSpec - - assert _LOADED_MODEL, "model is not loaded" - assert _LOADED_META, "model metadata is not loaded" - assert TARGET_METHOD, "Missing environment variable TARGET_METHOD" - - try: - features = cast(List[FeatureSpec], _LOADED_META.signatures[TARGET_METHOD].inputs) - dtype_map = {feature.name: feature.as_dtype() for feature in features} - input_cols = [spec.name for spec in features] - output_cols = [spec.name for spec in _LOADED_META.signatures[TARGET_METHOD].outputs] - assert "data" in input_json, "missing data field in the request input" - # The expression x[1:] is used to exclude the index of the data row. - input_data = [x[1] for x in input_json["data"]] - df = pd.json_normalize(input_data).astype(dtype=dtype_map) - x = df[input_cols] - assert len(input_data) != 0 and not all(not row for row in input_data), "empty data" - except Exception as e: - error_message = f"Input data malformed: {str(e)}\n{traceback.format_exc()}" - logger.error(f"Failed request with error: {error_message}") - return responses.JSONResponse({"error": error_message}, status_code=http.HTTPStatus.BAD_REQUEST) - - try: - predictions_df = getattr(_LOADED_MODEL, TARGET_METHOD)(x) - predictions_df.columns = output_cols - # Use _ID to keep the order of prediction result and associated features. - _KEEP_ORDER_COL_NAME = "_ID" - if _KEEP_ORDER_COL_NAME in df.columns: - predictions_df[_KEEP_ORDER_COL_NAME] = df[_KEEP_ORDER_COL_NAME] - response = {"data": [[i, row] for i, row in enumerate(predictions_df.to_dict(orient="records"))]} - return responses.JSONResponse(response) - except Exception as e: - error_message = f"Prediction failed: {str(e)}\n{traceback.format_exc()}" - logger.error(f"Failed request with error: {error_message}") - return responses.JSONResponse({"error": error_message}, status_code=http.HTTPStatus.BAD_REQUEST) - - -async def predict(request: requests.Request) -> responses.JSONResponse: - """Endpoint to make predictions based on input data. - - Args: - request: The input data is expected to be in the following JSON format: - { - "data": [ - [0, {'_ID': 0, 'input_feature_0': 0.0, 'input_feature_1': 1.0}], - [1, {'_ID': 1, 'input_feature_0': 2.0, 'input_feature_1': 3.0}], - } - Each row is represented as a list, where the first element denotes the index of the row. - - Returns: - Two possible responses: - For success, return a JSON response - { - "data": [ - [0, {'_ID': 0, 'output': 1}], - [1, {'_ID': 1, 'output': 2}] - ] - }, - The first element of each resulting list denotes the index of the row, and the rest of the elements - represent the prediction results for that row. - For an error, return {"error": error_message, "status_code": http_response_status_code}. - """ - _MODEL_LOADING_EVENT.wait() # Ensure model is indeed loaded into memory - - global _CONCURRENT_COUNTER - global _CONCURRENT_COUNTER_LOCK - - input_json = await request.json() - - if _CONCURRENT_REQUESTS_MAX: - async with _CONCURRENT_COUNTER_LOCK: - if _CONCURRENT_COUNTER >= int(_CONCURRENT_REQUESTS_MAX): - return responses.JSONResponse( - {"error": "Too many requests"}, status_code=http.HTTPStatus.TOO_MANY_REQUESTS - ) - - async with _CONCURRENT_COUNTER_LOCK: - _CONCURRENT_COUNTER += 1 - - resp = await concurrency.run_in_threadpool(_do_predict, input_json) - - async with _CONCURRENT_COUNTER_LOCK: - _CONCURRENT_COUNTER -= 1 - - return resp - - -def run_app() -> applications.Starlette: - # TODO[shchen]: SNOW-893654. Before SnowService supports Startup probe, or extends support for Readiness probe - # with configurable failureThreshold, we will have to load the model in a separate thread in order to prevent - # gunicorn worker timeout. - model_loading_worker = CustomThread(target=_run_setup, daemon=True) - model_loading_worker.start() - - routes = [ - routing.Route("/health", endpoint=ready, methods=["GET"]), - routing.Route("/predict", endpoint=predict, methods=["POST"]), - ] - return applications.Starlette(routes=routes) - - -app = run_app() diff --git a/snowflake/ml/model/_deploy_client/image_builds/inference_server/main_test.py b/snowflake/ml/model/_deploy_client/image_builds/inference_server/main_test.py deleted file mode 100644 index 444336c1..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/inference_server/main_test.py +++ /dev/null @@ -1,219 +0,0 @@ -import contextlib -import http -import os - -import pandas as pd -import sklearn.datasets as datasets -import sklearn.neighbors as neighbors -from absl.testing import absltest -from absl.testing.absltest import mock -from starlette import testclient - -from snowflake.ml._internal import file_utils -from snowflake.ml.model import custom_model -from snowflake.ml.model._packager import model_packager - - -class MainTest(absltest.TestCase): - """ - This test utilizes TestClient, powered by httpx, to send requests to the Starlette application. - """ - - def setUp(self) -> None: - super().setUp() - self.model_zip_path = self.setup_model() - - def setup_model(self) -> str: - iris = datasets.load_iris(as_frame=True) - x = iris.data - y = iris.target - knn_model = neighbors.KNeighborsClassifier() - knn_model.fit(x, y) - - class TestCustomModel(custom_model.CustomModel): - def __init__(self, context: custom_model.ModelContext) -> None: - super().__init__(context) - - @custom_model.inference_api - def predict(self, input: pd.DataFrame) -> pd.DataFrame: - return pd.DataFrame(knn_model.predict(input)) - - model = TestCustomModel(custom_model.ModelContext()) - tmpdir = self.create_tempdir() - tmpdir_for_zip = self.create_tempdir() - zip_full_path = os.path.join(tmpdir_for_zip.full_path, "model.zip") - model_packager.ModelPackager(tmpdir.full_path).save( - name="test_model", - model=model, - sample_input_data=x, - metadata={"author": "halu", "version": "1"}, - ) - file_utils.make_archive(zip_full_path, tmpdir.full_path) - return zip_full_path - - def test_setup_import(self) -> None: - with mock.patch.dict( - os.environ, - { - "TARGET_METHOD": "predict", - "MODEL_ZIP_STAGE_PATH": self.model_zip_path, - }, - ): - with mock.patch.object( - model_packager.ModelPackager, - "load", - side_effect=ImportError("Cannot import transformers", name="transformers"), - ): - from main import _run_setup - - with self.assertRaisesRegex(ImportError, "Cannot import transformers"): - _run_setup() - - @contextlib.contextmanager - def common_helper(self): # type: ignore[no-untyped-def] - with mock.patch.dict( - os.environ, - { - "TARGET_METHOD": "predict", - "MODEL_ZIP_STAGE_PATH": self.model_zip_path, - }, - ): - import main - - client = testclient.TestClient(main.app) - yield main, client - - def test_ready_endpoint_after_model_successfully_loaded(self) -> None: - with self.common_helper() as (_, client): - response = client.get("/health") - self.assertEqual(response.status_code, http.HTTPStatus.OK) - self.assertEqual(response.json(), {"status": "ready"}) - - def test_ready_endpoint_during_model_loading(self) -> None: - with self.common_helper() as (main, client): - with mock.patch("main._MODEL_LOADING_STATE", main._ModelLoadingState.LOADING): - response = client.get("/health") - self.assertEqual(response.status_code, http.HTTPStatus.SERVICE_UNAVAILABLE) - self.assertEqual(response.json(), {"status": "not ready"}) - - def test_ready_endpoint_after_model_loading_failed(self) -> None: - with self.common_helper() as (main, client): - with mock.patch("main._MODEL_LOADING_STATE", main._ModelLoadingState.FAILED): - response = client.get("/health") - self.assertEqual(response.status_code, http.HTTPStatus.SERVICE_UNAVAILABLE) - self.assertEqual(response.json(), {"status": "not ready"}) - - def test_predict_endpoint_happy_path(self) -> None: - with self.common_helper() as (_, client): - # Construct data input based on external function data input format - data = { - "data": [ - [ - 0, - { - "_ID": 0, - "sepal length (cm)": 5.1, - "sepal width (cm)": 3.5, - "petal length (cm)": 4.2, - "petal width (cm)": 1.3, - }, - ], - [ - 1, - { - "_ID": 1, - "sepal length (cm)": 4.7, - "sepal width (cm)": 3.2, - "petal length (cm)": 4.1, - "petal width (cm)": 4.2, - }, - ], - ] - } - - response = client.post("/predict", json=data) - self.assertEqual(response.status_code, http.HTTPStatus.OK) - expected_response = { - "data": [[0, {"output_feature_0": 1, "_ID": 0}], [1, {"output_feature_0": 2, "_ID": 1}]] - } - self.assertEqual(response.json(), expected_response) - - def test_predict_endpoint_with_invalid_input(self) -> None: - with self.common_helper() as (_, client): - response = client.post("/predict", json={}) - self.assertEqual(response.status_code, http.HTTPStatus.BAD_REQUEST) - self.assertRegex(response.text, "Input data malformed: missing data field in the request input") - - response = client.post("/predict", json={"data": []}) - self.assertEqual(response.status_code, http.HTTPStatus.BAD_REQUEST) - self.assertRegex(response.text, "Input data malformed") - - # Input data with indexes only. - response = client.post("/predict", json={"data": [[0], [1]]}) - self.assertEqual(response.status_code, http.HTTPStatus.BAD_REQUEST) - self.assertRegex(response.text, "Input data malformed") - - response = client.post( - "/predict", - json={ - "foo": [ - [1, 2], - [2, 3], - ] - }, - ) - self.assertEqual(response.status_code, http.HTTPStatus.BAD_REQUEST) - self.assertRegex(response.text, "Input data malformed: missing data field in the request input") - - def test_predict_with_misshaped_data(self) -> None: - with self.common_helper() as (_, client): - data = { - "data": [ - [ - 0, - { - "_ID": 0, - "sepal length (cm)": 5.1, - "sepal width (cm)": 3.5, - "petal length (cm)": 4.2, - }, - ], - [ - 1, - { - "_ID": 1, - "sepal length (cm)": 4.7, - "sepal width (cm)": 3.2, - "petal length (cm)": 4.1, - }, - ], - ] - } - - response = client.post("/predict", json=data) - self.assertEqual(response.status_code, http.HTTPStatus.BAD_REQUEST) - self.assertRegex(response.text, r"Input data malformed: .*dtype mappings argument.*") - - def test_predict_with_incorrect_data_type(self) -> None: - with self.common_helper() as (_, client): - data = { - "data": [ - [ - 0, - { - "_ID": 0, - "sepal length (cm)": "a", - "sepal width (cm)": "b", - "petal length (cm)": "c", - "petal width (cm)": "d", - }, - ] - ] - } - response = client.post("/predict", json=data) - self.assertEqual(response.status_code, http.HTTPStatus.BAD_REQUEST) - self.assertRegex(response.text, "Input data malformed: could not convert string to float") - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/model/_deploy_client/image_builds/inference_server/main_vllm_test.py b/snowflake/ml/model/_deploy_client/image_builds/inference_server/main_vllm_test.py deleted file mode 100644 index 645eaed5..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/inference_server/main_vllm_test.py +++ /dev/null @@ -1,124 +0,0 @@ -import contextlib -import http -import logging -import os -import tempfile -from typing import Any, Dict, List - -from absl.testing import absltest -from absl.testing.absltest import mock -from starlette import testclient - -from snowflake.ml._internal import file_utils -from snowflake.ml.model._packager import model_packager -from snowflake.ml.model.models import llm - -logger = logging.getLogger(__name__) - - -class MainVllmTest(absltest.TestCase): - @classmethod - def setUpClass(cls) -> None: - cls.cache_dir = tempfile.TemporaryDirectory() - cls._original_hf_home = os.getenv("HF_HOME", None) - os.environ["HF_HOME"] = cls.cache_dir.name - - @classmethod - def tearDownClass(cls) -> None: - if cls._original_hf_home: - os.environ["HF_HOME"] = cls._original_hf_home - else: - del os.environ["HF_HOME"] - cls.cache_dir.cleanup() - - def setUp(self) -> None: - super().setUp() - - def setup_lora_model(self) -> str: - import peft - - ft_model = peft.AutoPeftModelForCausalLM.from_pretrained( # type: ignore[attr-defined] - "peft-internal-testing/opt-350m-lora", - device_map="auto", - ) - tmpdir = self.create_tempdir().full_path - ft_model.save_pretrained(tmpdir) - options = llm.LLMOptions( - max_batch_size=100, - ) - model = llm.LLM(tmpdir, options=options) - tmpdir = self.create_tempdir() - tmpdir_for_zip = self.create_tempdir() - zip_full_path = os.path.join(tmpdir_for_zip.full_path, "model.zip") - model_packager.ModelPackager(tmpdir.full_path).save( - name="test_model", - model=model, - metadata={"author": "halu", "version": "1"}, - ) - file_utils.make_archive(zip_full_path, tmpdir.full_path) - return zip_full_path - - def setup_pretrain_model(self) -> str: - options = llm.LLMOptions( - max_batch_size=100, - enable_tp=True, - ) - model = llm.LLM("facebook/opt-350m", options=options) - tmpdir = self.create_tempdir() - tmpdir_for_zip = self.create_tempdir() - zip_full_path = os.path.join(tmpdir_for_zip.full_path, "model.zip") - model_packager.ModelPackager(tmpdir.full_path).save( - name="test_model", - model=model, - metadata={"author": "halu", "version": "1"}, - ) - file_utils.make_archive(zip_full_path, tmpdir.full_path) - return zip_full_path - - @contextlib.contextmanager - def common_helper(self, model_zip_path): # type: ignore[no-untyped-def] - with mock.patch.dict( - os.environ, - { - "TARGET_METHOD": "infer", - "MODEL_ZIP_STAGE_PATH": model_zip_path, - }, - ): - import main - - client = testclient.TestClient(main.app) - yield main, client - - def generate_data(self, dfl: List[str]) -> Dict[str, Any]: - res = [] - for i, v in enumerate(dfl): - res.append( - [ - i, - { - "_ID": i, - "input": v, - }, - ] - ) - return {"data": res} - - def test_happy_lora_case(self) -> None: - model_zip_path = self.setup_lora_model() - with self.common_helper(model_zip_path) as (_, client): - prompts = ["1+1=", "2+2="] - data = self.generate_data(prompts) - response = client.post("/predict", json=data) - self.assertEqual(response.status_code, http.HTTPStatus.OK) - - def test_happy_pretrain_case(self) -> None: - model_zip_path = self.setup_pretrain_model() - with self.common_helper(model_zip_path) as (_, client): - prompts = ["1+1=", "2+2="] - data = self.generate_data(prompts) - response = client.post("/predict", json=data) - self.assertEqual(response.status_code, http.HTTPStatus.OK) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py b/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py deleted file mode 100644 index 3bc5828c..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +++ /dev/null @@ -1,215 +0,0 @@ -import logging -import os -import posixpath -from string import Template -from typing import List - -import importlib_resources - -from snowflake import snowpark -from snowflake.ml._internal import file_utils -from snowflake.ml._internal.container_services.image_registry import ( - registry_client as image_registry_client, -) -from snowflake.ml._internal.exceptions import ( - error_codes, - exceptions as snowml_exceptions, -) -from snowflake.ml._internal.utils import identifier -from snowflake.ml.model._deploy_client import image_builds -from snowflake.ml.model._deploy_client.image_builds import base_image_builder -from snowflake.ml.model._deploy_client.utils import constants, snowservice_client - -logger = logging.getLogger(__name__) - - -class ServerImageBuilder(base_image_builder.ImageBuilder): - """ - Server-side image building and upload to model registry. - """ - - def __init__( - self, - *, - context_dir: str, - full_image_name: str, - image_repo: str, - session: snowpark.Session, - artifact_stage_location: str, - compute_pool: str, - job_name: str, - external_access_integrations: List[str], - ) -> None: - """Initialization - - Args: - context_dir: Local docker context dir. - full_image_name: Full image name consists of image name and image tag. - image_repo: Path to image repository. - session: Snowpark session - artifact_stage_location: Spec file and future deployment related artifacts will be stored under - {stage}/models/{model_id} - compute_pool: The compute pool used to run docker image build workload. - job_name: job_name to use. - external_access_integrations: EAIs for network connection. - """ - self.context_dir = context_dir - self.image_repo = image_repo - self.full_image_name = full_image_name - self.session = session - self.artifact_stage_location = artifact_stage_location - self.compute_pool = compute_pool - self.external_access_integrations = external_access_integrations - self.job_name = job_name - self.client = snowservice_client.SnowServiceClient(session) - - assert artifact_stage_location.startswith( - "@" - ), f"stage path should start with @, actual: {artifact_stage_location}" - - def build_and_upload_image(self) -> None: - """ - Builds and uploads an image to the model registry. - """ - logger.info("Starting server-side image build") - self._build_image_in_remote_job() - - def _build_image_in_remote_job(self) -> None: - context_tarball_stage_location = f"{self.artifact_stage_location}/{constants.CONTEXT}.tar.gz" - spec_stage_location = f"{self.artifact_stage_location}/{constants.IMAGE_BUILD_JOB_SPEC_TEMPLATE}.yaml" - kaniko_shell_script_stage_location = f"{self.artifact_stage_location}/{constants.KANIKO_SHELL_SCRIPT_NAME}" - - self._compress_and_upload_docker_context_tarball(context_tarball_stage_location=context_tarball_stage_location) - - self._construct_and_upload_docker_entrypoint_script( - context_tarball_stage_location=context_tarball_stage_location - ) - - # This is more of a workaround to support non-spcs-registry images. - # TODO[shchen] remove such logic when first-party-image is supported on snowservice registry. - # The regular Kaniko image doesn't include a shell; only the debug image comes with a shell. We need a shell - # as we use an sh script to launch Kaniko - kaniko_image = "/".join([self.image_repo.rstrip("/"), constants.KANIKO_IMAGE]) - registry_client = image_registry_client.ImageRegistryClient(self.session, kaniko_image) - if registry_client.image_exists(kaniko_image): - logger.debug(f"Kaniko image already existed at {kaniko_image}, skipping uploading") - else: - # Following Digest is corresponding to v1.16.0-debug tag. Note that we cannot copy from image that contains - # tag as the underlying image blob copying API supports digest only. - registry_client.copy_image( - source_image_with_digest="gcr.io/kaniko-project/executor@sha256:" - "b8c0977f88f24dbd7cbc2ffe5c5f824c410ccd0952a72cc066efc4b6dfbb52b6", - dest_image_with_tag=kaniko_image, - ) - self._construct_and_upload_job_spec( - base_image=kaniko_image, - kaniko_shell_script_stage_location=kaniko_shell_script_stage_location, - ) - self._launch_kaniko_job(spec_stage_location) - - def _construct_and_upload_docker_entrypoint_script(self, context_tarball_stage_location: str) -> None: - """Construct a shell script that invokes logic to uncompress the docker context tarball, then invoke Kaniko - executor to build images and push to image registry; the script will also ensure the docker credential(used to - authenticate to image registry) stays up-to-date when session token refreshes. - - Args: - context_tarball_stage_location: Path context directory stage location. - """ - kaniko_shell_script_template = ( - importlib_resources.files(image_builds) - .joinpath(f"templates/{constants.KANIKO_SHELL_SCRIPT_TEMPLATE}") - .read_text("utf-8") - ) - - kaniko_shell_file = os.path.join(self.context_dir, constants.KANIKO_SHELL_SCRIPT_NAME) - - with file_utils.open_file(kaniko_shell_file, "w+") as script_file: - normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@")) - params = { - # Remove @ in the beginning, append "/" to denote root directory. - "tar_from": "/" + posixpath.normpath(identifier.remove_prefix(context_tarball_stage_location, "@")), - # Remove @ in the beginning, append "/" to denote root directory. - "tar_to": "/" + normed_artifact_stage_path, - "context_dir": f"dir:///{normed_artifact_stage_path}/{constants.CONTEXT}", - "image_repo": self.image_repo, - # All models will be sharing the same layer cache from the image_repo/cache directory. - "cache_repo": f"{self.image_repo.rstrip('/')}/cache", - "image_destination": self.full_image_name, - } - template = Template(kaniko_shell_script_template) - script = template.safe_substitute(params) - script_file.write(script) - logger.debug(f"script content: \n\n {script}") - self.session.file.put( - local_file_name=kaniko_shell_file, - stage_location=self.artifact_stage_location, - auto_compress=False, - overwrite=True, - ) - - def _compress_and_upload_docker_context_tarball(self, context_tarball_stage_location: str) -> None: - try: - with file_utils._create_tar_gz_stream( - source_dir=self.context_dir, arcname=constants.CONTEXT - ) as input_stream: - self.session.file.put_stream( - input_stream=input_stream, - stage_location=context_tarball_stage_location, - auto_compress=False, - overwrite=True, - ) - except Exception as e: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_SNOWPARK_ERROR, - original_exception=RuntimeError( - "Exception occurred when compressing docker context dir as tarball and upload to stage." - ), - ) from e - - def _construct_and_upload_job_spec(self, base_image: str, kaniko_shell_script_stage_location: str) -> None: - assert kaniko_shell_script_stage_location.startswith( - "@" - ), f"stage path should start with @, actual: {kaniko_shell_script_stage_location}" - - spec_template = ( - importlib_resources.files(image_builds) - .joinpath(f"templates/{constants.IMAGE_BUILD_JOB_SPEC_TEMPLATE}") - .read_text("utf-8") - ) - - spec_file_path = os.path.join(self.context_dir, f"{constants.IMAGE_BUILD_JOB_SPEC_TEMPLATE}.yaml") - - with file_utils.open_file(spec_file_path, "w+") as spec_file: - assert self.artifact_stage_location.startswith("@") - normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@")) - (db, schema, stage, path) = identifier.parse_snowflake_stage_path(normed_artifact_stage_path) - content = Template(spec_template).safe_substitute( - { - "base_image": base_image, - "container_name": constants.KANIKO_CONTAINER_NAME, - "stage": identifier.get_schema_level_object_identifier(db, schema, stage), - # Remove @ in the beginning, append "/" to denote root directory. - "script_path": "/" - + posixpath.normpath(identifier.remove_prefix(kaniko_shell_script_stage_location, "@")), - "mounted_token_path": constants.SPCS_MOUNTED_TOKEN_PATH, - } - ) - spec_file.write(content) - spec_file.seek(0) - logger.debug(f"Kaniko job spec file: \n\n {spec_file.read()}") - - self.session.file.put( - local_file_name=spec_file_path, - stage_location=self.artifact_stage_location, - auto_compress=False, - overwrite=True, - ) - - def _launch_kaniko_job(self, spec_stage_location: str) -> None: - logger.debug(f"Submitting job {self.job_name} for building docker image with kaniko") - self.client.create_job( - job_name=self.job_name, - compute_pool=self.compute_pool, - spec_stage_location=spec_stage_location, - external_access_integrations=self.external_access_integrations, - ) diff --git a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder_test.py b/snowflake/ml/model/_deploy_client/image_builds/server_image_builder_test.py deleted file mode 100644 index 68270e98..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder_test.py +++ /dev/null @@ -1,61 +0,0 @@ -import os -import tempfile - -from absl.testing import absltest -from absl.testing.absltest import mock - -from snowflake.ml.model._deploy_client.image_builds import server_image_builder -from snowflake.ml.model._deploy_client.utils import constants - - -class ServerImageBuilderTestCase(absltest.TestCase): - def setUp(self) -> None: - super().setUp() - self.image_repo = "mock_image_repo" - self.artifact_stage_location = "@stage/models/id" - self.compute_pool = "test_pool" - self.job_name = "abcd" - self.context_tarball_stage_location = f"{self.artifact_stage_location}/context.tar.gz" - self.full_image_name = "org-account.registry.snowflakecomputing.com/db/schema/repo/image:latest" - self.eais = ["eai_1"] - - @mock.patch( # type: ignore[misc] - "snowflake.ml.model._deploy_client.image_builds.server_image_builder.snowpark.Session" - ) - def test_construct_and_upload_docker_entrypoint_script(self, m_session_class: mock.MagicMock) -> None: - m_session = m_session_class.return_value - mock_file_put = mock.MagicMock() - m_session.file.put = mock_file_put - - with tempfile.TemporaryDirectory() as context_dir: - builder = server_image_builder.ServerImageBuilder( - context_dir=context_dir, - full_image_name=self.full_image_name, - image_repo=self.image_repo, - session=m_session, - artifact_stage_location=self.artifact_stage_location, - compute_pool=self.compute_pool, - job_name=self.job_name, - external_access_integrations=self.eais, - ) - - shell_file_path = os.path.join(context_dir, constants.KANIKO_SHELL_SCRIPT_NAME) - fixture_path = os.path.join(os.path.dirname(__file__), "test_fixtures", "kaniko_shell_script_fixture.sh") - builder._construct_and_upload_docker_entrypoint_script( - context_tarball_stage_location=self.context_tarball_stage_location - ) - m_session.file.put.assert_called_once_with( - local_file_name=shell_file_path, - stage_location=self.artifact_stage_location, - auto_compress=False, - overwrite=True, - ) - - with open(shell_file_path, encoding="utf-8") as shell_file, open(fixture_path, encoding="utf-8") as fixture: - actual = shell_file.read() - expected = fixture.read() - self.assertEqual(actual, expected, "Generated image build shell script is not the same") - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template b/snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template deleted file mode 100644 index 7d6494ec..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +++ /dev/null @@ -1,53 +0,0 @@ -# Note that base image tag should not be 'latest' as it might cause false positive image cache hit. -FROM ${base_image} as build - -COPY ${model_env_folder}/conda.yml conda.yml -COPY ${model_env_folder}/requirements.txt requirements.txt -COPY ${inference_server_dir} ./${inference_server_dir} -COPY ${entrypoint_script} ./${entrypoint_script} - -USER root -RUN if id mambauser >/dev/null 2>&1; then \ - echo "mambauser already exists."; \ - else \ - # Set environment variables - export USER=mambauser && \ - export UID=1000 && \ - export HOME=/home/$USER && \ - echo "Creating $USER user..." && \ - adduser --disabled-password \ - --gecos "A non-root user for running inference server" \ - --uid $UID \ - --home $HOME \ - $USER; \ - fi - -RUN chmod +rx conda.yml -RUN chmod +rx requirements.txt -RUN chmod +x ./${entrypoint_script} - -USER mambauser - -# Set MAMBA_DOCKERFILE_ACTIVATE=1 to activate the conda environment during build time. -ARG MAMBA_DOCKERFILE_ACTIVATE=1 -ARG MAMBA_NO_LOW_SPEED_LIMIT=1 - -# Bitsandbytes uses this ENVVAR to determine CUDA library location -ENV CONDA_PREFIX=/opt/conda - -# The micromamba image comes with an empty environment named base. -# CONDA_OVERRIDE_CUDA ref https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-virtual.html -RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="${cuda_override_env}" \ - micromamba install -y -n base -f conda.yml && \ - python -m pip install "uvicorn[standard]" gunicorn starlette==0.30.0 && \ - python -m pip install -r requirements.txt && \ - micromamba clean -afy - -${copy_model_statement} - -${extra_env_statement} - -# Expose the port on which the Starlette app will run. -EXPOSE 5000 - -CMD ["./${entrypoint_script}"] diff --git a/snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template b/snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template deleted file mode 100644 index ff7c28b4..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +++ /dev/null @@ -1,38 +0,0 @@ -spec: - container: - - name: "${container_name}" - image: "${base_image}" - command: - - sh - args: - - -c - - | - wait_for_file() { - file_path="$1" - timeout="$2" - elapsed_time=0 - while [ ! -f "${file_path}" ]; do - if [ "${elapsed_time}" -ge "${timeout}" ]; then - echo "Error: ${file_path} not found within ${timeout} seconds. Exiting." - exit 1 - fi - elapsed_time=$((elapsed_time + 1)) - remaining_time=$((timeout - elapsed_time)) - echo "Awaiting the mounting of ${file_path}. Wait time remaining: ${remaining_time} seconds" - sleep 1 - done - } - wait_for_file "${script_path}" 300 - wait_for_file "${mounted_token_path}" 300 - chmod +x "${script_path}" - sh "${script_path}" - volumeMounts: - - name: vol1 - mountPath: /local/user/vol1 - - name: stagemount - mountPath: "/${stage}" - volume: - - name: vol1 - source: local # only local emptyDir volume is supported - - name: stagemount - source: "@${stage}" diff --git a/snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template b/snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template deleted file mode 100644 index 70a4e5fe..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +++ /dev/null @@ -1,105 +0,0 @@ -#!/bin/sh - -# Set the file path to monitor -REGISTRY_CRED_PATH="/kaniko/.docker/config.json" -SESSION_TOKEN_PATH="/snowflake/session/token" - -# Function to gracefully terminate the file monitoring job -cleanup() { - echo "Stopping file monitoring job..." - trap - INT TERM # Remove the signal handlers - kill -- -$$$ # Kill the entire process group. Extra $ to escape, the generated shell script should have two $. -} - -# SNOW-990976, This is an additional safety check to ensure token file exists, on top of the token file check upon -# launching SPCS job. This additional check could provide value in cases things go wrong with token refresh that result -# in token file to disappear. -wait_till_token_file_exists() { - timeout=60 # 1 minute timeout - elapsed_time=0 - - while [ ! -f "${SESSION_TOKEN_PATH}" ] && [ "$elapsed_time" -lt "$timeout" ]; do - sleep 1 - elapsed_time=$((elapsed_time + 1)) - remaining_time=$((timeout - elapsed_time)) - echo "Waiting for token file to exist. Wait time remaining: ${remaining_time} seconds." - done - - if [ ! -f "${SESSION_TOKEN_PATH}" ]; then - echo "Error: Token file '${SESSION_TOKEN_PATH}' does not show up within the ${timeout} seconds timeout period." - exit 1 - fi -} - -generate_registry_cred() { - wait_till_token_file_exists - AUTH_TOKEN=$(printf '0auth2accesstoken:%s' "$(cat ${SESSION_TOKEN_PATH})" | base64); - echo '{"auths":{"$image_repo":{"auth":"'"$AUTH_TOKEN"'"}}}' | tr -d '\n' > $REGISTRY_CRED_PATH; -} - -on_session_token_change() { - wait_till_token_file_exists - # Get the initial checksum of the file - CHECKSUM=$(md5sum "${SESSION_TOKEN_PATH}" | awk '{ print $1 }') - # Run the command once before the loop - echo "Monitoring session token changes in the background..." - ( - while true; do - wait_till_token_file_exists - # Get the current checksum of the file - CURRENT_CHECKSUM=$(md5sum "${SESSION_TOKEN_PATH}" | awk '{ print $1 }') - if [ "${CURRENT_CHECKSUM}" != "${CHECKSUM}" ]; then - # Session token file has changed, regenerate registry credential. - echo "Session token has changed. Regenerating registry auth credentials." - generate_registry_cred - CHECKSUM="${CURRENT_CHECKSUM}" - fi - # Wait for a short period of time before checking again - sleep 1 - done - ) -} - -run_kaniko() { - # Run the Kaniko command in the foreground - echo "Starting Kaniko command..." - - # Set cache ttl to a large value as snowservice registry doesn't support deleting cache anyway. - # Compression level set to 1 for fastest compression/decompression speed at the cost of compression ration. - /kaniko/executor \ - --dockerfile Dockerfile \ - --context ${context_dir} \ - --destination=${image_destination} \ - --cache=true \ - --compressed-caching=false \ - --cache-copy-layers=false \ - --use-new-run \ - --snapshot-mode=redo \ - --cache-repo=${cache_repo} \ - --cache-run-layers=true \ - --cache-ttl=8760h \ - --push-retry=3 \ - --image-fs-extract-retry=5 \ - --compression=zstd \ - --compression-level=1 \ - --log-timestamp -} - -setup() { - tar -C "${tar_to}" -xf "${tar_from}"; - generate_registry_cred - # Set up the signal handlers - trap cleanup TERM -} - -setup - -# Running kaniko job on the foreground and session token monitoring on the background. When session token changes, -# overwrite the existing registry cred file with the new session token. -on_session_token_change & -run_kaniko - -# Capture the exit code from the previous kaniko command. -KANIKO_EXIT_CODE=$? -# Exit with the same exit code as the Kaniko command. This then triggers the cleanup function. -exit $KANIKO_EXIT_CODE diff --git a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture deleted file mode 100644 index 1ae0ec8e..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture +++ /dev/null @@ -1,39 +0,0 @@ - -FROM mambaorg/micromamba:1.4.3 as build - -COPY env/conda.yml conda.yml -COPY env/requirements.txt requirements.txt -COPY inference_server ./inference_server -COPY gunicorn_run.sh ./gunicorn_run.sh - -USER root -RUN if id mambauser >/dev/null 2>&1; then \ - echo "mambauser already exists."; \ - else \ - export USER=mambauser && \ - export UID=1000 && \ - export HOME=/home/$USER && \ - echo "Creating $USER user..." && \ - adduser --disabled-password \ - --gecos "A non-root user for running inference server" \ - --uid $UID \ - --home $HOME \ - $USER; \ - fi - -RUN chmod +rx conda.yml -RUN chmod +rx requirements.txt -RUN chmod +x ./gunicorn_run.sh - -USER mambauser -ARG MAMBA_DOCKERFILE_ACTIVATE=1 -ARG MAMBA_NO_LOW_SPEED_LIMIT=1 -ENV CONDA_PREFIX=/opt/conda -RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="" \ - micromamba install -y -n base -f conda.yml && \ - python -m pip install "uvicorn[standard]" gunicorn starlette==0.30.0 && \ - python -m pip install -r requirements.txt && \ - micromamba clean -afy -EXPOSE 5000 - -CMD ["./gunicorn_run.sh"] diff --git a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_CUDA b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_CUDA deleted file mode 100644 index 0a0dff19..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_CUDA +++ /dev/null @@ -1,39 +0,0 @@ - -FROM mambaorg/micromamba:1.4.3 as build - -COPY env/conda.yml conda.yml -COPY env/requirements.txt requirements.txt -COPY inference_server ./inference_server -COPY gunicorn_run.sh ./gunicorn_run.sh - -USER root -RUN if id mambauser >/dev/null 2>&1; then \ - echo "mambauser already exists."; \ - else \ - export USER=mambauser && \ - export UID=1000 && \ - export HOME=/home/$USER && \ - echo "Creating $USER user..." && \ - adduser --disabled-password \ - --gecos "A non-root user for running inference server" \ - --uid $UID \ - --home $HOME \ - $USER; \ - fi - -RUN chmod +rx conda.yml -RUN chmod +rx requirements.txt -RUN chmod +x ./gunicorn_run.sh - -USER mambauser -ARG MAMBA_DOCKERFILE_ACTIVATE=1 -ARG MAMBA_NO_LOW_SPEED_LIMIT=1 -ENV CONDA_PREFIX=/opt/conda -RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="11.7" \ - micromamba install -y -n base -f conda.yml && \ - python -m pip install "uvicorn[standard]" gunicorn starlette==0.30.0 && \ - python -m pip install -r requirements.txt && \ - micromamba clean -afy -EXPOSE 5000 - -CMD ["./gunicorn_run.sh"] diff --git a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_model b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_model deleted file mode 100644 index 8619b7b0..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_model +++ /dev/null @@ -1,43 +0,0 @@ - -FROM mambaorg/micromamba:1.4.3 as build - -COPY env/conda.yml conda.yml -COPY env/requirements.txt requirements.txt -COPY inference_server ./inference_server -COPY gunicorn_run.sh ./gunicorn_run.sh - -USER root -RUN if id mambauser >/dev/null 2>&1; then \ - echo "mambauser already exists."; \ - else \ - export USER=mambauser && \ - export UID=1000 && \ - export HOME=/home/$USER && \ - echo "Creating $USER user..." && \ - adduser --disabled-password \ - --gecos "A non-root user for running inference server" \ - --uid $UID \ - --home $HOME \ - $USER; \ - fi - -RUN chmod +rx conda.yml -RUN chmod +rx requirements.txt -RUN chmod +x ./gunicorn_run.sh - -USER mambauser -ARG MAMBA_DOCKERFILE_ACTIVATE=1 -ARG MAMBA_NO_LOW_SPEED_LIMIT=1 -ENV CONDA_PREFIX=/opt/conda -RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="11.7" \ - micromamba install -y -n base -f conda.yml && \ - python -m pip install "uvicorn[standard]" gunicorn starlette==0.30.0 && \ - python -m pip install -r requirements.txt && \ - micromamba clean -afy - -COPY model.zip ./model_repo/model.zip - -ENV MODEL_ZIP_STAGE_PATH=model_repo/model.zip -EXPOSE 5000 - -CMD ["./gunicorn_run.sh"] diff --git a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/kaniko_shell_script_fixture.sh b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/kaniko_shell_script_fixture.sh deleted file mode 100644 index 36a9e449..00000000 --- a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/kaniko_shell_script_fixture.sh +++ /dev/null @@ -1,105 +0,0 @@ -#!/bin/sh - -# Set the file path to monitor -REGISTRY_CRED_PATH="/kaniko/.docker/config.json" -SESSION_TOKEN_PATH="/snowflake/session/token" - -# Function to gracefully terminate the file monitoring job -cleanup() { - echo "Stopping file monitoring job..." - trap - INT TERM # Remove the signal handlers - kill -- -$$ # Kill the entire process group. Extra $ to escape, the generated shell script should have two $. -} - -# SNOW-990976, This is an additional safety check to ensure token file exists, on top of the token file check upon -# launching SPCS job. This additional check could provide value in cases things go wrong with token refresh that result -# in token file to disappear. -wait_till_token_file_exists() { - timeout=60 # 1 minute timeout - elapsed_time=0 - - while [ ! -f "${SESSION_TOKEN_PATH}" ] && [ "$elapsed_time" -lt "$timeout" ]; do - sleep 1 - elapsed_time=$((elapsed_time + 1)) - remaining_time=$((timeout - elapsed_time)) - echo "Waiting for token file to exist. Wait time remaining: ${remaining_time} seconds." - done - - if [ ! -f "${SESSION_TOKEN_PATH}" ]; then - echo "Error: Token file '${SESSION_TOKEN_PATH}' does not show up within the ${timeout} seconds timeout period." - exit 1 - fi -} - -generate_registry_cred() { - wait_till_token_file_exists - AUTH_TOKEN=$(printf '0auth2accesstoken:%s' "$(cat ${SESSION_TOKEN_PATH})" | base64); - echo '{"auths":{"mock_image_repo":{"auth":"'"$AUTH_TOKEN"'"}}}' | tr -d '\n' > $REGISTRY_CRED_PATH; -} - -on_session_token_change() { - wait_till_token_file_exists - # Get the initial checksum of the file - CHECKSUM=$(md5sum "${SESSION_TOKEN_PATH}" | awk '{ print $1 }') - # Run the command once before the loop - echo "Monitoring session token changes in the background..." - ( - while true; do - wait_till_token_file_exists - # Get the current checksum of the file - CURRENT_CHECKSUM=$(md5sum "${SESSION_TOKEN_PATH}" | awk '{ print $1 }') - if [ "${CURRENT_CHECKSUM}" != "${CHECKSUM}" ]; then - # Session token file has changed, regenerate registry credential. - echo "Session token has changed. Regenerating registry auth credentials." - generate_registry_cred - CHECKSUM="${CURRENT_CHECKSUM}" - fi - # Wait for a short period of time before checking again - sleep 1 - done - ) -} - -run_kaniko() { - # Run the Kaniko command in the foreground - echo "Starting Kaniko command..." - - # Set cache ttl to a large value as snowservice registry doesn't support deleting cache anyway. - # Compression level set to 1 for fastest compression/decompression speed at the cost of compression ration. - /kaniko/executor \ - --dockerfile Dockerfile \ - --context dir:///stage/models/id/context \ - --destination=org-account.registry.snowflakecomputing.com/db/schema/repo/image:latest \ - --cache=true \ - --compressed-caching=false \ - --cache-copy-layers=false \ - --use-new-run \ - --snapshot-mode=redo \ - --cache-repo=mock_image_repo/cache \ - --cache-run-layers=true \ - --cache-ttl=8760h \ - --push-retry=3 \ - --image-fs-extract-retry=5 \ - --compression=zstd \ - --compression-level=1 \ - --log-timestamp -} - -setup() { - tar -C "/stage/models/id" -xf "/stage/models/id/context.tar.gz"; - generate_registry_cred - # Set up the signal handlers - trap cleanup TERM -} - -setup - -# Running kaniko job on the foreground and session token monitoring on the background. When session token changes, -# overwrite the existing registry cred file with the new session token. -on_session_token_change & -run_kaniko - -# Capture the exit code from the previous kaniko command. -KANIKO_EXIT_CODE=$? -# Exit with the same exit code as the Kaniko command. This then triggers the cleanup function. -exit $KANIKO_EXIT_CODE diff --git a/snowflake/ml/model/_deploy_client/snowservice/BUILD.bazel b/snowflake/ml/model/_deploy_client/snowservice/BUILD.bazel deleted file mode 100644 index 7579d010..00000000 --- a/snowflake/ml/model/_deploy_client/snowservice/BUILD.bazel +++ /dev/null @@ -1,51 +0,0 @@ -load("//bazel:py_rules.bzl", "py_library", "py_test") - -package(default_visibility = ["//visibility:public"]) - -py_library( - name = "deploy_options", - srcs = ["deploy_options.py"], - deps = [ - "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/model/_deploy_client/utils:constants", - ], -) - -py_library( - name = "deploy", - srcs = ["deploy.py"], - data = [ - "templates/service_spec_template", - "templates/service_spec_template_with_model", - ], - deps = [ - ":deploy_options", - ":instance_types", - "//snowflake/ml/_internal:env_utils", - "//snowflake/ml/_internal/container_services/image_registry:registry_client", - "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/_internal/utils:identifier", - "//snowflake/ml/_internal/utils:spcs_attribution_utils", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/model/_deploy_client/image_builds:base_image_builder", - "//snowflake/ml/model/_deploy_client/image_builds:client_image_builder", - "//snowflake/ml/model/_deploy_client/image_builds:server_image_builder", - "//snowflake/ml/model/_deploy_client/utils:snowservice_client", - "//snowflake/ml/model/_packager/model_meta", - ], -) - -py_library( - name = "instance_types", - srcs = ["instance_types.py"], -) - -py_test( - name = "deploy_test", - srcs = ["deploy_test.py"], - deps = [ - ":deploy", - "//snowflake/ml/test_utils:exception_utils", - "//snowflake/ml/test_utils:mock_session", - ], -) diff --git a/snowflake/ml/model/_deploy_client/snowservice/deploy.py b/snowflake/ml/model/_deploy_client/snowservice/deploy.py deleted file mode 100644 index e28cc38a..00000000 --- a/snowflake/ml/model/_deploy_client/snowservice/deploy.py +++ /dev/null @@ -1,611 +0,0 @@ -import copy -import logging -import os -import posixpath -import string -import tempfile -import time -from contextlib import contextmanager -from typing import Any, Dict, Generator, Optional, cast - -import importlib_resources -import yaml -from packaging import requirements -from typing_extensions import Unpack - -from snowflake.ml._internal import env_utils, file_utils -from snowflake.ml._internal.container_services.image_registry import ( - registry_client as image_registry_client, -) -from snowflake.ml._internal.exceptions import ( - error_codes, - exceptions as snowml_exceptions, -) -from snowflake.ml._internal.utils import ( - identifier, - query_result_checker, - spcs_attribution_utils, -) -from snowflake.ml.model import type_hints -from snowflake.ml.model._deploy_client import snowservice -from snowflake.ml.model._deploy_client.image_builds import ( - base_image_builder, - client_image_builder, - docker_context, - server_image_builder, -) -from snowflake.ml.model._deploy_client.snowservice import deploy_options, instance_types -from snowflake.ml.model._deploy_client.utils import constants, snowservice_client -from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema -from snowflake.snowpark import Session - -logger = logging.getLogger(__name__) - - -@contextmanager -def _debug_aware_tmp_directory(debug_dir: Optional[str] = None) -> Generator[str, None, None]: - """Debug-aware directory provider. - - Args: - debug_dir: A folder for deploymement context. - - Yields: - A directory path to write deployment artifacts - """ - create_temp = False - if debug_dir: - directory_path = debug_dir - else: - temp_dir_context = tempfile.TemporaryDirectory() - directory_path = temp_dir_context.name - create_temp = True - try: - yield directory_path - finally: - if create_temp: - temp_dir_context.cleanup() - - -def _deploy( - session: Session, - *, - model_id: str, - model_meta: model_meta.ModelMetadata, - service_func_name: str, - model_zip_stage_path: str, - deployment_stage_path: str, - target_method: str, - **kwargs: Unpack[type_hints.SnowparkContainerServiceDeployOptions], -) -> type_hints.SnowparkContainerServiceDeployDetails: - """Entrypoint for model deployment to SnowService. This function will trigger a docker image build followed by - workflow deployment to SnowService. - - Args: - session: Snowpark session - model_id: Unique hex string of length 32, provided by model registry. - model_meta: Model Metadata. - service_func_name: The service function name in SnowService associated with the created service. - model_zip_stage_path: Path to model zip file in stage. Note that this path has a "@" prefix. - deployment_stage_path: Path to stage containing deployment artifacts. - target_method: The name of the target method to be deployed. - **kwargs: various SnowService deployment options. - - Returns: - Deployment details for SPCS. - - Raises: - SnowflakeMLException: Raised when model_id is empty. - SnowflakeMLException: Raised when service_func_name is empty. - SnowflakeMLException: Raised when model_stage_file_path is empty. - """ - snowpark_logger = logging.getLogger("snowflake.snowpark") - snowflake_connector_logger = logging.getLogger("snowflake.connector") - snowpark_log_level = snowpark_logger.level - snowflake_connector_log_level = snowflake_connector_logger.level - - query_result = ( - query_result_checker.SqlResultValidator( - session, - query="SHOW PARAMETERS LIKE 'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT' IN SESSION", - ) - .has_dimensions(expected_rows=1) - .validate() - ) - prev_format = query_result[0].value - - try: - # Setting appropriate log level to prevent console from being polluted by vast amount of snowpark and snowflake - # connector logging. - snowpark_logger.setLevel(logging.WARNING) - snowflake_connector_logger.setLevel(logging.WARNING) - - # Query format change is needed to ensure session token obtained from the session object is valid. - session.sql("ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'json'").collect() - if not model_id: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - 'Must provide a non-empty string for "model_id" when deploying to Snowpark Container Services' - ), - ) - if not service_func_name: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - 'Must provide a non-empty string for "service_func_name"' - " when deploying to Snowpark Container Services" - ), - ) - if not model_zip_stage_path: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - 'Must provide a non-empty string for "model_stage_file_path"' - " when deploying to Snowpark Container Services" - ), - ) - if not deployment_stage_path: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - 'Must provide a non-empty string for "deployment_stage_path"' - " when deploying to Snowpark Container Services" - ), - ) - - # Remove full qualified name to avoid double quotes corrupting the service spec - model_zip_stage_path = model_zip_stage_path.replace('"', "") - deployment_stage_path = deployment_stage_path.replace('"', "") - - assert model_zip_stage_path.startswith("@"), f"stage path should start with @, actual: {model_zip_stage_path}" - assert deployment_stage_path.startswith("@"), f"stage path should start with @, actual: {deployment_stage_path}" - options = deploy_options.SnowServiceDeployOptions.from_dict(cast(Dict[str, Any], kwargs)) - - model_meta_deploy = copy.deepcopy(model_meta) - # Set conda-forge as backup channel for SPCS deployment - if "conda-forge" not in model_meta_deploy.env._conda_dependencies: - model_meta_deploy.env._conda_dependencies["conda-forge"] = [] - # Snowflake connector needs pyarrow to work correctly. - env_utils.append_conda_dependency( - model_meta_deploy.env._conda_dependencies, - (env_utils.DEFAULT_CHANNEL_NAME, requirements.Requirement("pyarrow")), - ) - if options.use_gpu: - # Make mypy happy - assert options.num_gpus is not None - if model_meta_deploy.env.cuda_version is None: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - "You are requesting GPUs for models that do not use a GPU or does not have CUDA version set." - ), - ) - if model_meta.env.cuda_version: - model_meta_deploy.env.generate_env_for_cuda() - else: - # If user does not need GPU, we set this copies cuda_version to None, thus when Image builder gets a - # not-None cuda_version, it gets to know that GPU is required. - model_meta_deploy.env._cuda_version = None - - _validate_compute_pool(session, options=options) - - # TODO[shchen]: SNOW-863701, Explore ways to prevent entire model zip being downloaded during deploy step - # (for both warehouse and snowservice deployment) - # One alternative is for model registry to duplicate the model metadata and env dependency storage from model - # zip so that we don't have to pull down the entire model zip. - ss_deployment = SnowServiceDeployment( - session=session, - model_id=model_id, - model_meta=model_meta_deploy, - service_func_name=service_func_name, - model_zip_stage_path=model_zip_stage_path, # Pass down model_zip_stage_path for service spec file - deployment_stage_path=deployment_stage_path, - target_method=target_method, - options=options, - ) - return ss_deployment.deploy() - finally: - session.sql(f"ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = '{prev_format}'").collect() - # Preserve the original logging level. - snowpark_logger.setLevel(snowpark_log_level) - snowflake_connector_logger.setLevel(snowflake_connector_log_level) - - -def _validate_compute_pool(session: Session, *, options: deploy_options.SnowServiceDeployOptions) -> None: - # Remove full qualified name to avoid double quotes, which does not work well in desc compute pool syntax. - compute_pool = options.compute_pool.replace('"', "") - sql = f"DESC COMPUTE POOL {compute_pool}" - result = ( - query_result_checker.SqlResultValidator( - session=session, - query=sql, - ) - .has_column("instance_family") - .has_column("state") - .has_column("auto_resume") - .has_dimensions(expected_rows=1) - .validate() - ) - - state = result[0]["state"] - auto_resume = bool(result[0]["auto_resume"]) - - if state == "SUSPENDED": - if not auto_resume: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_SNOWPARK_COMPUTE_POOL, - original_exception=RuntimeError( - "The compute pool you are requesting to use is suspended without auto-resume enabled" - ), - ) - - elif state not in ["STARTING", "ACTIVE", "IDLE"]: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_SNOWPARK_COMPUTE_POOL, - original_exception=RuntimeError( - "The compute pool you are requesting to use is not in the ACTIVE/IDLE/STARTING status." - ), - ) - - if state in ["SUSPENDED", "STARTING"]: - logger.warning(f"The compute pool you are requesting is in {state} state. We are waiting it to be ready.") - - if options.use_gpu: - assert options.num_gpus is not None - request_gpus = options.num_gpus - instance_family = result[0]["instance_family"] - if instance_family in instance_types.INSTANCE_TYPE_TO_GPU_COUNT: - gpu_capacity = instance_types.INSTANCE_TYPE_TO_GPU_COUNT[instance_family] - if request_gpus > gpu_capacity: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_SNOWPARK_COMPUTE_POOL, - original_exception=RuntimeError( - f"GPU request exceeds instance capability; {instance_family} instance type has total " - f"capacity of {gpu_capacity} GPU, yet a request was made for {request_gpus} GPUs." - ), - ) - else: - logger.warning(f"Unknown instance type: {instance_family}, skipping GPU validation") - - -def _get_or_create_image_repo(session: Session, *, service_func_name: str, image_repo: Optional[str] = None) -> str: - def _sanitize_dns_url(url: str) -> str: - # Align with existing SnowService image registry url standard. - return url.lower() - - if image_repo: - return _sanitize_dns_url(image_repo) - - try: - conn = session._conn._conn - # We try to use the same db and schema as the service function locates, as we could retrieve those information - # if that is a fully qualified one. If not we use the current session one. - (_db, _schema, _) = identifier.parse_schema_level_object_identifier(service_func_name) - db = _db if _db is not None else conn._database - schema = _schema if _schema is not None else conn._schema - assert isinstance(db, str) and isinstance(schema, str) - - client = snowservice_client.SnowServiceClient(session) - client.create_image_repo(identifier.get_schema_level_object_identifier(db, schema, constants.SNOWML_IMAGE_REPO)) - sql = f"SHOW IMAGE REPOSITORIES LIKE '{constants.SNOWML_IMAGE_REPO}' IN SCHEMA {'.'.join([db, schema])}" - result = ( - query_result_checker.SqlResultValidator( - session=session, - query=sql, - ) - .has_column("repository_url") - .has_dimensions(expected_rows=1) - .validate() - ) - repository_url = result[0]["repository_url"] - return str(repository_url) - except Exception as e: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_SNOWPARK_CONTAINER_SERVICE_ERROR, - original_exception=RuntimeError("Failed to retrieve image repo URL"), - ) from e - - -class SnowServiceDeployment: - """ - Class implementation that encapsulates image build and workflow deployment to SnowService - """ - - def __init__( - self, - session: Session, - model_id: str, - model_meta: model_meta.ModelMetadata, - service_func_name: str, - model_zip_stage_path: str, - deployment_stage_path: str, - target_method: str, - options: deploy_options.SnowServiceDeployOptions, - ) -> None: - """Initialization - - Args: - session: Snowpark session - model_id: Unique hex string of length 32, provided by model registry; if not provided, auto-generate one for - resource naming.The model_id serves as an idempotent key throughout the deployment workflow. - model_meta: Model Metadata. - service_func_name: The service function name in SnowService associated with the created service. - model_zip_stage_path: Path to model zip file in stage. - deployment_stage_path: Path to stage containing deployment artifacts. - target_method: The name of the target method to be deployed. - options: A SnowServiceDeployOptions object containing deployment options. - """ - - self.session = session - self.id = model_id - self.model_meta = model_meta - self.service_func_name = service_func_name - self.model_zip_stage_path = model_zip_stage_path - self.options = options - self.target_method = target_method - (db, schema, _) = identifier.parse_schema_level_object_identifier(service_func_name) - - self._service_name = identifier.get_schema_level_object_identifier(db, schema, f"service_{model_id}") - self._job_name = identifier.get_schema_level_object_identifier(db, schema, f"build_{model_id}") - # Spec file and future deployment related artifacts will be stored under {stage}/models/{model_id} - self._model_artifact_stage_location = posixpath.join(deployment_stage_path, "models", self.id) - self.debug_dir: Optional[str] = None - if self.options.debug_mode: - self.debug_dir = tempfile.mkdtemp() - logger.warning(f"Debug model is enabled, deployment artifacts will be available in {self.debug_dir}") - - def deploy(self) -> type_hints.SnowparkContainerServiceDeployDetails: - """ - This function triggers image build followed by workflow deployment to SnowService. - - Returns: - Deployment details. - """ - if self.options.prebuilt_snowflake_image: - logger.warning(f"Skipped image build. Use prebuilt image: {self.options.prebuilt_snowflake_image}") - service_function_sql = self._deploy_workflow(self.options.prebuilt_snowflake_image) - else: - with _debug_aware_tmp_directory(debug_dir=self.debug_dir) as context_dir: - extra_kwargs = {} - if self.options.model_in_image: - extra_kwargs = { - "session": self.session, - "model_zip_stage_path": self.model_zip_stage_path, - } - dc = docker_context.DockerContext( - context_dir=context_dir, - model_meta=self.model_meta, - **extra_kwargs, # type: ignore[arg-type] - ) - dc.build() - image_repo = _get_or_create_image_repo( - self.session, service_func_name=self.service_func_name, image_repo=self.options.image_repo - ) - full_image_name = self._get_full_image_name(image_repo=image_repo, context_dir=context_dir) - registry_client = image_registry_client.ImageRegistryClient(self.session, full_image_name) - - if not self.options.force_image_build and registry_client.image_exists(full_image_name=full_image_name): - logger.warning( - f"Similar environment detected. Using existing image {full_image_name} to skip image " - f"build. To disable this feature, set 'force_image_build=True' in deployment options" - ) - else: - logger.warning( - "Building the Docker image and deploying to Snowpark Container Service. " - "This process may take anywhere from a few minutes to a longer period for GPU-based models." - ) - start = time.time() - self._build_and_upload_image( - context_dir=context_dir, image_repo=image_repo, full_image_name=full_image_name - ) - end = time.time() - logger.info(f"Time taken to build and upload image to registry: {end - start:.2f} seconds") - logger.warning( - f"Image successfully built! For future model deployments, the image will be reused if " - f"possible, saving model deployment time. To enforce using the same image, include " - f"'prebuilt_snowflake_image': '{full_image_name}' in the deploy() function's options." - ) - - # Adding the model name as an additional tag to the existing image, excluding the version to prevent - # excessive tags and also due to version not available in current model metadata. This will allow - # users to associate images with specific models and perform relevant image registry actions. In the - # event that model dependencies change across versions, a new image hash will be computed, resulting in - # a new image. - try: - registry_client.add_tag_to_remote_image( - original_full_image_name=full_image_name, new_tag=self.model_meta.name - ) - except Exception as e: - # Proceed to the deployment with a warning message. - logger.warning(f"Failed to add tag {self.model_meta.name} to image {full_image_name}: {str(e)}") - service_function_sql = self._deploy_workflow(full_image_name) - - rows = self.session.sql(f"DESCRIBE SERVICE {self._service_name}").collect() - service_info = rows[0].as_dict() if rows and rows[0] else None - return type_hints.SnowparkContainerServiceDeployDetails( - service_info=service_info, - service_function_sql=service_function_sql, - ) - - def _get_full_image_name(self, image_repo: str, context_dir: str) -> str: - """Return a valid full image name that consists of image name and tag. e.g - org-account.registry.snowflakecomputing.com/db/schema/repo/image:latest - - Args: - image_repo: image repo path, e.g. org-account.registry.snowflakecomputing.com/db/schema/repo - context_dir: the local docker context directory, which consists everything needed to build the docker image. - - Returns: - Full image name. - """ - image_repo = _get_or_create_image_repo( - self.session, service_func_name=self.service_func_name, image_repo=self.options.image_repo - ) - - # We skip "MODEL_METADATA_FILE" as it contains information that will always lead to cache misses. This isn't an - # issue because model dependency is also captured in the model env/ folder, which will be hashed. The aim is to - # reuse the same Docker image even if the user logs a similar model without new dependencies. - docker_context_dir_hash = file_utils.hash_directory( - context_dir, ignore_hidden=True, excluded_files=[model_meta.MODEL_METADATA_FILE] - ) - # By default, we associate a 'latest' tag with each of our created images for easy existence checking. - # Additional tags are added for readability. - return f"{image_repo}/{docker_context_dir_hash}:{constants.LATEST_IMAGE_TAG}" - - def _build_and_upload_image(self, context_dir: str, image_repo: str, full_image_name: str) -> None: - """Handles image build and upload to image registry. - - Args: - context_dir: the local docker context directory, which consists everything needed to build the docker image. - image_repo: image repo path, e.g. org-account.registry.snowflakecomputing.com/db/schema/repo - full_image_name: Full image name consists of image name and image tag. - """ - image_builder: base_image_builder.ImageBuilder - if self.options.enable_remote_image_build: - image_builder = server_image_builder.ServerImageBuilder( - context_dir=context_dir, - full_image_name=full_image_name, - image_repo=image_repo, - session=self.session, - artifact_stage_location=self._model_artifact_stage_location, - compute_pool=self.options.compute_pool, - job_name=self._job_name, - external_access_integrations=self.options.external_access_integrations, - ) - else: - image_builder = client_image_builder.ClientImageBuilder( - context_dir=context_dir, full_image_name=full_image_name, image_repo=image_repo, session=self.session - ) - image_builder.build_and_upload_image() - - def _prepare_and_upload_artifacts_to_stage(self, image: str) -> None: - """Constructs and upload service spec to stage. - - Args: - image: Name of the image to create SnowService container from. - """ - if self.options.model_in_image: - spec_template = ( - importlib_resources.files(snowservice) - .joinpath("templates/service_spec_template_with_model") - .read_text("utf-8") - ) - else: - spec_template = ( - importlib_resources.files(snowservice).joinpath("templates/service_spec_template").read_text("utf-8") - ) - - with _debug_aware_tmp_directory(self.debug_dir) as dir_path: - spec_file_path = os.path.join(dir_path, f"{constants.SERVICE_SPEC}.yaml") - - with open(spec_file_path, "w+", encoding="utf-8") as spec_file: - assert self.model_zip_stage_path.startswith("@") - norm_stage_path = posixpath.normpath(identifier.remove_prefix(self.model_zip_stage_path, "@")) - # Ensure model stage path has root prefix as stage mount will it mount it to root. - absolute_model_stage_path = os.path.join("/", norm_stage_path) - (db, schema, stage, path) = identifier.parse_snowflake_stage_path(norm_stage_path) - substitutes = { - "image": image, - "predict_endpoint_name": constants.PREDICT, - "model_stage": identifier.get_schema_level_object_identifier(db, schema, stage), - "model_zip_stage_path": absolute_model_stage_path, - "inference_server_container_name": constants.INFERENCE_SERVER_CONTAINER, - "target_method": self.target_method, - "num_workers": self.options.num_workers, - "use_gpu": self.options.use_gpu, - "enable_ingress": self.options.enable_ingress, - } - if self.options.model_in_image: - del substitutes["model_stage"] - del substitutes["model_zip_stage_path"] - content = string.Template(spec_template).substitute(substitutes) - content_dict = yaml.safe_load(content) - if self.options.use_gpu: - container = content_dict["spec"]["container"][0] - # TODO[shchen]: SNOW-871538, external dependency that only single GPU is supported on SnowService. - # GPU limit has to be specified in order to trigger the workload to be run on GPU in SnowService. - container["resources"] = { - "limits": {"nvidia.com/gpu": self.options.num_gpus}, - "requests": {"nvidia.com/gpu": self.options.num_gpus}, - } - - # Make LLM use case sequential - if any( - model_blob_meta.model_type == "huggingface_pipeline" or model_blob_meta.model_type == "llm" - for model_blob_meta in self.model_meta.models.values() - ): - container["env"]["_CONCURRENT_REQUESTS_MAX"] = 1 - - yaml.dump(content_dict, spec_file) - logger.debug("Create service spec: \n, %s", content_dict) - - self.session.file.put( - local_file_name=spec_file_path, - stage_location=self._model_artifact_stage_location, - auto_compress=False, - overwrite=True, - ) - logger.debug( - f"Uploaded spec file {os.path.basename(spec_file_path)} " f"to {self._model_artifact_stage_location}" - ) - - def _get_max_batch_rows(self) -> Optional[int]: - # To avoid too large batch in HF LLM case - max_batch_rows = None - if self.options.use_gpu: - for model_blob_meta in self.model_meta.models.values(): - batch_size = None - if model_blob_meta.model_type == "huggingface_pipeline": - model_blob_options_hf = cast( - model_meta_schema.HuggingFacePipelineModelBlobOptions, model_blob_meta.options - ) - batch_size = model_blob_options_hf["batch_size"] - if model_blob_meta.model_type == "llm": - model_blob_options_llm = cast(model_meta_schema.LLMModelBlobOptions, model_blob_meta.options) - batch_size = model_blob_options_llm["batch_size"] - if batch_size: - if max_batch_rows is None: - max_batch_rows = batch_size - else: - max_batch_rows = min(batch_size, max_batch_rows) - return max_batch_rows - - def _deploy_workflow(self, image: str) -> str: - """This function handles workflow deployment to SnowService with the given image. - - Args: - image: Name of the image to create SnowService container from. - - Returns: - service function sql - """ - - self._prepare_and_upload_artifacts_to_stage(image) - client = snowservice_client.SnowServiceClient(self.session) - spec_stage_location = posixpath.join( - self._model_artifact_stage_location.rstrip("/"), f"{constants.SERVICE_SPEC}.yaml" - ) - client.create_or_replace_service( - service_name=self._service_name, - compute_pool=self.options.compute_pool, - spec_stage_location=spec_stage_location, - min_instances=self.options.min_instances, - max_instances=self.options.max_instances, - external_access_integrations=self.options.external_access_integrations, - ) - logger.info(f"Wait for service {self._service_name} to become ready...") - client.block_until_resource_is_ready( - resource_name=self._service_name, resource_type=constants.ResourceType.SERVICE - ) - logger.info(f"Service {self._service_name} is ready. Creating service function...") - - spcs_attribution_utils.record_service_start(self.session, self._service_name) - - service_function_sql = client.create_or_replace_service_function( - service_func_name=self.service_func_name, - service_name=self._service_name, - endpoint_name=constants.PREDICT, - max_batch_rows=self._get_max_batch_rows(), - ) - logger.info(f"Service function {self.service_func_name} is created. Deployment completed successfully!") - return service_function_sql diff --git a/snowflake/ml/model/_deploy_client/snowservice/deploy_options.py b/snowflake/ml/model/_deploy_client/snowservice/deploy_options.py deleted file mode 100644 index 3d8ae330..00000000 --- a/snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +++ /dev/null @@ -1,116 +0,0 @@ -import inspect -import logging -from typing import Any, Dict, List, Optional - -from snowflake.ml._internal.exceptions import ( - error_codes, - exceptions as snowml_exceptions, -) -from snowflake.ml.model._deploy_client.utils import constants - -logger = logging.getLogger(__name__) - - -class SnowServiceDeployOptions: - def __init__( - self, - compute_pool: str, - *, - external_access_integrations: List[str], - image_repo: Optional[str] = None, - min_instances: Optional[int] = 1, - max_instances: Optional[int] = 1, - prebuilt_snowflake_image: Optional[str] = None, - num_gpus: Optional[int] = 0, - num_workers: Optional[int] = None, - enable_remote_image_build: Optional[bool] = True, - force_image_build: Optional[bool] = False, - model_in_image: Optional[bool] = False, - debug_mode: Optional[bool] = False, - enable_ingress: Optional[bool] = False, - ) -> None: - """Initialization - - When updated, please ensure the type hint is updated accordingly at: //snowflake/ml/model/type_hints - - Args: - compute_pool: SnowService compute pool name. Please refer to official doc for how to create a - compute pool: - https://docs.snowflake.com/en/developer-guide/snowpark-container-services/working-with-compute-pool - external_access_integrations: External Access Integrations name used to build image and deploy the model. - Please refer to the doc for how to create an External Access Integrations: https://docs.snowflake.com/ - developer-guide/snowpark-container-services/additional-considerations-services-jobs - #configuring-network-capabilities . - To make sure your image could be built, access to the following endpoint must be allowed. - docker.com:80, docker.com:443, anaconda.com:80, anaconda.com:443, anaconda.org:80, anaconda.org:443, - pypi.org:80, pypi.org:443 - image_repo: SnowService image repo path. e.g. "///". Default to auto - inferred based on session information. - min_instances: Minimum number of service replicas. Default to 1. - max_instances: Maximum number of service replicas. Default to 1. - prebuilt_snowflake_image: When provided, the image-building step is skipped, and the pre-built image from - Snowflake is used as is. This option is for users who consistently use the same image for multiple use - cases, allowing faster deployment. The snowflake image used for deployment is logged to the console for - future use. Default to None. - num_gpus: Number of GPUs to be used for the service. Default to 0. - num_workers: Number of workers used for model inference. Please ensure that the number of workers is set - lower than the total available memory divided by the size of model to prevent memory-related issues. - Default is number of CPU cores * 2 + 1. - enable_remote_image_build: When set to True, will enable image build on a remote SnowService job. - Default is True. - force_image_build: When set to True, an image rebuild will occur. The default is False, which means the - system will automatically check whether a previously built image can be reused - model_in_image: When set to True, image would container full model weights. The default if False, which - means image without model weights and we do stage mount to access weights. - debug_mode: When set to True, deployment artifacts will be persisted in a local temp directory. - enable_ingress: When set to True, will expose HTTP endpoint for access to the predict method of the created - service. Default to False. - """ - - self.compute_pool = compute_pool - self.image_repo = image_repo - self.min_instances = min_instances - self.max_instances = max_instances - self.prebuilt_snowflake_image = prebuilt_snowflake_image - self.num_gpus = num_gpus - self.num_workers = num_workers - self.enable_remote_image_build = enable_remote_image_build - self.force_image_build = force_image_build - self.model_in_image = model_in_image - self.debug_mode = debug_mode - self.enable_ingress = enable_ingress - self.external_access_integrations = external_access_integrations - - if self.num_workers is None and self.use_gpu: - logger.info("num_workers has been defaulted to 1 when using GPU.") - self.num_workers = 1 - - @property - def use_gpu(self) -> bool: - return self.num_gpus is not None and self.num_gpus > 0 - - @classmethod - def from_dict(cls, options_dict: Dict[str, Any]) -> "SnowServiceDeployOptions": - """Construct SnowServiceDeployOptions instance based from an option dictionary. - - Args: - options_dict: The dict containing various deployment options. - - Raises: - SnowflakeMLException: When required option is missing. - - Returns: - A SnowServiceDeployOptions object - """ - required_options = [constants.COMPUTE_POOL] - missing_keys = [key for key in required_options if options_dict.get(key) is None] - if missing_keys: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - f"Must provide options when deploying to Snowpark Container Services: {', '.join(missing_keys)}" - ), - ) - supported_options_keys = inspect.signature(cls.__init__).parameters.keys() - filtered_options = {k: v for k, v in options_dict.items() if k in supported_options_keys} - return cls(**filtered_options) diff --git a/snowflake/ml/model/_deploy_client/snowservice/deploy_test.py b/snowflake/ml/model/_deploy_client/snowservice/deploy_test.py deleted file mode 100644 index a1244bec..00000000 --- a/snowflake/ml/model/_deploy_client/snowservice/deploy_test.py +++ /dev/null @@ -1,613 +0,0 @@ -from typing import Any, Dict, Optional, cast - -from absl.testing import absltest -from absl.testing.absltest import mock - -from snowflake import snowpark -from snowflake.ml.model._deploy_client.snowservice import deploy_options -from snowflake.ml.model._deploy_client.snowservice.deploy import ( - SnowServiceDeployment, - _deploy, - _get_or_create_image_repo, -) -from snowflake.ml.model._deploy_client.utils import constants -from snowflake.ml.test_utils import exception_utils, mock_data_frame, mock_session -from snowflake.snowpark import row, session - - -class Connection: - def __init__(self, account: str, database: str, schema: str) -> None: - self.account = account - self._database = database - self._schema = schema - - -class DeployTestCase(absltest.TestCase): - def setUp(self) -> None: - super().setUp() - self.m_session = mock_session.MockSession(conn=None, test_case=self) - self.options: Dict[str, Any] = { - "compute_pool": "mock_compute_pool", - "image_repo": "mock_image_repo", - "external_access_integrations": ["eai_1"], - } - - self.m_session.add_mock_sql( - query="SHOW PARAMETERS LIKE 'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT' IN SESSION", - result=mock_data_frame.MockDataFrame([row.Row(value="arrow")]), - ) - - self.m_session.add_mock_sql( - query="ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'json'", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - def _get_mocked_compute_pool_res( - self, - state: Optional[str] = "IDLE", - instance_family: Optional[str] = "STANDARD_1", - auto_resume: Optional[bool] = False, - ) -> mock_data_frame.MockDataFrame: - return mock_data_frame.MockDataFrame( - [ - row.Row( - name="mock_compute_pool", - state=state, - min_nodes=1, - max_nodes=1, - instance_family=instance_family, - auto_resume=auto_resume, - ) - ] - ) - - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.model_meta.ModelMetadata") # type: ignore[misc] - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.SnowServiceDeployment") # type: ignore[misc] - def test_deploy_with_model_id(self, m_deployment_class: mock.MagicMock, m_model_meta_class: mock.MagicMock) -> None: - m_deployment = m_deployment_class.return_value - m_model_meta = m_model_meta_class.return_value - - m_model_zip_stage_path = "@mock_model_zip_stage_path/model.zip" - m_deployment_stage_path = "@mock_model_deployment_stage_path" - - self.m_session.add_mock_sql( - query=f"DESC COMPUTE POOL {self.options['compute_pool']}", - result=self._get_mocked_compute_pool_res(), - ) - - self.m_session.add_mock_sql( - query="ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'arrow'", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - _deploy( - session=cast(session.Session, self.m_session), - model_id="provided_model_id", - model_meta=m_model_meta, - service_func_name="mock_service_func", - model_zip_stage_path=m_model_zip_stage_path, - deployment_stage_path=m_deployment_stage_path, - target_method=constants.PREDICT, - **self.options, - ) - - m_deployment_class.assert_called_once_with( - session=self.m_session, - model_id="provided_model_id", - service_func_name="mock_service_func", - model_zip_stage_path=m_model_zip_stage_path, - deployment_stage_path=m_deployment_stage_path, - model_meta=m_model_meta, - target_method=constants.PREDICT, - options=mock.ANY, - ) - m_deployment.deploy.assert_called_once() - - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.model_meta.ModelMetadata") # type: ignore[misc] - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.SnowServiceDeployment") # type: ignore[misc] - def test_deploy_with_compute_pool_in_starting( - self, m_deployment_class: mock.MagicMock, m_model_meta_class: mock.MagicMock - ) -> None: - m_model_meta = m_model_meta_class.return_value - - m_model_zip_stage_path = "@mock_model_zip_stage_path/model.zip" - m_deployment_stage_path = "@mock_model_deployment_stage_path" - m_deployment = m_deployment_class.return_value - - self.m_session.add_mock_sql( - query=f"DESC COMPUTE POOL {self.options['compute_pool']}", - result=self._get_mocked_compute_pool_res(state="STARTING"), - ) - - self.m_session.add_mock_sql( - query="ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'arrow'", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - _deploy( - session=cast(session.Session, self.m_session), - model_id="provided_model_id", - model_meta=m_model_meta, - service_func_name="mock_service_func", - model_zip_stage_path=m_model_zip_stage_path, - deployment_stage_path=m_deployment_stage_path, - target_method=constants.PREDICT, - **self.options, - ) - - m_deployment.deploy.assert_called_once() - - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.model_meta.ModelMetadata") # type: ignore[misc] - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.SnowServiceDeployment") # type: ignore[misc] - def test_deploy_with_not_ready_compute_pool( - self, m_deployment_class: mock.MagicMock, m_model_meta_class: mock.MagicMock - ) -> None: - m_model_meta = m_model_meta_class.return_value - - m_model_zip_stage_path = "@mock_model_zip_stage_path/model.zip" - m_deployment_stage_path = "@mock_model_deployment_stage_path" - - self.m_session.add_mock_sql( - query=f"DESC COMPUTE POOL {self.options['compute_pool']}", - result=self._get_mocked_compute_pool_res(state="STOPPED"), - ) - - self.m_session.add_mock_sql( - query="ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'arrow'", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=RuntimeError): - _deploy( - session=cast(session.Session, self.m_session), - model_id="provided_model_id", - model_meta=m_model_meta, - service_func_name="mock_service_func", - model_zip_stage_path=m_model_zip_stage_path, - deployment_stage_path=m_deployment_stage_path, - target_method=constants.PREDICT, - **self.options, - ) - - m_deployment_class.assert_not_called() - - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.model_meta.ModelMetadata") # type: ignore[misc] - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.SnowServiceDeployment") # type: ignore[misc] - def test_deploy_with_not_auto_resume_compute_pool( - self, m_deployment_class: mock.MagicMock, m_model_meta_class: mock.MagicMock - ) -> None: - m_model_meta = m_model_meta_class.return_value - - m_model_zip_stage_path = "@mock_model_zip_stage_path/model.zip" - m_deployment_stage_path = "@mock_model_deployment_stage_path" - - self.m_session.add_mock_sql( - query=f"DESC COMPUTE POOL {self.options['compute_pool']}", - result=self._get_mocked_compute_pool_res(state="SUSPENDED", auto_resume=False), - ) - - self.m_session.add_mock_sql( - query="ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'arrow'", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=RuntimeError): - _deploy( - session=cast(session.Session, self.m_session), - model_id="provided_model_id", - model_meta=m_model_meta, - service_func_name="mock_service_func", - model_zip_stage_path=m_model_zip_stage_path, - deployment_stage_path=m_deployment_stage_path, - target_method=constants.PREDICT, - **self.options, - ) - - m_deployment_class.assert_not_called() - - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.model_meta.ModelMetadata") # type: ignore[misc] - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.SnowServiceDeployment") # type: ignore[misc] - def test_deploy_with_compute_pool_in_suspended_state_with_auto_resume( - self, m_deployment_class: mock.MagicMock, m_model_meta_class: mock.MagicMock - ) -> None: - m_model_meta = m_model_meta_class.return_value - - m_model_zip_stage_path = "@mock_model_zip_stage_path/model.zip" - m_deployment_stage_path = "@mock_model_deployment_stage_path" - m_deployment = m_deployment_class.return_value - - self.m_session.add_mock_sql( - query=f"DESC COMPUTE POOL {self.options['compute_pool']}", - result=self._get_mocked_compute_pool_res(state="SUSPENDED", auto_resume=True), - ) - - self.m_session.add_mock_sql( - query="ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'arrow'", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - _deploy( - session=cast(session.Session, self.m_session), - model_id="provided_model_id", - model_meta=m_model_meta, - service_func_name="mock_service_func", - model_zip_stage_path=m_model_zip_stage_path, - deployment_stage_path=m_deployment_stage_path, - target_method=constants.PREDICT, - **self.options, - ) - - m_deployment.deploy.assert_called_once() - - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.model_meta.ModelMetadata") # type: ignore[misc] - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.SnowServiceDeployment") # type: ignore[misc] - def test_deploy_with_empty_model_id( - self, m_deployment_class: mock.MagicMock, m_model_meta_class: mock.MagicMock - ) -> None: - m_model_meta = m_model_meta_class.return_value - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=ValueError): - self.m_session.add_mock_sql( - query="ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'arrow'", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - _deploy( - session=cast(session.Session, self.m_session), - service_func_name="mock_service_func", - model_id="", - model_meta=m_model_meta, - model_zip_stage_path="@mock_model_zip_stage_path/model.zip", - deployment_stage_path="@mock_model_deployment_stage_path", - target_method=constants.PREDICT, - **self.options, - ) - - m_deployment_class.assert_not_called() - - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.model_meta.ModelMetadata") # type: ignore[misc] - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.SnowServiceDeployment") # type: ignore[misc] - def test_deploy_with_missing_required_options( - self, m_deployment_class: mock.MagicMock, m_model_meta_class: mock.MagicMock - ) -> None: - m_model_meta = m_model_meta_class.return_value - with exception_utils.assert_snowml_exceptions( - self, expected_original_error_type=ValueError, expected_regex="compute_pool" - ): - options: Dict[str, Any] = {} - self.m_session.add_mock_sql( - query="ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'arrow'", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - _deploy( - session=cast(session.Session, self.m_session), - service_func_name="mock_service_func", - model_id="mock_model_id", - model_meta=m_model_meta, - model_zip_stage_path="@mock_model_zip_stage_path/model.zip", - deployment_stage_path="@mock_model_deployment_stage_path", - target_method=constants.PREDICT, - **options, - ) - m_deployment_class.assert_not_called() - - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.model_meta.ModelMetadata") # type: ignore[misc] - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.SnowServiceDeployment") # type: ignore[misc] - def test_deploy_with_over_requested_gpus( - self, m_deployment_class: mock.MagicMock, m_model_meta_class: mock.MagicMock - ) -> None: - m_model_meta = m_model_meta_class.return_value - m_model_meta.cuda_version = "11.7" - with exception_utils.assert_snowml_exceptions( - self, expected_original_error_type=RuntimeError, expected_regex="GPU request exceeds instance capability" - ): - self.m_session.add_mock_sql( - query=f"DESC COMPUTE POOL {self.options['compute_pool']}", - result=self._get_mocked_compute_pool_res(instance_family="GPU_3"), - ) - self.m_session.add_mock_sql( - query="ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'arrow'", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - _deploy( - session=cast(session.Session, self.m_session), - service_func_name="mock_service_func", - model_id="mock_model_id", - model_meta=m_model_meta, - model_zip_stage_path="@mock_model_zip_stage_path/model.zip", - deployment_stage_path="@mock_model_deployment_stage_path", - target_method=constants.PREDICT, - num_gpus=2, - **self.options, - ) - m_deployment_class.assert_not_called() - - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.model_meta.ModelMetadata") # type: ignore[misc] - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.SnowServiceDeployment") # type: ignore[misc] - def test_deploy_with_over_requested_gpus_no_cuda( - self, m_deployment_class: mock.MagicMock, m_model_meta_class: mock.MagicMock - ) -> None: - m_model_meta = m_model_meta_class.return_value - m_model_meta.env.cuda_version = None - with exception_utils.assert_snowml_exceptions( - self, - expected_original_error_type=ValueError, - expected_regex="You are requesting GPUs for models that do not use a GPU or does not have CUDA version set", - ): - self.m_session.add_mock_sql( - query="ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'arrow'", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - _deploy( - session=cast(session.Session, self.m_session), - service_func_name="mock_service_func", - model_id="mock_model_id", - model_meta=m_model_meta, - model_zip_stage_path="@mock_model_zip_stage_path/model.zip", - deployment_stage_path="@mock_model_deployment_stage_path", - target_method=constants.PREDICT, - num_gpus=2, - **self.options, - ) - m_deployment_class.assert_not_called() - - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.copy.deepcopy") # type: ignore[misc] - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.model_meta.ModelMetadata") # type: ignore[misc] - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.SnowServiceDeployment") # type: ignore[misc] - def test_deploy_with_gpu_validation_and_unknown_instance_type( - self, m_deployment_class: mock.MagicMock, m_model_meta_class: mock.MagicMock, m_deepcopy_func: mock.MagicMock - ) -> None: - m_deployment = m_deployment_class.return_value - m_model_meta = m_model_meta_class.return_value - m_model_meta.cuda_version = "11.7" - m_model_meta_deploy = m_deepcopy_func.return_value - m_model_zip_stage_path = "@mock_model_zip_stage_path/model.zip" - m_deployment_stage_path = "@mock_model_deployment_stage_path" - - unknown_instance_type = "GPU_UNKNOWN" - self.m_session.add_mock_sql( - query=f"DESC COMPUTE POOL {self.options['compute_pool']}", - result=self._get_mocked_compute_pool_res(instance_family=unknown_instance_type), - ) - self.m_session.add_mock_sql( - query="ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'arrow'", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - with self.assertLogs(level="INFO") as cm: - _deploy( - session=cast(session.Session, self.m_session), - model_id="provided_model_id", - model_meta=m_model_meta, - service_func_name="mock_service_func", - model_zip_stage_path=m_model_zip_stage_path, - deployment_stage_path=m_deployment_stage_path, - target_method=constants.PREDICT, - num_gpus=2, - **self.options, - ) - - self.assertListEqual( - cm.output, - [ - "INFO:snowflake.ml.model._deploy_client.snowservice.deploy_options:num_workers has been defaulted" - " to 1 when using GPU.", - ( - "WARNING:snowflake.ml.model._deploy_client.snowservice.deploy:Unknown " - "instance type: GPU_UNKNOWN, skipping GPU validation" - ), - ], - ) - - m_deployment_class.assert_called_once_with( - session=self.m_session, - model_id="provided_model_id", - model_meta=m_model_meta_deploy, - service_func_name="mock_service_func", - model_zip_stage_path=m_model_zip_stage_path, - deployment_stage_path=m_deployment_stage_path, - target_method=constants.PREDICT, - options=mock.ANY, - ) - m_deployment.deploy.assert_called_once() - - -class TestImageRepoCreate(absltest.TestCase): - def setUp(self) -> None: - super().setUp() - self.m_session = mock_session.MockSession(conn=None, test_case=self) - - @mock.patch( - "snowflake.ml.model._deploy_client.snowservice.deploy." "snowservice_client.SnowServiceClient" - ) # type: ignore[misc] - def test_get_or_create_image_repo(self, m_snowservice_client_class: mock.MagicMock) -> None: - db = "mock_db" - schema = "mock_schema" - session_db = "session_db" - session_schema = "session_schema" - repo_url = f"org-account.registry.snowflakecomputing.com/{db}/{schema}/{constants.SNOWML_IMAGE_REPO}" - session_repo_url = ( - f"org-account.registry-dev.snowflakecomputing.com" - f"/{session_db}/{session_schema}/{constants.SNOWML_IMAGE_REPO}" - ) - self.m_session.add_mock_sql( - query=f"SHOW IMAGE REPOSITORIES LIKE '{constants.SNOWML_IMAGE_REPO}' IN SCHEMA {'.'.join([db, schema])}", - result=mock_data_frame.MockDataFrame([row.Row(name=constants.SNOWML_IMAGE_REPO, repository_url=repo_url)]), - ) - self.m_session.add_mock_sql( - query=f"SHOW IMAGE REPOSITORIES LIKE '{constants.SNOWML_IMAGE_REPO}' " - f"IN SCHEMA {'.'.join([session_db, session_schema])}", - result=mock_data_frame.MockDataFrame( - [row.Row(name=constants.SNOWML_IMAGE_REPO, repository_url=session_repo_url)] - ), - ) - self.m_session._conn = mock.MagicMock() - self.m_session._conn._conn = Connection(account="account", database=session_db, schema=session_schema) - - # Test when image repo url is provided. - self.assertEqual( - _get_or_create_image_repo( - session=cast(session.Session, self.m_session), - service_func_name="func", - image_repo="dummy.dummy.snowflakecomputing.com/DB/SCHEMA/REPO", - ), - "dummy.dummy.snowflakecomputing.com/db/schema/repo", - ) - - # Test when image repo is constructed from db/schema inferred from the service func. - self.assertEqual( - _get_or_create_image_repo( - session=cast(session.Session, self.m_session), service_func_name=f"{db}.{schema}.func" - ), - repo_url, - ) - m_snowservice_client = m_snowservice_client_class.return_value - m_snowservice_client.create_image_repo.assert_called_with(f"{db}.{schema}.{constants.SNOWML_IMAGE_REPO}") - - # Test constructing image repo from session object, this happens when service func is missing db and/or schema. - self.assertEqual( - _get_or_create_image_repo(session=cast(session.Session, self.m_session), service_func_name="func"), - session_repo_url, - ) - m_snowservice_client.create_image_repo.assert_called_with( - f"{session_db}.{session_schema}.{constants.SNOWML_IMAGE_REPO}" - ) - - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=RuntimeError): - # Cannot find image repo in the given db/schema. - self.m_session.add_mock_sql( - query=f"SHOW IMAGE REPOSITORIES LIKE " - f"'{constants.SNOWML_IMAGE_REPO}' IN SCHEMA {'.'.join([db, schema])}", - result=mock_data_frame.MockDataFrame([]), - ) - _get_or_create_image_repo( - session=cast(session.Session, self.m_session), service_func_name=f"{db}.{schema}.func" - ) - - -class SnowServiceDeploymentTestCase(absltest.TestCase): - @mock.patch("snowflake.ml.model._deploy_client.snowservice.deploy.model_meta.ModelMetadata") # type: ignore[misc] - def setUp(self, m_model_meta_class: mock.MagicMock) -> None: - super().setUp() - self.m_session = mock_session.MockSession(conn=None, test_case=self) - self.m_model_id = "provided_model_id" - self.m_service_func_name = "mock_db.mock_schema.provided_service_func_name" - self.m_model_zip_stage_path = "@provided_model_zip_stage_path/model.zip" - self.m_deployment_stage_path = "@mock_model_deployment_stage_path" - self.m_model_meta = m_model_meta_class.return_value - self.m_model_meta.cuda_version = None - self.model_name = "mock_model_name" - self.m_model_meta.name = self.model_name - self.m_options = { - "stage": "mock_stage", - "compute_pool": "mock_compute_pool", - "image_repo": "mock_image_repo", - "external_access_integrations": ["eai_a", "eai_b"], - } - - self.deployment = SnowServiceDeployment( - cast(session.Session, self.m_session), - model_id=self.m_model_id, - service_func_name=self.m_service_func_name, - model_meta=self.m_model_meta, - model_zip_stage_path=self.m_model_zip_stage_path, - deployment_stage_path=self.m_deployment_stage_path, - target_method=constants.PREDICT, - options=deploy_options.SnowServiceDeployOptions.from_dict(self.m_options), - ) - - self.m_session.add_mock_sql( - query=f"DESCRIBE SERVICE {self.deployment._service_name}", - result=mock_data_frame.MockDataFrame( - collect_result=[snowpark.Row(**{"name": self.deployment._service_name})] - ), - ) - - def test_service_name(self) -> None: - self.assertEqual(self.deployment._service_name, "mock_db.mock_schema.service_provided_model_id") - - @mock.patch( - "snowflake.ml.model._deploy_client.snowservice.deploy.image_registry_client.ImageRegistryClient" - ) # type: ignore[misc] - def test_deploy(self, m_image_registry_client: mock.MagicMock) -> None: - m_image_registry_client.return_value = mock.MagicMock() - m_client = m_image_registry_client.return_value - m_client.image_exists.return_value = False - m_client.add_tag_to_remote_image.return_value = None - - with mock.patch.object( - self.deployment, "_build_and_upload_image" - ) as m_build_and_upload_image, mock.patch.object( - self.deployment, "_deploy_workflow" - ) as m_deploy_workflow, mock.patch.object( - self.deployment, "_get_full_image_name" - ) as m_get_full_image_name: - full_image_name = "org-account.registry.snowflakecomputing.com/db/schema/repo/image:latest" - m_deploy_workflow.return_value = ("service_spec", "sql") - m_get_full_image_name.return_value = full_image_name - - with self.assertLogs(level="WARNING") as cm: - self.deployment.deploy() - m_build_and_upload_image.assert_called_once() - m_deploy_workflow.assert_called_once_with(full_image_name) - self.assertEqual( - cm.output, - [ - ( - "WARNING:snowflake.ml.model._deploy_client.snowservice.deploy:Building the Docker image " - "and deploying to Snowpark Container Service. This process may take anywhere from a few " - "minutes to a longer period for GPU-based models." - ), - ( - f"WARNING:snowflake.ml.model._deploy_client.snowservice.deploy:Image successfully built! " - f"For future model deployments, the image will be reused if possible, saving model " - f"deployment time. To enforce using the same image, include 'prebuilt_snowflake_image': " - f"'{full_image_name}' in the deploy() function's options." - ), - ], - ) - - m_client.add_tag_to_remote_image.assert_called_once_with( - original_full_image_name=full_image_name, new_tag=self.model_name - ) - - @mock.patch( - "snowflake.ml.model._deploy_client.snowservice.deploy.image_registry_client.ImageRegistryClient" - ) # type: ignore[misc] - def test_deploy_with_image_already_exists_in_registry(self, m_image_registry_client: mock.MagicMock) -> None: - m_image_registry_client.return_value = mock.MagicMock() - m_client = m_image_registry_client.return_value - m_client.image_exists.return_value = True - m_client.add_tag_to_remote_image.return_value = None - - with mock.patch.object( - self.deployment, "_build_and_upload_image" - ) as m_build_and_upload_image, mock.patch.object( - self.deployment, "_deploy_workflow" - ) as m_deploy_workflow, mock.patch.object( - self.deployment, "_get_full_image_name" - ) as m_get_full_image_name: - full_image_name = "org-account.registry.snowflakecomputing.com/db/schema/repo/image:latest" - m_get_full_image_name.return_value = full_image_name - m_deploy_workflow.return_value = ("service_spec", "sql") - - with self.assertLogs(level="WARNING") as cm: - self.deployment.deploy() - m_build_and_upload_image.assert_not_called() - m_deploy_workflow.assert_called_once_with(full_image_name) - - self.assertEqual( - cm.output, - [ - ( - f"WARNING:snowflake.ml.model._deploy_client.snowservice.deploy:Similar environment " - f"detected. Using existing image {full_image_name} to skip image build. To disable this" - f" feature, set 'force_image_build=True' in deployment options" - ) - ], - ) - m_client.add_tag_to_remote_image.assert_called_once_with( - original_full_image_name=full_image_name, new_tag=self.model_name - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/model/_deploy_client/snowservice/instance_types.py b/snowflake/ml/model/_deploy_client/snowservice/instance_types.py deleted file mode 100644 index ab27b15b..00000000 --- a/snowflake/ml/model/_deploy_client/snowservice/instance_types.py +++ /dev/null @@ -1,10 +0,0 @@ -# Snowpark Container Service GPU instance type and corresponding GPU counts. -INSTANCE_TYPE_TO_GPU_COUNT = { - "GPU_3": 1, - "GPU_5": 1, - "GPU_7": 4, - "GPU_10": 8, - "GPU_NV_S": 1, - "GPU_NV_M": 4, - "GPU_NV_L": 8, -} diff --git a/snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template b/snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template deleted file mode 100644 index ce9d2fdd..00000000 --- a/snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +++ /dev/null @@ -1,28 +0,0 @@ -spec: - container: - - name: ${inference_server_container_name} - image: ${image} - env: - MODEL_ZIP_STAGE_PATH: ${model_zip_stage_path} - TARGET_METHOD: ${target_method} - NUM_WORKERS: ${num_workers} - SNOWML_USE_GPU: ${use_gpu} - readinessProbe: - port: 5000 - path: /health - volumeMounts: - - name: vol1 - mountPath: /local/user/vol1 - - name: stage - mountPath: ${model_stage} - endpoint: - - name: ${predict_endpoint_name} - port: 5000 - public: ${enable_ingress} - volume: - - name: vol1 - source: local # only local emptyDir volume is supported - - name: stage - source: "@${model_stage}" - uid: 1000 - gid: 1000 diff --git a/snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model b/snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model deleted file mode 100644 index 66e8bc1c..00000000 --- a/snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +++ /dev/null @@ -1,21 +0,0 @@ -spec: - container: - - name: ${inference_server_container_name} - image: ${image} - env: - TARGET_METHOD: ${target_method} - NUM_WORKERS: ${num_workers} - SNOWML_USE_GPU: ${use_gpu} - readinessProbe: - port: 5000 - path: /health - volumeMounts: - - name: vol1 - mountPath: /local/user/vol1 - endpoint: - - name: ${predict_endpoint_name} - port: 5000 - public: ${enable_ingress} - volume: - - name: vol1 - source: local # only local emptyDir volume is supported diff --git a/snowflake/ml/model/_deploy_client/utils/BUILD.bazel b/snowflake/ml/model/_deploy_client/utils/BUILD.bazel deleted file mode 100644 index 31a123fd..00000000 --- a/snowflake/ml/model/_deploy_client/utils/BUILD.bazel +++ /dev/null @@ -1,29 +0,0 @@ -load("//bazel:py_rules.bzl", "py_library", "py_test") - -package(default_visibility = ["//visibility:public"]) - -py_library( - name = "constants", - srcs = ["constants.py"], -) - -py_library( - name = "snowservice_client", - srcs = ["snowservice_client.py"], - deps = [ - ":constants", - "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/_internal/utils:log_stream_processor", - "//snowflake/ml/_internal/utils:uri", - ], -) - -py_test( - name = "snowservice_client_test", - srcs = ["snowservice_client_test.py"], - deps = [ - ":snowservice_client", - "//snowflake/ml/test_utils:exception_utils", - "//snowflake/ml/test_utils:mock_session", - ], -) diff --git a/snowflake/ml/model/_deploy_client/utils/constants.py b/snowflake/ml/model/_deploy_client/utils/constants.py deleted file mode 100644 index f32d2e19..00000000 --- a/snowflake/ml/model/_deploy_client/utils/constants.py +++ /dev/null @@ -1,48 +0,0 @@ -from enum import Enum - - -class ResourceType(Enum): - SERVICE = "service" - JOB = "job" - - -class ResourceStatus(Enum): - UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet. - PENDING = "PENDING" # resource set is being created, can't be used yet - READY = "READY" # resource set has been deployed. - DELETING = "DELETING" # resource set is being deleted - FAILED = "FAILED" # resource set has failed and cannot be used anymore - DONE = "DONE" # resource set has finished running - NOT_FOUND = "NOT_FOUND" # not found or deleted - INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error. - - -PREDICT = "predict" -STAGE = "stage" -COMPUTE_POOL = "compute_pool" -MIN_INSTANCES = "min_instances" -MAX_INSTANCES = "max_instances" -GPU_COUNT = "gpu" -OVERRIDDEN_BASE_IMAGE = "image" -ENDPOINT = "endpoint" -SERVICE_SPEC = "service_spec" -INFERENCE_SERVER_CONTAINER = "inference-server" - -"""Image build related constants""" -SNOWML_IMAGE_REPO = "snowml_repo" -MODEL_DIR = "model_dir" -INFERENCE_SERVER_DIR = "inference_server" -ENTRYPOINT_SCRIPT = "gunicorn_run.sh" -PROD_IMAGE_REGISTRY_DOMAIN = "snowflakecomputing.com" -PROD_IMAGE_REGISTRY_SUBDOMAIN = "registry" -DEV_IMAGE_REGISTRY_SUBDOMAIN = "registry-dev" -MODEL_ENV_FOLDER = "env" -CONDA_FILE = "conda.yml" -IMAGE_BUILD_JOB_SPEC_TEMPLATE = "image_build_job_spec_template" -KANIKO_SHELL_SCRIPT_TEMPLATE = "kaniko_shell_script_template" -CONTEXT = "context" -KANIKO_SHELL_SCRIPT_NAME = "kaniko_shell_script_fixture.sh" -KANIKO_CONTAINER_NAME = "kaniko" -LATEST_IMAGE_TAG = "latest" -KANIKO_IMAGE = "kaniko-project/executor:v1.16.0-debug" -SPCS_MOUNTED_TOKEN_PATH = "/snowflake/session/token" diff --git a/snowflake/ml/model/_deploy_client/utils/snowservice_client.py b/snowflake/ml/model/_deploy_client/utils/snowservice_client.py deleted file mode 100644 index 43a1f33f..00000000 --- a/snowflake/ml/model/_deploy_client/utils/snowservice_client.py +++ /dev/null @@ -1,280 +0,0 @@ -import json -import logging -import textwrap -import time -from typing import List, Optional - -from snowflake.ml._internal.exceptions import ( - error_codes, - exceptions as snowml_exceptions, -) -from snowflake.ml._internal.utils import log_stream_processor, uri -from snowflake.ml.model._deploy_client.utils import constants -from snowflake.snowpark import Session - -logger = logging.getLogger(__name__) - - -class SnowServiceClient: - """ - SnowService client implementation: a Python wrapper for SnowService SQL queries. - """ - - def __init__(self, session: Session) -> None: - """Initialization - - Args: - session: Snowpark session - """ - self.session = session - - def create_image_repo(self, repo_name: str) -> None: - self.session.sql(f"CREATE IMAGE REPOSITORY IF NOT EXISTS {repo_name}").collect() - - def create_or_replace_service( - self, - service_name: str, - compute_pool: str, - spec_stage_location: str, - external_access_integrations: List[str], - *, - min_instances: Optional[int] = 1, - max_instances: Optional[int] = 1, - ) -> None: - """Create or replace service. Since SnowService doesn't support the CREATE OR REPLACE service syntax, we will - first attempt to drop the service if it exists, and then create the service. Please note that this approach may - have side effects due to the lack of transaction support. - - Args: - service_name: Name of the service. - min_instances: Minimum number of service replicas. - max_instances: Maximum number of service replicas. - external_access_integrations: EAIs for network connection. - compute_pool: Name of the compute pool. - spec_stage_location: Stage path for the service spec. - """ - stage, path = uri.get_stage_and_path(spec_stage_location) - self._drop_service_if_exists(service_name) - sql = textwrap.dedent( - f""" - CREATE SERVICE {service_name} - IN COMPUTE POOL {compute_pool} - FROM {stage} - SPEC = '{path}' - MIN_INSTANCES={min_instances} - MAX_INSTANCES={max_instances} - EXTERNAL_ACCESS_INTEGRATIONS = ({', '.join(external_access_integrations)}) - """ - ) - logger.info(f"Creating service {service_name}") - logger.debug(f"Create service with SQL: \n {sql}") - self.session.sql(sql).collect() - - def create_job( - self, job_name: str, compute_pool: str, spec_stage_location: str, external_access_integrations: List[str] - ) -> None: - """Execute the job creation SQL command. Note that the job creation is synchronous, hence we execute it in a - async way so that we can query the log in the meantime. - - Upon job failure, full job container log will be logged. - - Args: - job_name: name of the job - compute_pool: name of the compute pool - spec_stage_location: path to the stage location where the spec is located at. - external_access_integrations: EAIs for network connection. - """ - stage, path = uri.get_stage_and_path(spec_stage_location) - sql = textwrap.dedent( - f""" - EXECUTE JOB SERVICE - IN COMPUTE POOL {compute_pool} - FROM {stage} - SPECIFICATION_FILE = '{path}' - NAME = {job_name} - EXTERNAL_ACCESS_INTEGRATIONS = ({', '.join(external_access_integrations)}) - """ - ) - logger.debug(f"Create job with SQL: \n {sql}") - self.session.sql(sql).collect_nowait() - self.block_until_resource_is_ready( - resource_name=job_name, - resource_type=constants.ResourceType.JOB, - container_name=constants.KANIKO_CONTAINER_NAME, - max_retries=240, - retry_interval_secs=15, - ) - - def _drop_service_if_exists(self, service_name: str) -> None: - """Drop service if it already exists. - - Args: - service_name: Name of the service. - """ - self.session.sql(f"DROP SERVICE IF EXISTS {service_name}").collect() - - def create_or_replace_service_function( - self, - service_func_name: str, - service_name: str, - *, - endpoint_name: str = constants.PREDICT, - path_at_service_endpoint: str = constants.PREDICT, - max_batch_rows: Optional[int] = None, - ) -> str: - """Create or replace service function. - - Args: - service_func_name: Name of the service function. - service_name: Name of the service. - endpoint_name: Name the service endpoint, declared in the service spec, indicating the listening port. - path_at_service_endpoint: Specify the path/route at the service endpoint. Multiple paths can exist for a - given endpoint. For example, an inference server listening on port 5000 may have paths like "/predict" - and "/monitoring - max_batch_rows: Specify the MAX_BATCH_ROWS property of the service function, if None, leave unset - - Returns: - The actual SQL for service function creation. - """ - max_batch_rows_sql = "" - if max_batch_rows: - max_batch_rows_sql = f"MAX_BATCH_ROWS = {max_batch_rows}" - - sql = textwrap.dedent( - f""" - CREATE OR REPLACE FUNCTION {service_func_name}(input OBJECT) - RETURNS OBJECT - SERVICE={service_name} - ENDPOINT={endpoint_name} - {max_batch_rows_sql} - AS '/{path_at_service_endpoint}' - """ - ) - logger.debug(f"Create service function with SQL: \n {sql}") - self.session.sql(sql).collect() - logger.debug(f"Successfully created service function: {service_func_name}") - return sql - - def block_until_resource_is_ready( - self, - resource_name: str, - resource_type: constants.ResourceType, - *, - max_retries: int = 180, - container_name: str = constants.INFERENCE_SERVER_CONTAINER, - retry_interval_secs: int = 10, - ) -> None: - """Blocks execution until the specified resource is ready. - Note that this is a best-effort approach because when launching a service, it's possible for it to initially - fail due to a system error. However, SnowService may automatically retry and recover the service, leading to - potential false-negative information. - - Args: - resource_name: Name of the resource. - resource_type: Type of the resource. - container_name: The container to query the log from. - max_retries: The maximum number of retries to check the resource readiness (default: 60). - retry_interval_secs: The number of seconds to wait between each retry (default: 10). - - Raises: - SnowflakeMLException: If the resource received the following status [failed, not_found, internal_error, - deleting] - SnowflakeMLException: If the resource does not reach the ready/done state within the specified number - of retries. - """ - assert resource_type == constants.ResourceType.SERVICE or resource_type == constants.ResourceType.JOB - query_command = "" - query_command = f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')" - logger.warning( - f"Best-effort log streaming from SPCS will be enabled when python logging level is set to INFO." - f"Alternatively, you can also query the logs by running the query '{query_command}'" - ) - lsp = log_stream_processor.LogStreamProcessor() - - for attempt_idx in range(max_retries): - if logger.level <= logging.INFO: - resource_log = self.get_resource_log( - resource_name=resource_name, - resource_type=resource_type, - container_name=container_name, - ) - lsp.process_new_logs(resource_log, log_level=logging.INFO) - - status = self.get_resource_status(resource_name=resource_name) - - if resource_type == constants.ResourceType.JOB and status == constants.ResourceStatus.DONE: - return - elif resource_type == constants.ResourceType.SERVICE and status == constants.ResourceStatus.READY: - return - - if ( - status - in [ - constants.ResourceStatus.FAILED, - constants.ResourceStatus.NOT_FOUND, - constants.ResourceStatus.INTERNAL_ERROR, - constants.ResourceStatus.DELETING, - ] - or attempt_idx >= max_retries - 1 - ): - if logger.level > logging.INFO: - resource_log = self.get_resource_log( - resource_name=resource_name, - resource_type=resource_type, - container_name=container_name, - ) - # Show full error log when logging level is above INFO level. For INFO level and below, we already - # show the log through logStreamProcessor above. - logger.error(resource_log) - - error_message = "failed" - if attempt_idx >= max_retries - 1: - error_message = "does not reach ready/done status" - - if resource_type == constants.ResourceType.SERVICE: - self._drop_service_if_exists(service_name=resource_name) - - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_SNOWPARK_CONTAINER_SERVICE_ERROR, - original_exception=RuntimeError( - f"{resource_type} {resource_name} {error_message}." f"\nStatus: {status if status else ''} \n" - ), - ) - time.sleep(retry_interval_secs) - - def get_resource_log( - self, resource_name: str, resource_type: constants.ResourceType, container_name: str - ) -> Optional[str]: - try: - row = self.session.sql( - f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')" - ).collect() - return str(row[0]["SYSTEM$GET_SERVICE_LOGS"]) - except Exception: - return None - - def get_resource_status(self, resource_name: str) -> Optional[constants.ResourceStatus]: - """Get resource status. - - Args: - resource_name: Name of the resource. - - Returns: - Optional[constants.ResourceStatus]: The status of the resource, or None if the resource status is empty. - """ - status_func = "SYSTEM$GET_SERVICE_STATUS" - try: - row = self.session.sql(f"CALL {status_func}('{resource_name}');").collect() - except Exception: - # Silent fail as SPCS status call is not guaranteed to return in time. Will rely on caller to retry. - return None - - resource_metadata = json.loads(row[0][status_func])[0] - logger.debug(f"Resource status metadata: {resource_metadata}") - if resource_metadata and resource_metadata["status"]: - try: - status = resource_metadata["status"] - return constants.ResourceStatus(status) - except ValueError: - logger.warning(f"Unknown status returned: {status}") - return None diff --git a/snowflake/ml/model/_deploy_client/utils/snowservice_client_test.py b/snowflake/ml/model/_deploy_client/utils/snowservice_client_test.py deleted file mode 100644 index e160bc73..00000000 --- a/snowflake/ml/model/_deploy_client/utils/snowservice_client_test.py +++ /dev/null @@ -1,322 +0,0 @@ -import json -from typing import cast - -from absl.testing import absltest -from absl.testing.absltest import mock - -from snowflake import snowpark -from snowflake.ml.model._deploy_client.utils import constants -from snowflake.ml.model._deploy_client.utils.snowservice_client import SnowServiceClient -from snowflake.ml.test_utils import exception_utils, mock_data_frame, mock_session -from snowflake.snowpark import session - - -class SnowServiceClientTest(absltest.TestCase): - def setUp(self) -> None: - super().setUp() - self.m_session = mock_session.MockSession(conn=None, test_case=self) - self.client = SnowServiceClient(cast(session.Session, self.m_session)) - self.m_service_name = "mock_service_name" - - def test_create_or_replace_service(self) -> None: - m_min_instances = 1 - m_max_instances = 2 - m_compute_pool = "mock_compute_pool" - m_stage = "@mock_spec_stage" - m_stage_path = "a/hello.yaml" - m_spec_storgae_location = f"{m_stage}/{m_stage_path}" - - self.m_session.add_mock_sql( - query="drop service if exists mock_service_name", result=mock_data_frame.MockDataFrame(collect_result=[]) - ) - - self.m_session.add_mock_sql( - query=f""" - CREATE SERVICE {self.m_service_name} - IN COMPUTE POOL {m_compute_pool} - FROM {m_stage} - SPEC = '{m_stage_path}' - MIN_INSTANCES={m_min_instances} - MAX_INSTANCES={m_max_instances} - EXTERNAL_ACCESS_INTEGRATIONS=(eai_a, eai_b) - """, - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - self.client.create_or_replace_service( - service_name=self.m_service_name, - min_instances=m_min_instances, - max_instances=m_max_instances, - compute_pool=m_compute_pool, - spec_stage_location=m_spec_storgae_location, - external_access_integrations=["eai_a", "eai_b"], - ) - - def test_create_job_successfully(self) -> None: - with mock.patch.object( - self.client, "get_resource_status", return_value=constants.ResourceStatus.DONE - ) as mock_get_resource_status: - m_compute_pool = "mock_compute_pool" - m_stage = "@mock_spec_stage" - m_stage_path = "a/hello.yaml" - m_spec_storgae_location = f"{m_stage}/{m_stage_path}" - m_job_name = "abcd" - self.m_session.add_mock_sql( - query=f""" - EXECUTE JOB SERVICE - IN COMPUTE POOL {m_compute_pool} - FROM {m_stage} - SPECIFICATION_FILE = '{m_stage_path}' - NAME = {m_job_name} - EXTERNAL_ACCESS_INTEGRATIONS = (eai_a, eai_b)""", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - self.client.create_job( - job_name=m_job_name, - compute_pool=m_compute_pool, - spec_stage_location=m_spec_storgae_location, - external_access_integrations=["eai_a", "eai_b"], - ) - mock_get_resource_status.assert_called_once_with(resource_name=m_job_name) - - def test_create_job_failed(self) -> None: - with self.assertLogs(level="INFO") as cm: - with mock.patch.object(self.client, "get_resource_status", return_value=constants.ResourceStatus.FAILED): - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=RuntimeError): - test_log = "Job fails because of xyz." - m_compute_pool = "mock_compute_pool" - m_stage = "@mock_spec_stage" - m_stage_path = "a/hello.yaml" - m_spec_storgae_location = f"{m_stage}/{m_stage_path}" - m_job_name = "abcd" - - self.m_session.add_mock_sql( - query=f""" - EXECUTE JOB SERVICE - IN COMPUTE POOL {m_compute_pool} - FROM {m_stage} - SPECIFICATION_FILE = '{m_stage_path}' - NAME = {m_job_name} - EXTERNAL_ACCESS_INTEGRATIONS = (eai_a, eai_b)""", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - self.m_session.add_mock_sql( - query=f"CALL SYSTEM$GET_SERVICE_LOGS('{m_job_name}', '0', '{constants.KANIKO_CONTAINER_NAME}')", - result=mock_data_frame.MockDataFrame( - collect_result=[snowpark.Row(**{"SYSTEM$GET_SERVICE_LOGS": test_log})] - ), - ) - - self.client.create_job( - job_name=m_job_name, - compute_pool=m_compute_pool, - spec_stage_location=m_spec_storgae_location, - external_access_integrations=["eai_a", "eai_b"], - ) - - self.assertTrue(cm.output, test_log) - - def test_create_service_function(self) -> None: - m_service_func_name = "mock_service_func_name" - m_service_name = "mock_service_name" - m_endpoint_name = "mock_endpoint_name" - m_path_at_endpoint = "mock_route" - - m_sql = f""" - CREATE OR REPLACE FUNCTION {m_service_func_name}(input OBJECT) - RETURNS OBJECT - SERVICE={m_service_name} - ENDPOINT={m_endpoint_name} - AS '/{m_path_at_endpoint}' - """ - - self.m_session.add_mock_sql( - query=m_sql, - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - self.client.create_or_replace_service_function( - service_func_name=m_service_func_name, - service_name=m_service_name, - endpoint_name=m_endpoint_name, - path_at_service_endpoint=m_path_at_endpoint, - ) - - def test_create_service_function_max_batch_rows(self) -> None: - m_service_func_name = "mock_service_func_name" - m_service_name = "mock_service_name" - m_endpoint_name = "mock_endpoint_name" - m_path_at_endpoint = "mock_route" - m_max_batch_rows = 1 - - m_sql = f""" - CREATE OR REPLACE FUNCTION {m_service_func_name}(input OBJECT) - RETURNS OBJECT - SERVICE={m_service_name} - ENDPOINT={m_endpoint_name} - MAX_BATCH_ROWS={m_max_batch_rows} - AS '/{m_path_at_endpoint}' - """ - - self.m_session.add_mock_sql( - query=m_sql, - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - self.client.create_or_replace_service_function( - service_func_name=m_service_func_name, - service_name=m_service_name, - endpoint_name=m_endpoint_name, - path_at_service_endpoint=m_path_at_endpoint, - max_batch_rows=m_max_batch_rows, - ) - - def test_get_service_status(self) -> None: - row = snowpark.Row( - **{ - "SYSTEM$GET_SERVICE_STATUS": json.dumps( - [ - { - "status": "READY", - "message": "Running", - "containerName": "inference-server", - "instanceId": "0", - "serviceName": "SERVICE_DFC46DE9CEC441B2A3185266C11E79BA", - "image": "image", - "restartCount": 0, - } - ] - ) - } - ) - self.m_session.add_mock_sql( - query="call system$GET_SERVICE_STATUS('mock_service_name');", - result=mock_data_frame.MockDataFrame(collect_result=[row]), - ) - - self.assertEqual( - self.client.get_resource_status(self.m_service_name), - constants.ResourceStatus("READY"), - ) - - row = snowpark.Row( - **{ - "SYSTEM$GET_SERVICE_STATUS": json.dumps( - [ - { - "status": "FAILED", - "message": "Running", - "containerName": "inference-server", - "instanceId": "0", - "serviceName": "SERVICE_DFC46DE9CEC441B2A3185266C11E79BA", - "image": "image", - "restartCount": 0, - } - ] - ) - } - ) - self.m_session.add_mock_sql( - query="call system$GET_SERVICE_STATUS('mock_service_name');", - result=mock_data_frame.MockDataFrame(collect_result=[row]), - ) - - self.assertEqual( - self.client.get_resource_status(self.m_service_name), - constants.ResourceStatus("FAILED"), - ) - - row = snowpark.Row( - **{ - "SYSTEM$GET_SERVICE_STATUS": json.dumps( - [ - { - "status": "", - "message": "Running", - "containerName": "inference-server", - "instanceId": "0", - "serviceName": "SERVICE_DFC46DE9CEC441B2A3185266C11E79BA", - "image": "image", - "restartCount": 0, - } - ] - ) - } - ) - self.m_session.add_mock_sql( - query="call system$GET_SERVICE_STATUS('mock_service_name');", - result=mock_data_frame.MockDataFrame(collect_result=[row]), - ) - self.assertEqual(self.client.get_resource_status(self.m_service_name), None) - - def test_block_until_service_is_ready_happy_path(self) -> None: - with mock.patch.object(self.client, "get_resource_status", return_value=constants.ResourceStatus("READY")): - self.client.block_until_resource_is_ready( - self.m_service_name, constants.ResourceType.SERVICE, max_retries=1, retry_interval_secs=1 - ) - - def test_block_until_service_is_ready_timeout(self) -> None: - test_log = "service fails because of xyz." - - self.m_session.add_mock_sql( - query=f"CALL SYSTEM$GET_SERVICE_LOGS('{self.m_service_name}', '0'," - f"'{constants.INFERENCE_SERVER_CONTAINER}')", - result=mock_data_frame.MockDataFrame( - collect_result=[snowpark.Row(**{"SYSTEM$GET_SERVICE_LOGS": test_log})] - ), - ) - - self.m_session.add_mock_sql( - query=f"DROP SERVICE IF EXISTS {self.m_service_name}", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=RuntimeError): - with mock.patch.object(self.client, "get_resource_status", side_effect=[None, None, None, "READY"]): - self.client.block_until_resource_is_ready( - self.m_service_name, constants.ResourceType.SERVICE, max_retries=1, retry_interval_secs=1 - ) - - def test_block_until_service_is_ready_retries_and_ready(self) -> None: - # Service becomes ready on 2nd retry. - with mock.patch.object( - self.client, "get_resource_status", side_effect=[None, constants.ResourceStatus("READY")] - ): - self.client.block_until_resource_is_ready( - self.m_service_name, constants.ResourceType.SERVICE, max_retries=2, retry_interval_secs=1 - ) - - def test_block_until_service_is_ready_retries_and_fail(self) -> None: - test_log = "service fails because of abc." - # First status call return None; first get_log is empty; second status call return failed state - self.m_session.add_mock_sql( - query=f"CALL SYSTEM$GET_SERVICE_LOGS('{self.m_service_name}', '0'," - f"'{constants.INFERENCE_SERVER_CONTAINER}')", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - self.m_session.add_mock_sql( - query=f"CALL SYSTEM$GET_SERVICE_LOGS('{self.m_service_name}', '0'," - f"'{constants.INFERENCE_SERVER_CONTAINER}')", - result=mock_data_frame.MockDataFrame( - collect_result=[snowpark.Row(**{"SYSTEM$GET_SERVICE_LOGS": test_log})] - ), - ) - self.m_session.add_mock_sql( - query=f"DROP SERVICE IF EXISTS {self.m_service_name}", - result=mock_data_frame.MockDataFrame(collect_result=[]), - ) - - # Service show failure status on 2nd retry. - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=RuntimeError): - with mock.patch.object( - self.client, "get_resource_status", side_effect=[None, constants.ResourceStatus("FAILED")] - ): - self.client.block_until_resource_is_ready( - self.m_service_name, constants.ResourceType.SERVICE, max_retries=2, retry_interval_secs=1 - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/model/_deploy_client/warehouse/BUILD.bazel b/snowflake/ml/model/_deploy_client/warehouse/BUILD.bazel deleted file mode 100644 index d5764e2c..00000000 --- a/snowflake/ml/model/_deploy_client/warehouse/BUILD.bazel +++ /dev/null @@ -1,37 +0,0 @@ -load("//bazel:py_rules.bzl", "py_library", "py_test") - -package(default_visibility = ["//visibility:public"]) - -py_library( - name = "infer_template", - srcs = ["infer_template.py"], -) - -py_library( - name = "deploy", - srcs = ["deploy.py"], - deps = [ - ":infer_template", - "//snowflake/ml/_internal:env", - "//snowflake/ml/_internal:env_utils", - "//snowflake/ml/_internal:file_utils", - "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/model/_packager/model_meta", - ], -) - -py_test( - name = "deploy_test", - srcs = ["deploy_test.py"], - deps = [ - ":deploy", - "//snowflake/ml/_internal:env", - "//snowflake/ml/_internal:env_utils", - "//snowflake/ml/model:model_signature", - "//snowflake/ml/model/_packager/model_meta", - "//snowflake/ml/test_utils:exception_utils", - "//snowflake/ml/test_utils:mock_data_frame", - "//snowflake/ml/test_utils:mock_session", - ], -) diff --git a/snowflake/ml/model/_deploy_client/warehouse/deploy.py b/snowflake/ml/model/_deploy_client/warehouse/deploy.py deleted file mode 100644 index e89b49e6..00000000 --- a/snowflake/ml/model/_deploy_client/warehouse/deploy.py +++ /dev/null @@ -1,202 +0,0 @@ -import copy -import logging -import posixpath -import tempfile -import textwrap -from types import ModuleType -from typing import IO, List, Optional, Tuple, TypedDict, Union - -from typing_extensions import Unpack - -from snowflake.ml._internal import env_utils, file_utils -from snowflake.ml._internal.exceptions import ( - error_codes, - exceptions as snowml_exceptions, -) -from snowflake.ml.model import type_hints as model_types -from snowflake.ml.model._deploy_client.warehouse import infer_template -from snowflake.ml.model._packager.model_meta import model_meta -from snowflake.snowpark import session as snowpark_session, types as st - -logger = logging.getLogger(__name__) - - -def _deploy_to_warehouse( - session: snowpark_session.Session, - *, - model_stage_file_path: str, - model_meta: model_meta.ModelMetadata, - udf_name: str, - target_method: str, - **kwargs: Unpack[model_types.WarehouseDeployOptions], -) -> None: - """Deploy the model to warehouse as UDF. - - Args: - session: Snowpark session. - model_stage_file_path: Path to the stored model zip file in the stage. - model_meta: Model Metadata. - udf_name: Name of the UDF. - target_method: The name of the target method to be deployed. - **kwargs: Options that control some features in generated udf code. - - Raises: - SnowflakeMLException: Raised when model file name is unable to encoded using ASCII. - SnowflakeMLException: Raised when incompatible model. - SnowflakeMLException: Raised when target method does not exist in model. - SnowflakeMLException: Raised when confronting invalid stage location. - - """ - # TODO(SNOW-862576): Should remove check on ASCII encoding after SNOW-862576 fixed. - model_stage_file_name = posixpath.basename(model_stage_file_path) - if not file_utils._able_ascii_encode(model_stage_file_name): - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - f"Model file name {model_stage_file_name} cannot be encoded using ASCII. Please rename." - ), - ) - - relax_version = kwargs.get("relax_version", False) - - if target_method not in model_meta.signatures.keys(): - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError(f"Target method {target_method} does not exist in model."), - ) - - final_packages = _get_model_final_packages(model_meta, session, relax_version=relax_version) - - stage_location = kwargs.get("permanent_udf_stage_location", None) - if stage_location: - stage_location = posixpath.normpath(stage_location.strip()) - if not stage_location.startswith("@"): - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError(f"Invalid stage location {stage_location}."), - ) - - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f: - _write_UDF_py_file(f.file, model_stage_file_name=model_stage_file_name, target_method=target_method, **kwargs) - logger.info(f"Generated UDF file is persisted at: {f.name}") - - class _UDFParams(TypedDict): - file_path: str - func_name: str - name: str - input_types: List[st.DataType] - return_type: st.DataType - imports: List[Union[str, Tuple[str, str]]] - packages: List[Union[str, ModuleType]] - - params = _UDFParams( - file_path=f.name, - func_name="infer", - name=udf_name, - return_type=st.PandasSeriesType(st.MapType(st.StringType(), st.VariantType())), - input_types=[st.PandasDataFrameType([st.MapType()])], - imports=[model_stage_file_path], - packages=list(final_packages), - ) - if stage_location is None: # Temporary UDF - session.udf.register_from_file(**params, replace=True) - else: # Permanent UDF - session.udf.register_from_file( - **params, - replace=kwargs.get("replace_udf", False), - is_permanent=True, - stage_location=stage_location, - ) - - logger.info(f"{udf_name} is deployed to warehouse.") - - -def _write_UDF_py_file( - f: IO[str], - model_stage_file_name: str, - target_method: str, - **kwargs: Unpack[model_types.WarehouseDeployOptions], -) -> None: - """Generate and write UDF python code into a file - - Args: - f: File descriptor to write the python code. - model_stage_file_name: Model zip file name. - target_method: The name of the target method to be deployed. - **kwargs: Options that control some features in generated udf code. - """ - udf_code = infer_template._UDF_CODE_TEMPLATE.format( - model_stage_file_name=model_stage_file_name, - _KEEP_ORDER_COL_NAME=infer_template._KEEP_ORDER_COL_NAME, - target_method=target_method, - code_dir_name=model_meta.MODEL_CODE_DIR, - ) - f.write(udf_code) - f.flush() - - -def _get_model_final_packages( - meta: model_meta.ModelMetadata, - session: snowpark_session.Session, - relax_version: Optional[bool] = False, -) -> List[str]: - """Generate final packages list of dependency of a model to be deployed to warehouse. - - Args: - meta: Model metadata to get dependency information. - session: Snowpark connection session. - relax_version: Whether or not relax the version restriction when fail to resolve dependencies. - Defaults to False. - - Raises: - SnowflakeMLException: Raised when PIP requirements and dependencies from non-Snowflake anaconda channel found. - SnowflakeMLException: Raised when not all packages are available in snowflake conda channel. - - Returns: - List of final packages string that is accepted by Snowpark register UDF call. - """ - - if ( - any(channel.lower() not in [env_utils.DEFAULT_CHANNEL_NAME] for channel in meta.env._conda_dependencies.keys()) - or meta.env.pip_requirements - ): - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.DEPENDENCY_VERSION_ERROR, - original_exception=RuntimeError( - "PIP requirements and dependencies from non-Snowflake anaconda channel is not supported." - ), - ) - - if relax_version: - relaxed_env = copy.deepcopy(meta.env) - relaxed_env.relax_version() - required_packages = relaxed_env._conda_dependencies[env_utils.DEFAULT_CHANNEL_NAME] - else: - required_packages = meta.env._conda_dependencies[env_utils.DEFAULT_CHANNEL_NAME] - - package_availability_dict = env_utils.get_matched_package_versions_in_information_schema( - session, required_packages, python_version=meta.env.python_version - ) - no_version_available_packages = [ - req_name for req_name, ver_list in package_availability_dict.items() if len(ver_list) < 1 - ] - unavailable_packages = [req.name for req in required_packages if req.name not in package_availability_dict] - if no_version_available_packages or unavailable_packages: - relax_version_info_str = "" if relax_version else "Try to set relax_version as True in the options. " - required_package_str = " ".join(map(lambda x: f'"{x}"', required_packages)) - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.DEPENDENCY_VERSION_ERROR, - original_exception=RuntimeError( - textwrap.dedent( - f""" - The model's dependencies are not available in Snowflake Anaconda Channel. {relax_version_info_str} - Required packages are: {required_package_str} - Required Python version is: {meta.env.python_version} - Packages that are not available are: {unavailable_packages} - Packages that cannot meet your requirements are: {no_version_available_packages} - Package availability information of those you requested is: {package_availability_dict} - """ - ), - ), - ) - return list(sorted(map(str, required_packages))) diff --git a/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py b/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py deleted file mode 100644 index 60d38909..00000000 --- a/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py +++ /dev/null @@ -1,208 +0,0 @@ -import platform -import tempfile -import textwrap -from importlib import metadata as importlib_metadata -from typing import Dict, List, cast - -from absl.testing import absltest -from packaging import requirements - -from snowflake.ml._internal import env as snowml_env, env_utils -from snowflake.ml.model import model_signature -from snowflake.ml.model._deploy_client.warehouse import deploy -from snowflake.ml.model._packager.model_meta import model_blob_meta, model_meta -from snowflake.ml.test_utils import exception_utils, mock_data_frame, mock_session -from snowflake.snowpark import row, session - -_DUMMY_SIG = { - "predict": model_signature.ModelSignature( - inputs=[ - model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="input"), - ], - outputs=[model_signature.FeatureSpec(name="output", dtype=model_signature.DataType.FLOAT)], - ) -} - -_DUMMY_BLOB = model_blob_meta.ModelBlobMeta( - name="model1", model_type="custom", path="mock_path", handler_version="version_0" -) - -_BASIC_DEPENDENCIES_FINAL_PACKAGES = list( - sorted( - map( - lambda x: env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x)), - model_meta._PACKAGING_CORE_DEPENDENCIES + [env_utils.SNOWPARK_ML_PKG_NAME], - ), - key=lambda x: x.name, - ) -) - - -class TestFinalPackagesWithoutConda(absltest.TestCase): - @classmethod - def setUpClass(cls) -> None: - cls.m_session = mock_session.MockSession(conn=None, test_case=None) - - def setUp(self) -> None: - self.add_packages( - { - **{ - basic_dep.name: [importlib_metadata.version(basic_dep.name)] - for basic_dep in _BASIC_DEPENDENCIES_FINAL_PACKAGES - if basic_dep.name != env_utils.SNOWPARK_ML_PKG_NAME - }, - env_utils.SNOWPARK_ML_PKG_NAME: [snowml_env.VERSION], - } - ) - - @classmethod - def tearDownClass(cls) -> None: - pass - - def add_packages(self, packages_dicts: Dict[str, List[str]]) -> None: - pkg_names_str = " OR ".join(f"package_name = '{pkg}'" for pkg in sorted(packages_dicts.keys())) - query = textwrap.dedent( - f""" - SELECT PACKAGE_NAME, VERSION - FROM information_schema.packages - WHERE ({pkg_names_str}) - AND language = 'python' - AND (runtime_version = '{platform.python_version_tuple()[0]}.{platform.python_version_tuple()[1]}' - OR runtime_version is null); - """ - ) - sql_result = [ - row.Row(PACKAGE_NAME=pkg, VERSION=pkg_ver) - for pkg, pkg_vers in packages_dicts.items() - for pkg_ver in pkg_vers - ] - if len(sql_result) == 0: - sql_result = [row.Row()] - - self.m_session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) - - def test_get_model_final_packages(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - _legacy_save=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - c_session = cast(session.Session, self.m_session) - final_packages = deploy._get_model_final_packages(meta, c_session) - self.assertListEqual( - final_packages, - list(map(str, map(env_utils.relax_requirement_version, _BASIC_DEPENDENCIES_FINAL_PACKAGES))), - ) - - def test_get_model_final_packages_no_relax(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - conda_dependencies=["pandas==1.0.*"], - _legacy_save=True, - relax_version=False, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - c_session = cast(session.Session, self.m_session) - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=RuntimeError): - deploy._get_model_final_packages(meta, c_session) - - def test_get_model_final_packages_relax(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - _legacy_save=True, - relax_version=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - c_session = cast(session.Session, self.m_session) - final_packages = deploy._get_model_final_packages(meta, c_session, relax_version=True) - self.assertListEqual( - final_packages, - list(map(str, map(env_utils.relax_requirement_version, _BASIC_DEPENDENCIES_FINAL_PACKAGES))), - ) - - def test_get_model_final_packages_with_pip(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - pip_requirements=["python-package"], - _legacy_save=True, - relax_version=False, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - c_session = cast(session.Session, self.m_session) - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=RuntimeError): - deploy._get_model_final_packages(meta, c_session) - - def test_get_model_final_packages_with_other_channel(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - conda_dependencies=["conda-forge::python_package"], - _legacy_save=True, - relax_version=False, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - c_session = cast(session.Session, self.m_session) - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=RuntimeError): - deploy._get_model_final_packages(meta, c_session) - - def test_get_model_final_packages_with_non_exist_package(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} - d = { - **{ - basic_dep.name: [importlib_metadata.version(basic_dep.name)] - for basic_dep in _BASIC_DEPENDENCIES_FINAL_PACKAGES - if basic_dep.name != env_utils.SNOWPARK_ML_PKG_NAME - }, - env_utils.SNOWPARK_ML_PKG_NAME: [snowml_env.VERSION], - } - d["python-package"] = [] - self.m_session = mock_session.MockSession(conn=None, test_case=self) - self.add_packages(d) - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - conda_dependencies=["python-package"], - _legacy_save=True, - relax_version=False, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - c_session = cast(session.Session, self.m_session) - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=RuntimeError): - deploy._get_model_final_packages(meta, c_session) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/model/_deploy_client/warehouse/infer_template.py b/snowflake/ml/model/_deploy_client/warehouse/infer_template.py deleted file mode 100644 index 34d2e044..00000000 --- a/snowflake/ml/model/_deploy_client/warehouse/infer_template.py +++ /dev/null @@ -1,99 +0,0 @@ -_KEEP_ORDER_COL_NAME = "_ID" - -_UDF_CODE_TEMPLATE = """ -import fcntl -import functools -import inspect -import os -import sys -import threading -import zipfile -from types import TracebackType -from typing import Optional, Type - -import anyio -import pandas as pd -from _snowflake import vectorized - - -class FileLock: - def __enter__(self) -> None: - self._lock = threading.Lock() - self._lock.acquire() - self._fd = open("/tmp/lockfile.LOCK", "w+") - fcntl.lockf(self._fd, fcntl.LOCK_EX) - - def __exit__( - self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] - ) -> None: - self._fd.close() - self._lock.release() - - -# User-defined parameters -MODEL_FILE_NAME = "{model_stage_file_name}" -TARGET_METHOD = "{target_method}" -MAX_BATCH_SIZE = None - - -# Retrieve the model -IMPORT_DIRECTORY_NAME = "snowflake_import_directory" -import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME] - -model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0] -zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME) -extracted = "/tmp/models" -extracted_model_dir_path = os.path.join(extracted, model_dir_name) - -with FileLock(): - if not os.path.isdir(extracted_model_dir_path): - with zipfile.ZipFile(zip_model_path, "r") as myzip: - myzip.extractall(extracted_model_dir_path) - -sys.path.insert(0, os.path.join(extracted_model_dir_path, "{code_dir_name}")) - -# Load the model -try: - from snowflake.ml.model._packager import model_packager - pk = model_packager.ModelPackager(extracted_model_dir_path) - pk.load(as_custom_model=True) - assert pk.model, "model is not loaded" - assert pk.meta, "model metadata is not loaded" - - model = pk.model - meta = pk.meta -except ImportError as e: - if e.name and not e.name.startswith("snowflake.ml"): - raise e - # Support Legacy model - from snowflake.ml.model import _model - # Backward for <= 1.0.5 - if hasattr(_model, "_load_model_for_deploy"): - model, meta = _model._load_model_for_deploy(extracted_model_dir_path) - else: - model, meta = _model._load(local_dir_path=extracted_model_dir_path, as_custom_model=True) - -# Determine the actual runner -func = getattr(model, TARGET_METHOD) -if inspect.iscoroutinefunction(func): - runner = functools.partial(anyio.run, func) -else: - runner = functools.partial(func) - -# Determine preprocess parameters -features = meta.signatures[TARGET_METHOD].inputs -input_cols = [feature.name for feature in features] -dtype_map = {{feature.name: feature.as_dtype() for feature in features}} - - -# Actual handler -@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE) -def infer(df: pd.DataFrame) -> dict: - input_df = pd.json_normalize(df[0]).astype(dtype=dtype_map) - predictions_df = runner(input_df[input_cols]) - - if "{_KEEP_ORDER_COL_NAME}" in input_df.columns: - predictions_df["{_KEEP_ORDER_COL_NAME}"] = input_df["{_KEEP_ORDER_COL_NAME}"] - - return predictions_df.to_dict("records") -""" diff --git a/snowflake/ml/model/_model_composer/model_composer.py b/snowflake/ml/model/_model_composer/model_composer.py index 2946deea..d92a8981 100644 --- a/snowflake/ml/model/_model_composer/model_composer.py +++ b/snowflake/ml/model/_model_composer/model_composer.py @@ -1,14 +1,11 @@ -import glob import pathlib import tempfile import uuid -import zipfile from types import ModuleType from typing import Any, Dict, List, Optional from absl import logging from packaging import requirements -from typing_extensions import deprecated from snowflake import snowpark from snowflake.ml._internal import env as snowml_env, env_utils, file_utils @@ -92,7 +89,7 @@ def save( python_version: Optional[str] = None, ext_modules: Optional[List[ModuleType]] = None, code_paths: Optional[List[str]] = None, - model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN, + task: model_types.Task = model_types.Task.UNKNOWN, options: Optional[model_types.ModelSaveOption] = None, ) -> model_meta.ModelMetadata: if not options: @@ -121,25 +118,20 @@ def save( python_version=python_version, ext_modules=ext_modules, code_paths=code_paths, - model_objective=model_objective, + task=task, options=options, ) assert self.packager.meta is not None - if not options.get("_legacy_save", False): - # Keep both loose files and zipped file. - # TODO(SNOW-726678): Remove once import a directory is possible. - file_utils.copytree( - str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH) - ) - self.manifest.save( - model_meta=self.packager.meta, - model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH), - options=options, - data_sources=self._get_data_sources(model, sample_input_data), - ) - else: - file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path)) + file_utils.copytree( + str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH) + ) + self.manifest.save( + model_meta=self.packager.meta, + model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH), + options=options, + data_sources=self._get_data_sources(model, sample_input_data), + ) file_utils.upload_directory_to_stage( self.session, @@ -149,28 +141,6 @@ def save( ) return model_metadata - @deprecated("Only used by PrPr model registry. Use static method version of load instead.") - def legacy_load( - self, - *, - meta_only: bool = False, - options: Optional[model_types.ModelLoadOption] = None, - ) -> None: - file_utils.download_directory_from_stage( - self.session, - stage_path=self.stage_path, - local_path=self.workspace_path, - statement_params=self._statement_params, - ) - - # TODO (Server-side Model Rollout): Remove this section. - model_zip_path = pathlib.Path(glob.glob(str(self.workspace_path / "*.zip"))[0]) - self.model_file_rel_path = str(model_zip_path.relative_to(self.workspace_path)) - - with zipfile.ZipFile(self.model_local_path, mode="r", compression=zipfile.ZIP_DEFLATED) as zf: - zf.extractall(path=self._packager_workspace_path) - self.packager.load(meta_only=meta_only, options=options) - @staticmethod def load( workspace_path: pathlib.Path, diff --git a/snowflake/ml/model/_model_composer/model_composer_test.py b/snowflake/ml/model/_model_composer/model_composer_test.py index 6359160a..41ced3bb 100644 --- a/snowflake/ml/model/_model_composer/model_composer_test.py +++ b/snowflake/ml/model/_model_composer/model_composer_test.py @@ -77,7 +77,7 @@ def test_save_interface(self) -> None: name="model1", model=linear_model.LinearRegression(), sample_input_data=d, - model_objective=model_types.ModelObjective.REGRESSION, + task=model_types.Task.TABULAR_REGRESSION, ) mock_upload_directory_to_stage.assert_called_once_with( diff --git a/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel b/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel index e333e4c8..264a250b 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel +++ b/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel @@ -23,6 +23,7 @@ py_library( "//snowflake/ml/model/_model_composer/model_method", "//snowflake/ml/model/_model_composer/model_method:function_generator", "//snowflake/ml/model/_packager/model_meta", + "//snowflake/ml/model/_packager/model_runtime", ], ) @@ -48,5 +49,6 @@ py_test( "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_packager/model_meta", "//snowflake/ml/model/_packager/model_meta:model_blob_meta", + "//snowflake/ml/model/_packager/model_runtime", ], ) diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py index 7074818c..9a6ca525 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py @@ -1,6 +1,7 @@ import collections -import copy +import logging import pathlib +import warnings from typing import List, Optional, cast import yaml @@ -17,6 +18,9 @@ model_meta as model_meta_api, model_meta_schema, ) +from snowflake.ml.model._packager.model_runtime import model_runtime + +logger = logging.getLogger(__name__) class ModelManifest: @@ -44,9 +48,30 @@ def save( if options is None: options = {} - runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"]) - runtime_to_use.name = self._DEFAULT_RUNTIME_NAME - runtime_to_use.imports.append(str(model_rel_path) + "/") + if "relax_version" not in options: + warnings.warn( + ( + "`relax_version` is not set and therefore defaulted to True. Dependency version constraints relaxed" + " from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility, " + "reproducibility, etc., set `options={'relax_version': False}` when logging the model." + ), + category=UserWarning, + stacklevel=2, + ) + relax_version = options.get("relax_version", True) + + runtime_to_use = model_runtime.ModelRuntime( + name=self._DEFAULT_RUNTIME_NAME, + env=model_meta.env, + imports=[str(model_rel_path) + "/"], + is_gpu=False, + is_warehouse=True, + ) + if relax_version: + runtime_to_use.runtime_env.relax_version() + logger.info("Relaxing version constraints for dependencies in the model.") + logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}") + logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}") runtime_dict = runtime_to_use.save( self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL ) diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py index 2ee07e5c..0c076972 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py @@ -5,6 +5,7 @@ import importlib_resources import yaml from absl.testing import absltest +from packaging import requirements from snowflake.ml._internal import env_utils from snowflake.ml.model import model_signature, type_hints @@ -14,6 +15,7 @@ model_meta, model_meta_schema, ) +from snowflake.ml.model._packager.model_runtime import model_runtime _DUMMY_SIG = { "predict": model_signature.ModelSignature( @@ -37,6 +39,57 @@ name="model1", model_type="custom", path="mock_path", handler_version="version_0" ) +_PACKAGING_REQUIREMENTS_TARGET_WITHOUT_SNOWML = ( + list( + sorted( + map( + lambda x: str(env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x))), + model_meta._PACKAGING_REQUIREMENTS, + ) + ) + ) + + model_runtime._SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES +) + +_PACKAGING_REQUIREMENTS_TARGET_WITHOUT_SNOWML_RELAXED = ( + list( + sorted( + map( + lambda x: str( + env_utils.relax_requirement_version( + env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x)) + ) + ), + model_meta._PACKAGING_REQUIREMENTS, + ) + ) + ) + + model_runtime._SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES +) + +_PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML = list( + sorted( + map( + lambda x: str(env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x))), + model_meta._PACKAGING_REQUIREMENTS + [env_utils.SNOWPARK_ML_PKG_NAME], + ) + ) +) + + +_PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED = list( + sorted( + map( + lambda x: str( + env_utils.relax_requirement_version( + env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x)) + ) + ), + model_meta._PACKAGING_REQUIREMENTS + [env_utils.SNOWPARK_ML_PKG_NAME], + ) + ) +) + class ModelManifestTest(absltest.TestCase): def test_model_manifest_1(self) -> None: @@ -52,7 +105,17 @@ def test_model_manifest_1(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - mm.save(meta, pathlib.PurePosixPath("model")) + with self.assertWarnsRegex(UserWarning, "`relax_version` is not set and therefore defaulted to True."): + mm.save(meta, pathlib.PurePosixPath("model")) + with open(pathlib.Path(workspace, "runtimes", "python_runtime", "env", "conda.yml"), encoding="utf-8") as f: + self.assertDictEqual( + yaml.safe_load(f), + { + "channels": [env_utils.SNOWFLAKE_CONDA_CHANNEL_URL, "nodefaults"], + "dependencies": ["python==3.8.*"] + _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED, + "name": "snow-env", + }, + ) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: self.assertEqual( ( @@ -74,6 +137,36 @@ def test_model_manifest_1(self) -> None: f.read(), ) + def test_model_manifest_1_relax_version(self) -> None: + with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: + mm = model_manifest.ModelManifest(pathlib.Path(workspace)) + with model_meta.create_model_metadata( + model_dir_path=tmpdir, + name="model1", + model_type="custom", + signatures=_DUMMY_SIG, + python_version="3.8", + embed_local_ml_library=False, + ) as meta: + meta.models["model1"] = _DUMMY_BLOB + + mm.save( + meta, + pathlib.PurePosixPath("model"), + options=type_hints.BaseModelSaveOption( + relax_version=False, + ), + ) + with open(pathlib.Path(workspace, "runtimes", "python_runtime", "env", "conda.yml"), encoding="utf-8") as f: + self.assertDictEqual( + yaml.safe_load(f), + { + "channels": [env_utils.SNOWFLAKE_CONDA_CHANNEL_URL, "nodefaults"], + "dependencies": ["python==3.8.*"] + _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML, + "name": "snow-env", + }, + ) + def test_model_manifest_2(self) -> None: with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: mm = model_manifest.ModelManifest(pathlib.Path(workspace)) @@ -91,9 +184,19 @@ def test_model_manifest_2(self) -> None: meta, pathlib.PurePosixPath("model"), options=type_hints.BaseModelSaveOption( - method_options={"__call__": type_hints.ModelMethodSaveOptions(max_batch_size=10)} + method_options={"__call__": type_hints.ModelMethodSaveOptions(max_batch_size=10)}, + relax_version=False, ), ) + with open(pathlib.Path(workspace, "runtimes", "python_runtime", "env", "conda.yml"), encoding="utf-8") as f: + self.assertDictEqual( + yaml.safe_load(f), + { + "channels": [env_utils.SNOWFLAKE_CONDA_CHANNEL_URL, "nodefaults"], + "dependencies": ["python==3.8.*"] + _PACKAGING_REQUIREMENTS_TARGET_WITHOUT_SNOWML, + "name": "snow-env", + }, + ) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: self.assertEqual( ( @@ -115,6 +218,37 @@ def test_model_manifest_2(self) -> None: f.read(), ) + def test_model_manifest_2_relax_version(self) -> None: + with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: + mm = model_manifest.ModelManifest(pathlib.Path(workspace)) + with model_meta.create_model_metadata( + model_dir_path=tmpdir, + name="model1", + model_type="custom", + signatures={"__call__": _DUMMY_SIG["predict"]}, + python_version="3.8", + embed_local_ml_library=True, + ) as meta: + meta.models["model1"] = _DUMMY_BLOB + + mm.save( + meta, + pathlib.PurePosixPath("model"), + options=type_hints.BaseModelSaveOption( + method_options={"__call__": type_hints.ModelMethodSaveOptions(max_batch_size=10)}, + relax_version=True, + ), + ) + with open(pathlib.Path(workspace, "runtimes", "python_runtime", "env", "conda.yml"), encoding="utf-8") as f: + self.assertDictEqual( + yaml.safe_load(f), + { + "channels": [env_utils.SNOWFLAKE_CONDA_CHANNEL_URL, "nodefaults"], + "dependencies": ["python==3.8.*"] + _PACKAGING_REQUIREMENTS_TARGET_WITHOUT_SNOWML_RELAXED, + "name": "snow-env", + }, + ) + def test_model_manifest_mix(self) -> None: with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: mm = model_manifest.ModelManifest(pathlib.Path(workspace)) @@ -149,8 +283,13 @@ def test_model_manifest_mix(self) -> None: f.read(), ) with open(pathlib.Path(workspace, "runtimes", "python_runtime", "env", "conda.yml"), encoding="utf-8") as f: - self.assertListEqual( - yaml.safe_load(f)["channels"], [env_utils.SNOWFLAKE_CONDA_CHANNEL_URL, "nodefaults"] + self.assertDictEqual( + yaml.safe_load(f), + { + "channels": [env_utils.SNOWFLAKE_CONDA_CHANNEL_URL, "nodefaults"], + "dependencies": ["python==3.8.*"] + _PACKAGING_REQUIREMENTS_TARGET_WITHOUT_SNOWML_RELAXED, + "name": "snow-env", + }, ) with open(pathlib.Path(workspace, "functions", "predict.py"), encoding="utf-8") as f: self.assertEqual( diff --git a/snowflake/ml/model/_packager/BUILD.bazel b/snowflake/ml/model/_packager/BUILD.bazel index 1be32a3a..ae2045cc 100644 --- a/snowflake/ml/model/_packager/BUILD.bazel +++ b/snowflake/ml/model/_packager/BUILD.bazel @@ -24,7 +24,6 @@ py_library( "//snowflake/ml/model/_packager/model_handlers:custom", "//snowflake/ml/model/_packager/model_handlers:huggingface_pipeline", "//snowflake/ml/model/_packager/model_handlers:lightgbm", - "//snowflake/ml/model/_packager/model_handlers:llm", "//snowflake/ml/model/_packager/model_handlers:mlflow", "//snowflake/ml/model/_packager/model_handlers:pytorch", "//snowflake/ml/model/_packager/model_handlers:sentence_transformers", diff --git a/snowflake/ml/model/_packager/model_env/model_env.py b/snowflake/ml/model/_packager/model_env/model_env.py index a002142d..83fe3479 100644 --- a/snowflake/ml/model/_packager/model_env/model_env.py +++ b/snowflake/ml/model/_packager/model_env/model_env.py @@ -21,7 +21,7 @@ # The default CUDA version is chosen based on the driver availability in SPCS. # If changing this version, we need also change the version of default PyTorch in HuggingFace pipeline handler to # make sure they are compatible. -DEFAULT_CUDA_VERSION = "11.7" +DEFAULT_CUDA_VERSION = "11.8" class ModelEnv: @@ -199,50 +199,16 @@ def generate_env_for_cuda(self) -> None: ) if xgboost_spec: self.include_if_absent( - [ - ModelDependency( - requirement=f"conda-forge::py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost" - ) - ], + [ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")], check_local_version=False, ) - pytorch_spec = env_utils.find_dep_spec( - self._conda_dependencies, - self._pip_requirements, - conda_pkg_name="pytorch", - pip_pkg_name="torch", - remove_spec=True, - ) - pytorch_cuda_spec = env_utils.find_dep_spec( - self._conda_dependencies, - self._pip_requirements, - conda_pkg_name="pytorch-cuda", - remove_spec=False, - ) - if pytorch_cuda_spec and not pytorch_cuda_spec.specifier.contains(self.cuda_version): - raise ValueError( - "The Pytorch-CUDA requirement you specified in your conda dependencies or pip requirements is" - " conflicting with CUDA version required. Please do not specify Pytorch-CUDA dependency using conda" - " dependencies or pip requirements." - ) - if pytorch_spec: - self.include_if_absent( - [ModelDependency(requirement=f"pytorch::pytorch{pytorch_spec.specifier}", pip_name="torch")], - check_local_version=False, - ) - if not pytorch_cuda_spec: - self.include_if_absent( - [ModelDependency(requirement=f"pytorch::pytorch-cuda=={self.cuda_version}.*", pip_name="torch")], - check_local_version=False, - ) - tf_spec = env_utils.find_dep_spec( self._conda_dependencies, self._pip_requirements, conda_pkg_name="tensorflow", remove_spec=True ) if tf_spec: self.include_if_absent( - [ModelDependency(requirement=f"conda-forge::tensorflow-gpu{tf_spec.specifier}", pip_name="tensorflow")], + [ModelDependency(requirement=f"tensorflow-gpu{tf_spec.specifier}", pip_name="tensorflow")], check_local_version=False, ) @@ -252,7 +218,7 @@ def generate_env_for_cuda(self) -> None: if transformers_spec: self.include_if_absent( [ - ModelDependency(requirement="conda-forge::accelerate>=0.22.0", pip_name="accelerate"), + ModelDependency(requirement="accelerate>=0.22.0", pip_name="accelerate"), ModelDependency(requirement="scipy>=1.9", pip_name="scipy"), ], check_local_version=False, diff --git a/snowflake/ml/model/_packager/model_env/model_env_test.py b/snowflake/ml/model/_packager/model_env/model_env_test.py index 46ca0d44..527cdf75 100644 --- a/snowflake/ml/model/_packager/model_env/model_env_test.py +++ b/snowflake/ml/model/_packager/model_env/model_env_test.py @@ -29,25 +29,25 @@ def test_conda_dependencies(self) -> None: self.assertListEqual(env.conda_dependencies, ["package"]) env.conda_dependencies = ["some_package"] - self.assertListEqual(env.conda_dependencies, ["some-package"]) + self.assertListEqual(env.conda_dependencies, ["some_package"]) env.conda_dependencies = ["some_package==1.0.1"] - self.assertListEqual(env.conda_dependencies, ["some-package==1.0.1"]) + self.assertListEqual(env.conda_dependencies, ["some_package==1.0.1"]) env.conda_dependencies = ["some_package<1.2,>=1.0.1"] - self.assertListEqual(env.conda_dependencies, ["some-package<1.2,>=1.0.1"]) + self.assertListEqual(env.conda_dependencies, ["some_package<1.2,>=1.0.1"]) env.conda_dependencies = ["channel::some_package<1.2,>=1.0.1"] - self.assertListEqual(env.conda_dependencies, ["channel::some-package<1.2,>=1.0.1"]) + self.assertListEqual(env.conda_dependencies, ["channel::some_package<1.2,>=1.0.1"]) with self.assertRaisesRegex(ValueError, "Invalid package requirement _some_package<1.2,>=1.0.1 found."): env.conda_dependencies = ["channel::_some_package<1.2,>=1.0.1"] env.conda_dependencies = ["::some_package<1.2,>=1.0.1"] - self.assertListEqual(env.conda_dependencies, ["some-package<1.2,>=1.0.1"]) + self.assertListEqual(env.conda_dependencies, ["some_package<1.2,>=1.0.1"]) env.conda_dependencies = ["another==1.3", "channel::some_package<1.2,>=1.0.1"] - self.assertListEqual(env.conda_dependencies, ["another==1.3", "channel::some-package<1.2,>=1.0.1"]) + self.assertListEqual(env.conda_dependencies, ["another==1.3", "channel::some_package<1.2,>=1.0.1"]) def test_pip_requirements(self) -> None: env = model_env.ModelEnv() @@ -55,13 +55,13 @@ def test_pip_requirements(self) -> None: self.assertListEqual(env.pip_requirements, ["package"]) env.pip_requirements = ["some_package"] - self.assertListEqual(env.pip_requirements, ["some-package"]) + self.assertListEqual(env.pip_requirements, ["some_package"]) env.pip_requirements = ["some_package==1.0.1"] - self.assertListEqual(env.pip_requirements, ["some-package==1.0.1"]) + self.assertListEqual(env.pip_requirements, ["some_package==1.0.1"]) env.pip_requirements = ["some_package<1.2,>=1.0.1"] - self.assertListEqual(env.pip_requirements, ["some-package<1.2,>=1.0.1"]) + self.assertListEqual(env.pip_requirements, ["some_package<1.2,>=1.0.1"]) with self.assertRaisesRegex(ValueError, "Invalid package requirement channel::some_package<1.2,>=1.0.1 found."): env.pip_requirements = ["channel::some_package<1.2,>=1.0.1"] @@ -487,7 +487,7 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, [ - "another_channel::another-package==1.0.0", + "another_channel::another_package==1.0.0", "nvidia::cuda==11.7.*", "somepackage==1.0.0", ], @@ -506,7 +506,7 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, [ - "another_channel::another-package==1.0.0", + "another_channel::another_package==1.0.0", "nvidia::cuda>=11.7", "somepackage==1.0.0", ], @@ -535,7 +535,7 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, - ["nvidia::cuda==11.7.*", "pytorch::pytorch-cuda==11.7.*", "pytorch::pytorch==1.0.0"], + ["nvidia::cuda==11.7.*", "pytorch==1.0.0"], ) env = model_env.ModelEnv() @@ -546,7 +546,7 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, - ["nvidia::cuda==11.7.*", "pytorch::pytorch-cuda==11.7.*", "pytorch::pytorch>=1.0.0"], + ["nvidia::cuda==11.7.*", "pytorch>=1.0.0"], ) env = model_env.ModelEnv() @@ -557,21 +557,9 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, - ["nvidia::cuda==11.7.*", "pytorch::pytorch-cuda>=11.7", "pytorch::pytorch>=1.0.0"], + ["nvidia::cuda==11.7.*", "pytorch::pytorch-cuda>=11.7", "pytorch>=1.0.0"], ) - env = model_env.ModelEnv() - env.conda_dependencies = ["pytorch>=1.0.0", "pytorch::pytorch-cuda==11.8.*"] - env.cuda_version = "11.7" - - with self.assertRaisesRegex( - ValueError, - "The Pytorch-CUDA requirement you specified in your conda dependencies or pip requirements is" - " conflicting with CUDA version required. Please do not specify Pytorch-CUDA dependency using conda" - " dependencies or pip requirements.", - ): - env.generate_env_for_cuda() - env = model_env.ModelEnv() env.conda_dependencies = ["pytorch::pytorch>=1.1.0", "pytorch::pytorch-cuda==11.7.*"] env.cuda_version = "11.7" @@ -591,9 +579,8 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, - ["nvidia::cuda==11.7.*", "pytorch::pytorch-cuda==11.7.*", "pytorch::pytorch==1.0.0"], + ["conda-forge::pytorch==1.0.0", "nvidia::cuda==11.7.*"], ) - self.assertIn("conda-forge", env._conda_dependencies) env = model_env.ModelEnv() env.pip_requirements = ["torch==1.0.0"] @@ -603,9 +590,9 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, - ["nvidia::cuda==11.7.*", "pytorch::pytorch-cuda==11.7.*", "pytorch::pytorch==1.0.0"], + ["nvidia::cuda==11.7.*"], ) - self.assertListEqual(env.pip_requirements, []) + self.assertListEqual(env.pip_requirements, ["torch==1.0.0"]) env = model_env.ModelEnv() env.conda_dependencies = ["tensorflow==1.0.0"] @@ -615,7 +602,7 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, - ["conda-forge::tensorflow-gpu==1.0.0", "nvidia::cuda==11.7.*"], + ["nvidia::cuda==11.7.*", "tensorflow-gpu==1.0.0"], ) env = model_env.ModelEnv() @@ -626,7 +613,7 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, - ["conda-forge::tensorflow-gpu>=1.0.0", "nvidia::cuda==11.7.*"], + ["nvidia::cuda==11.7.*", "tensorflow-gpu>=1.0.0"], ) env = model_env.ModelEnv() @@ -649,7 +636,7 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, - ["conda-forge::tensorflow-gpu==1.0.0", "nvidia::cuda==11.7.*"], + ["nvidia::cuda==11.7.*", "tensorflow-gpu==1.0.0"], ) self.assertListEqual(env.pip_requirements, []) @@ -661,7 +648,7 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, - ["conda-forge::py-xgboost-gpu==1.0.0", "nvidia::cuda==11.7.*"], + ["nvidia::cuda==11.7.*", "py-xgboost-gpu==1.0.0"], ) env = model_env.ModelEnv() @@ -672,7 +659,7 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, - ["conda-forge::py-xgboost-gpu>=1.0.0", "nvidia::cuda==11.7.*"], + ["nvidia::cuda==11.7.*", "py-xgboost-gpu>=1.0.0"], ) env = model_env.ModelEnv() @@ -695,7 +682,7 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, - ["conda-forge::py-xgboost-gpu>=1.0.0", "nvidia::cuda==11.7.*"], + ["nvidia::cuda==11.7.*", "py-xgboost-gpu>=1.0.0"], ) env = model_env.ModelEnv() @@ -706,7 +693,7 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, - ["conda-forge::py-xgboost-gpu>=1.0.0", "nvidia::cuda==11.7.*"], + ["nvidia::cuda==11.7.*", "py-xgboost-gpu>=1.0.0"], ) self.assertListEqual(env.pip_requirements, []) @@ -719,10 +706,9 @@ def test_generate_conda_env_for_cuda(self) -> None: self.assertListEqual( env.conda_dependencies, [ - "conda-forge::accelerate>=0.22.0", + "accelerate>=0.22.0", "nvidia::cuda==11.7.*", - "pytorch::pytorch-cuda==11.7.*", - "pytorch::pytorch==1.0.0", + "pytorch==1.0.0", "scipy>=1.9", "transformers==1.0.0", ], @@ -989,7 +975,7 @@ def check_env_equality(this: model_env.ModelEnv, that: model_env.ModelEnv) -> bo conda_yml, { "channels": ["conda-forge", "channel", "nodefaults"], - "dependencies": ["python==3.10.*", "another==1.3", "channel::some-package<1.2,>=1.0.1"], + "dependencies": ["python==3.10.*", "another==1.3", "channel::some_package<1.2,>=1.0.1"], "name": "snow-env", }, ) @@ -1016,7 +1002,7 @@ def test_validate_with_local_env(self) -> None: mock_validate_local_installed_version_of_pip_package.assert_has_calls( [ mock.call(requirements.Requirement("torch==1.3")), - mock.call(requirements.Requirement("some-package<1.2,>=1.0.1")), + mock.call(requirements.Requirement("some_package<1.2,>=1.0.1")), mock.call(requirements.Requirement("pip-package<1.2,>=1.0.1")), ] ) @@ -1040,7 +1026,7 @@ def test_validate_with_local_env(self) -> None: mock_validate_local_installed_version_of_pip_package.assert_has_calls( [ mock.call(requirements.Requirement("torch==1.3")), - mock.call(requirements.Requirement("some-package<1.2,>=1.0.1")), + mock.call(requirements.Requirement("some_package<1.2,>=1.0.1")), mock.call(requirements.Requirement("pip-package<1.2,>=1.0.1")), ] ) @@ -1062,7 +1048,7 @@ def test_validate_with_local_env(self) -> None: mock_validate_local_installed_version_of_pip_package.assert_has_calls( [ mock.call(requirements.Requirement("torch==1.3")), - mock.call(requirements.Requirement("some-package<1.2,>=1.0.1")), + mock.call(requirements.Requirement("some_package<1.2,>=1.0.1")), mock.call(requirements.Requirement("pip-package<1.2,>=1.0.1")), ] ) @@ -1086,7 +1072,7 @@ def test_validate_with_local_env(self) -> None: mock_validate_local_installed_version_of_pip_package.assert_has_calls( [ mock.call(requirements.Requirement("torch==1.3")), - mock.call(requirements.Requirement("some-package<1.2,>=1.0.1")), + mock.call(requirements.Requirement("some_package<1.2,>=1.0.1")), mock.call(requirements.Requirement("pip-package<1.2,>=1.0.1")), ] ) diff --git a/snowflake/ml/model/_packager/model_handlers/BUILD.bazel b/snowflake/ml/model/_packager/model_handlers/BUILD.bazel index 091a0269..6f197622 100644 --- a/snowflake/ml/model/_packager/model_handlers/BUILD.bazel +++ b/snowflake/ml/model/_packager/model_handlers/BUILD.bazel @@ -21,6 +21,7 @@ py_library( "//snowflake/ml/model/_model_composer/model_method", "//snowflake/ml/model/_packager/model_meta", "//snowflake/ml/model/_signatures:snowpark_handler", + "//snowflake/ml/model/_signatures:utils", ], ) @@ -29,6 +30,7 @@ py_library( srcs = ["model_objective_utils.py"], deps = [ ":_utils", + "//snowflake/ml/_internal:type_utils", ], ) @@ -83,7 +85,6 @@ py_library( "//snowflake/ml/model/_packager/model_meta", "//snowflake/ml/model/_packager/model_meta:model_blob_meta", "//snowflake/ml/model/_signatures:numpy_handler", - "//snowflake/ml/model/_signatures:utils", ], ) @@ -104,6 +105,7 @@ py_library( "//snowflake/ml/model/_signatures:numpy_handler", "//snowflake/ml/model/_signatures:utils", "//snowflake/ml/modeling/framework", + "//snowflake/ml/modeling/pipeline", ], ) @@ -265,18 +267,3 @@ py_library( "//snowflake/ml/model/_signatures:utils", ], ) - -py_library( - name = "llm", - srcs = ["llm.py"], - deps = [ - ":_base", - "//snowflake/ml/_internal:file_utils", - "//snowflake/ml/model:custom_model", - "//snowflake/ml/model:model_signature", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/model/_packager/model_env", - "//snowflake/ml/model/_packager/model_meta", - "//snowflake/ml/model/models:llm_model", - ], -) diff --git a/snowflake/ml/model/_packager/model_handlers/_utils.py b/snowflake/ml/model/_packager/model_handlers/_utils.py index 445e3e78..a96047f8 100644 --- a/snowflake/ml/model/_packager/model_handlers/_utils.py +++ b/snowflake/ml/model/_packager/model_handlers/_utils.py @@ -1,17 +1,26 @@ import json +import os import warnings -from typing import Any, Callable, Iterable, Optional, Sequence, cast +from typing import Any, Callable, Iterable, List, Optional, Sequence, cast import numpy as np import numpy.typing as npt import pandas as pd from absl import logging +import snowflake.snowpark.dataframe as sp_df +from snowflake.ml._internal.utils import identifier from snowflake.ml.model import model_signature, type_hints as model_types from snowflake.ml.model._packager.model_meta import model_meta -from snowflake.ml.model._signatures import snowpark_handler +from snowflake.ml.model._signatures import ( + core, + snowpark_handler, + utils as model_signature_utils, +) from snowflake.snowpark import DataFrame as SnowparkDataFrame +EXPLAIN_BACKGROUND_DATA_ROWS_COUNT_LIMIT = 1000 + class NumpyEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: @@ -28,6 +37,18 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo return callable(getattr(model, method_name, None)) +def get_truncated_sample_data(sample_input_data: model_types.SupportedDataType) -> model_types.SupportedLocalDataType: + trunc_sample_input = model_signature._truncate_data(sample_input_data) + local_sample_input: model_types.SupportedLocalDataType = None + if isinstance(sample_input_data, SnowparkDataFrame): + # Added because of Any from missing stubs. + trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input) + local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input) + else: + local_sample_input = trunc_sample_input + return local_sample_input + + def validate_signature( model: model_types.SupportedRequireSignatureModelType, model_meta: model_meta.ModelMetadata, @@ -37,19 +58,23 @@ def validate_signature( ) -> model_meta.ModelMetadata: if model_meta.signatures: validate_target_methods(model, list(model_meta.signatures.keys())) + if sample_input_data is not None: + local_sample_input = get_truncated_sample_data(sample_input_data) + for target_method in model_meta.signatures.keys(): + + model_signature_inst = model_meta.signatures.get(target_method) + if model_signature_inst is not None: + # strict validation the input signature + model_signature._convert_and_validate_local_data( + local_sample_input, model_signature_inst._inputs, True + ) return model_meta # In this case sample_input_data should be available, because of the check in save_model. assert ( sample_input_data is not None ), "Model signature and sample input are None at the same time. This should not happen with local model." - trunc_sample_input = model_signature._truncate_data(sample_input_data) - if isinstance(sample_input_data, SnowparkDataFrame): - # Added because of Any from missing stubs. - trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input) - local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input) - else: - local_sample_input = trunc_sample_input + local_sample_input = get_truncated_sample_data(sample_input_data) for target_method in target_methods: predictions_df = get_prediction_fn(target_method, local_sample_input) sig = model_signature.infer_signature(local_sample_input, predictions_df) @@ -58,24 +83,55 @@ def validate_signature( return model_meta +def get_input_signature( + model_meta: model_meta.ModelMetadata, target_method: Optional[str] +) -> Sequence[core.BaseFeatureSpec]: + if target_method is None or target_method not in model_meta.signatures: + raise ValueError(f"Signature for target method {target_method} is missing or no method to explain.") + input_sig = model_meta.signatures[target_method].inputs + return input_sig + + def add_explain_method_signature( model_meta: model_meta.ModelMetadata, explain_method: str, - target_method: str, + target_method: Optional[str], output_return_type: model_signature.DataType = model_signature.DataType.DOUBLE, ) -> model_meta.ModelMetadata: - if target_method not in model_meta.signatures: - raise ValueError(f"Signature for target method {target_method} is missing") - inputs = model_meta.signatures[target_method].inputs + inputs = get_input_signature(model_meta, target_method) + if model_meta.model_type == "snowml": + output_feature_names = [identifier.concat_names([spec.name, "_explanation"]) for spec in inputs] + else: + output_feature_names = [f"{spec.name}_explanation" for spec in inputs] model_meta.signatures[explain_method] = model_signature.ModelSignature( inputs=inputs, outputs=[ - model_signature.FeatureSpec(dtype=output_return_type, name=f"{spec.name}_explanation") for spec in inputs + model_signature.FeatureSpec(dtype=output_return_type, name=output_name) + for output_name in output_feature_names ], ) return model_meta +def get_explainability_supported_background( + sample_input_data: Optional[model_types.SupportedDataType], + meta: model_meta.ModelMetadata, + explain_target_method: Optional[str], +) -> pd.DataFrame: + if sample_input_data is None: + return None + + if isinstance(sample_input_data, pd.DataFrame): + return sample_input_data + if isinstance(sample_input_data, sp_df.DataFrame): + return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data) + + df = model_signature._convert_local_data_to_df(sample_input_data) + input_signature_for_explain = get_input_signature(meta, explain_target_method) + df_with_named_cols = model_signature_utils.rename_pandas_df(df, input_signature_for_explain) + return df_with_named_cols + + def get_target_methods( model: model_types.SupportedModelType, target_methods: Optional[Sequence[str]], @@ -88,6 +144,23 @@ def get_target_methods( return target_methods +def save_background_data( + model_blobs_dir_path: str, + explain_artifact_dir: str, + bg_data_file_suffix: str, + model_name: str, + background_data: pd.DataFrame, +) -> None: + data_blob_path = os.path.join(model_blobs_dir_path, explain_artifact_dir) + os.makedirs(data_blob_path, exist_ok=True) + with open(os.path.join(data_blob_path, model_name + bg_data_file_suffix), "wb") as f: + # saving only the truncated data + trunc_background_data = background_data.head( + min(len(background_data.index), EXPLAIN_BACKGROUND_DATA_ROWS_COUNT_LIMIT) + ) + trunc_background_data.to_parquet(f) + + def validate_target_methods(model: model_types.SupportedModelType, target_methods: Iterable[str]) -> None: for method_name in target_methods: if not _is_callable(model, method_name): @@ -123,25 +196,26 @@ def row_to_dict(row: npt.NDArray[Any]) -> npt.NDArray[Any]: return pd.DataFrame(exp_2d) -def validate_model_objective( - passed_model_objective: model_types.ModelObjective, inferred_model_objective: model_types.ModelObjective -) -> model_types.ModelObjective: - if ( - passed_model_objective != model_types.ModelObjective.UNKNOWN - and inferred_model_objective != model_types.ModelObjective.UNKNOWN - ): - if passed_model_objective != inferred_model_objective: +def validate_model_task(passed_model_task: model_types.Task, inferred_model_task: model_types.Task) -> model_types.Task: + if passed_model_task != model_types.Task.UNKNOWN and inferred_model_task != model_types.Task.UNKNOWN: + if passed_model_task != inferred_model_task: warnings.warn( - f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for this model " - f"version and passed argument ModelObjective: {passed_model_objective.name} is ignored", + f"Inferred Task: {inferred_model_task.name} is used as task for this model " + f"version and passed argument Task: {passed_model_task.name} is ignored", category=UserWarning, stacklevel=1, ) - return inferred_model_objective - elif inferred_model_objective != model_types.ModelObjective.UNKNOWN: - logging.info( - f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for this model " - f"version" - ) - return inferred_model_objective - return passed_model_objective + return inferred_model_task + elif inferred_model_task != model_types.Task.UNKNOWN: + logging.info(f"Inferred Task: {inferred_model_task.name} is used as task for this model " f"version") + return inferred_model_task + return passed_model_task + + +def get_explain_target_method( + model_metadata: model_meta.ModelMetadata, target_methods_list: List[str] +) -> Optional[str]: + for method in model_metadata.signatures.keys(): + if method in target_methods_list: + return method + return None diff --git a/snowflake/ml/model/_packager/model_handlers/catboost.py b/snowflake/ml/model/_packager/model_handlers/catboost.py index badf1df6..0120a3e0 100644 --- a/snowflake/ml/model/_packager/model_handlers/catboost.py +++ b/snowflake/ml/model/_packager/model_handlers/catboost.py @@ -1,4 +1,5 @@ import os +import warnings from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final import numpy as np @@ -8,7 +9,11 @@ from snowflake.ml._internal import type_utils from snowflake.ml.model import custom_model, model_signature, type_hints as model_types from snowflake.ml.model._packager.model_env import model_env -from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils +from snowflake.ml.model._packager.model_handlers import ( + _base, + _utils as handlers_utils, + model_objective_utils, +) from snowflake.ml.model._packager.model_handlers_migrator import base_migrator from snowflake.ml.model._packager.model_meta import ( model_blob_meta, @@ -32,22 +37,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]): MODEL_BLOB_FILE_OR_DIR = "model.bin" DEFAULT_TARGET_METHODS = ["predict", "predict_proba"] - - @classmethod - def get_model_objective_and_output_type(cls, model: "catboost.CatBoost") -> model_types.ModelObjective: - import catboost - - if isinstance(model, catboost.CatBoostClassifier): - num_classes = handlers_utils.get_num_classes_if_exists(model) - if num_classes == 2: - return model_types.ModelObjective.BINARY_CLASSIFICATION - return model_types.ModelObjective.MULTI_CLASSIFICATION - if isinstance(model, catboost.CatBoostRanker): - return model_types.ModelObjective.RANKING - if isinstance(model, catboost.CatBoostRegressor): - return model_types.ModelObjective.REGRESSION - # TODO: Find out model type from the generic Catboost Model - return model_types.ModelObjective.UNKNOWN + EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"] @classmethod def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]: @@ -107,25 +97,34 @@ def get_prediction( sample_input_data=sample_input_data, get_prediction_fn=get_prediction, ) - inferred_model_objective = cls.get_model_objective_and_output_type(model) - model_meta.model_objective = handlers_utils.validate_model_objective( - model_meta.model_objective, inferred_model_objective - ) - model_objective = model_meta.model_objective + model_task_and_output = model_objective_utils.get_model_task_and_output_type(model) + model_meta.task = model_task_and_output.task if enable_explainability: - output_type = model_signature.DataType.DOUBLE - if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION: - output_type = model_signature.DataType.STRING + explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS) model_meta = handlers_utils.add_explain_method_signature( model_meta=model_meta, explain_method="explain", - target_method="predict", - output_return_type=output_type, + target_method=explain_target_method, + output_return_type=model_task_and_output.output_type, ) model_meta.function_properties = { "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False} } + background_data = handlers_utils.get_explainability_supported_background( + sample_input_data, model_meta, explain_target_method + ) + if background_data is not None: + handlers_utils.save_background_data( + model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data + ) + else: + warnings.warn( + "sample_input_data should be provided for better explainability results", + category=UserWarning, + stacklevel=1, + ) + model_blob_path = os.path.join(model_blobs_dir_path, name) os.makedirs(model_blob_path, exist_ok=True) model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR) diff --git a/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py b/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py index f96fc28e..c197dccc 100644 --- a/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +++ b/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py @@ -242,10 +242,10 @@ def save_model( task, spcs_only=(not type_utils.LazyType("transformers.Pipeline").isinstance(model)) ) if framework is None or framework == "pt": - # Since we set default cuda version to be 11.7, to make sure it works with GPU, we need to have a default - # Pytorch version that works with CUDA 11.7 as well. This is required for huggingface pipelines only as + # Since we set default cuda version to be 11.8, to make sure it works with GPU, we need to have a default + # Pytorch version that works with CUDA 11.8 as well. This is required for huggingface pipelines only as # users are not required to install pytorch locally if they are using the wrapper. - pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch==2.0.1", pip_name="torch")) + pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch")) elif framework == "tf": pkgs_requirements.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow")) model_meta.env.include_if_absent( diff --git a/snowflake/ml/model/_packager/model_handlers/lightgbm.py b/snowflake/ml/model/_packager/model_handlers/lightgbm.py index 779944f0..8413304f 100644 --- a/snowflake/ml/model/_packager/model_handlers/lightgbm.py +++ b/snowflake/ml/model/_packager/model_handlers/lightgbm.py @@ -1,4 +1,5 @@ import os +import warnings from typing import ( TYPE_CHECKING, Any, @@ -47,6 +48,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb MODEL_BLOB_FILE_OR_DIR = "model.pkl" DEFAULT_TARGET_METHODS = ["predict", "predict_proba"] + EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"] @classmethod def can_handle( @@ -111,21 +113,34 @@ def get_prediction( sample_input_data=sample_input_data, get_prediction_fn=get_prediction, ) - model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model) - model_meta.model_objective = handlers_utils.validate_model_objective( - model_meta.model_objective, model_objective_and_output.objective - ) + model_task_and_output = model_objective_utils.get_model_task_and_output_type(model) + model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task) if enable_explainability: + explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS) model_meta = handlers_utils.add_explain_method_signature( model_meta=model_meta, explain_method="explain", - target_method="predict", - output_return_type=model_objective_and_output.output_type, + target_method=explain_target_method, + output_return_type=model_task_and_output.output_type, ) model_meta.function_properties = { "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False} } + background_data = handlers_utils.get_explainability_supported_background( + sample_input_data, model_meta, explain_target_method + ) + if background_data is not None: + handlers_utils.save_background_data( + model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data + ) + else: + warnings.warn( + "sample_input_data should be provided for better explainability results", + category=UserWarning, + stacklevel=1, + ) + model_blob_path = os.path.join(model_blobs_dir_path, name) os.makedirs(model_blob_path, exist_ok=True) diff --git a/snowflake/ml/model/_packager/model_handlers/llm.py b/snowflake/ml/model/_packager/model_handlers/llm.py deleted file mode 100644 index 0082c760..00000000 --- a/snowflake/ml/model/_packager/model_handlers/llm.py +++ /dev/null @@ -1,269 +0,0 @@ -import logging -import os -from typing import Dict, Optional, Type, cast, final - -import cloudpickle -import pandas as pd -from typing_extensions import TypeGuard, Unpack - -from snowflake.ml._internal import file_utils -from snowflake.ml.model import custom_model, model_signature, type_hints as model_types -from snowflake.ml.model._packager.model_env import model_env -from snowflake.ml.model._packager.model_handlers import _base -from snowflake.ml.model._packager.model_handlers_migrator import base_migrator -from snowflake.ml.model._packager.model_meta import ( - model_blob_meta, - model_meta as model_meta_api, - model_meta_schema, -) -from snowflake.ml.model.models import llm - -logger = logging.getLogger(__name__) - - -@final -class LLMHandler(_base.BaseModelHandler[llm.LLM]): - HANDLER_TYPE = "llm" - HANDLER_VERSION = "2023-12-01" - _MIN_SNOWPARK_ML_VERSION = "1.0.12" - _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {} - - MODEL_BLOB_FILE_OR_DIR = "model" - LLM_META = "llm_meta" - IS_AUTO_SIGNATURE = True - - @classmethod - def can_handle( - cls, - model: model_types.SupportedModelType, - ) -> TypeGuard[llm.LLM]: - return isinstance(model, llm.LLM) - - @classmethod - def cast_model( - cls, - model: model_types.SupportedModelType, - ) -> llm.LLM: - assert isinstance(model, llm.LLM) - return cast(llm.LLM, model) - - @classmethod - def save_model( - cls, - name: str, - model: llm.LLM, - model_meta: model_meta_api.ModelMetadata, - model_blobs_dir_path: str, - sample_input_data: Optional[model_types.SupportedDataType] = None, - is_sub_model: Optional[bool] = False, - **kwargs: Unpack[model_types.LLMSaveOptions], - ) -> None: - assert not is_sub_model, "LLM can not be sub-model." - enable_explainability = kwargs.get("enable_explainability", False) - if enable_explainability: - raise NotImplementedError("Explainability is not supported for llm model.") - model_blob_path = os.path.join(model_blobs_dir_path, name) - os.makedirs(model_blob_path, exist_ok=True) - model_blob_dir_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR) - - sig = model_signature.ModelSignature( - inputs=[ - model_signature.FeatureSpec(name="input", dtype=model_signature.DataType.STRING), - ], - outputs=[ - model_signature.FeatureSpec(name="generated_text", dtype=model_signature.DataType.STRING), - ], - ) - model_meta.signatures = {"infer": sig} - if os.path.isdir(model.model_id_or_path): - file_utils.copytree(model.model_id_or_path, model_blob_dir_path) - - os.makedirs(model_blob_dir_path, exist_ok=True) - with open( - os.path.join(model_blob_dir_path, cls.LLM_META), - "wb", - ) as f: - cloudpickle.dump(model, f) - - base_meta = model_blob_meta.ModelBlobMeta( - name=name, - model_type=cls.HANDLER_TYPE, - handler_version=cls.HANDLER_VERSION, - path=cls.MODEL_BLOB_FILE_OR_DIR, - options=model_meta_schema.LLMModelBlobOptions( - { - "batch_size": model.max_batch_size, - } - ), - ) - model_meta.models[name] = base_meta - model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION - - pkgs_requirements = [ - model_env.ModelDependency(requirement="transformers>=4.32.1", pip_name="transformers"), - model_env.ModelDependency(requirement="pytorch==2.0.1", pip_name="torch"), - ] - if model.model_type == llm.SupportedLLMType.LLAMA_MODEL_TYPE.value: - pkgs_requirements = [ - model_env.ModelDependency(requirement="sentencepiece", pip_name="sentencepiece"), - model_env.ModelDependency(requirement="protobuf", pip_name="protobuf"), - *pkgs_requirements, - ] - model_meta.env.include_if_absent(pkgs_requirements, check_local_version=True) - # Recent peft versions are only available in PYPI. - model_meta.env.include_if_absent_pip(["peft==0.5.0", "vllm==0.2.1.post1"]) - - model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION) - - @classmethod - def load_model( - cls, - name: str, - model_meta: model_meta_api.ModelMetadata, - model_blobs_dir_path: str, - **kwargs: Unpack[model_types.LLMLoadOptions], - ) -> llm.LLM: - model_blob_path = os.path.join(model_blobs_dir_path, name) - if not hasattr(model_meta, "models"): - raise ValueError("Ill model metadata found.") - model_blobs_metadata = model_meta.models - if name not in model_blobs_metadata: - raise ValueError(f"Blob of model {name} does not exist.") - model_blob_metadata = model_blobs_metadata[name] - model_blob_filename = model_blob_metadata.path - model_blob_dir_path = os.path.join(model_blob_path, model_blob_filename) - assert model_blob_dir_path, "It must be a directory." - with open(os.path.join(model_blob_dir_path, cls.LLM_META), "rb") as f: - m = cloudpickle.load(f) - assert isinstance(m, llm.LLM) - if m.mode == llm.LLM.Mode.LOCAL_LORA: - # Switch to local path - m.model_id_or_path = model_blob_dir_path - return m - - @classmethod - def convert_as_custom_model( - cls, - raw_model: llm.LLM, - model_meta: model_meta_api.ModelMetadata, - background_data: Optional[pd.DataFrame] = None, - **kwargs: Unpack[model_types.LLMLoadOptions], - ) -> custom_model.CustomModel: - import gc - import tempfile - - import torch - import transformers - import vllm - - assert torch.cuda.is_available(), "LLM inference only works on GPUs." - device_count = torch.cuda.device_count() - logger.warning(f"There's total {device_count} GPUs visible to use.") - - class _LLMCustomModel(custom_model.CustomModel): - def _memory_stats(self, msg: str) -> None: - logger.warning(msg) - logger.warning(f"Torch VRAM {torch.cuda.memory_allocated()/1024**2} MB allocated.") - logger.warning(f"Torch VRAM {torch.cuda.memory_reserved()/1024**2} MB reserved.") - - def _prepare_for_pretrain(self) -> None: - hub_kwargs = { - "revision": raw_model.revision, - "token": raw_model.token, - } - model_dir_path = raw_model.model_id_or_path - tokenizer = transformers.AutoTokenizer.from_pretrained( - model_dir_path, - padding_side="right", - use_fast=False, - **hub_kwargs, - ) - if not tokenizer.pad_token: - tokenizer.pad_token = tokenizer.eos_token - tokenizer.save_pretrained(self.local_model_dir) - hf_model = transformers.AutoModelForCausalLM.from_pretrained( - model_dir_path, - device_map="auto", - torch_dtype="auto", - **hub_kwargs, - ) - hf_model.eval() - hf_model.save_pretrained(self.local_model_dir) - logger.warning(f"Model state is saved to {self.local_model_dir}.") - del tokenizer - del hf_model - gc.collect() - torch.cuda.empty_cache() - self._memory_stats("After GC on model.") - - def _prepare_for_lora(self) -> None: - self._memory_stats("Before model load & merge.") - import peft - - hub_kwargs = { - "revision": raw_model.revision, - "token": raw_model.token, - } - model_dir_path = raw_model.model_id_or_path - peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined] - model_dir_path - ) - base_model_path = peft_config.base_model_name_or_path - tokenizer = transformers.AutoTokenizer.from_pretrained( - base_model_path, - padding_side="right", - use_fast=False, - **hub_kwargs, - ) - if not tokenizer.pad_token: - tokenizer.pad_token = tokenizer.eos_token - tokenizer.save_pretrained(self.local_model_dir) - logger.warning(f"Tokenizer state is saved to {self.local_model_dir}.") - hf_model = peft.AutoPeftModelForCausalLM.from_pretrained( # type: ignore[attr-defined] - model_dir_path, - device_map="auto", - torch_dtype="auto", - **hub_kwargs, # type: ignore[arg-type] - ) - hf_model.eval() - hf_model = hf_model.merge_and_unload() - hf_model.save_pretrained(self.local_model_dir) - logger.warning(f"Merged model state is saved to {self.local_model_dir}.") - self._memory_stats("After model load & merge.") - del hf_model - gc.collect() - torch.cuda.empty_cache() - self._memory_stats("After GC on model.") - - def __init__(self, context: custom_model.ModelContext) -> None: - self.local_tmp_holder = tempfile.TemporaryDirectory() - self.local_model_dir = self.local_tmp_holder.name - if raw_model.mode == llm.LLM.Mode.LOCAL_LORA: - self._prepare_for_lora() - elif raw_model.mode == llm.LLM.Mode.REMOTE_PRETRAIN: - self._prepare_for_pretrain() - self.sampling_params = vllm.SamplingParams( - temperature=raw_model.temperature, - top_p=raw_model.top_p, - max_tokens=raw_model.max_tokens, - ) - self._init_engine() - - # This has to have same lifetime as main thread - # in order to avoid pre-maturely terminate ray. - def _init_engine(self) -> None: - tp_size = torch.cuda.device_count() if raw_model.enable_tp else 1 - self.llm_engine = vllm.LLM( - model=self.local_model_dir, - tensor_parallel_size=tp_size, - ) - - @custom_model.inference_api - def infer(self, X: pd.DataFrame) -> pd.DataFrame: - input_data = X.to_dict("list")["input"] - res = self.llm_engine.generate(input_data, self.sampling_params) - return pd.DataFrame({"generated_text": [o.outputs[0].text for o in res]}) - - llm_custom = _LLMCustomModel(custom_model.ModelContext()) - - return llm_custom diff --git a/snowflake/ml/model/_packager/model_handlers/mlflow.py b/snowflake/ml/model/_packager/model_handlers/mlflow.py index 57118f71..4620fefa 100644 --- a/snowflake/ml/model/_packager/model_handlers/mlflow.py +++ b/snowflake/ml/model/_packager/model_handlers/mlflow.py @@ -168,11 +168,6 @@ def load_model( ) -> "mlflow.pyfunc.PyFuncModel": import mlflow - if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call] - # We need to redirect the mlruns folder to a writable location in the sandbox. - tmpdir = tempfile.TemporaryDirectory(dir="/tmp") - mlflow.set_tracking_uri(f"file://{tmpdir}") - model_blob_path = os.path.join(model_blobs_dir_path, name) model_blobs_metadata = model_meta.models model_blob_metadata = model_blobs_metadata[name] @@ -183,6 +178,9 @@ def load_model( model_artifact_path = model_blob_options["artifact_path"] model_blob_filename = model_blob_metadata.path + if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call] + return mlflow.pyfunc.load_model(os.path.join(model_blob_path, model_blob_filename, model_artifact_path)) + # This is to make sure the loaded model can be saved again. with mlflow.start_run() as run: mlflow.log_artifacts( diff --git a/snowflake/ml/model/_packager/model_handlers/model_objective_utils.py b/snowflake/ml/model/_packager/model_handlers/model_objective_utils.py index ad20c45c..95572a0f 100644 --- a/snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +++ b/snowflake/ml/model/_packager/model_handlers/model_objective_utils.py @@ -2,23 +2,67 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Union +from snowflake.ml._internal import type_utils from snowflake.ml.model import model_signature, type_hints from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils if TYPE_CHECKING: + import catboost import lightgbm + import sklearn + import sklearn.pipeline import xgboost @dataclass -class ModelObjectiveAndOutputType: - objective: type_hints.ModelObjective +class ModelTaskAndOutputType: + task: type_hints.Task output_type: model_signature.DataType -def get_model_objective_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> type_hints.ModelObjective: +def get_task_skl(model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]) -> type_hints.Task: + from sklearn.base import is_classifier, is_regressor + + if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model): + return type_hints.Task.UNKNOWN + if is_regressor(model): + return type_hints.Task.TABULAR_REGRESSION + if is_classifier(model): + classes_list = getattr(model, "classes_", []) + num_classes = getattr(model, "n_classes_", None) or len(classes_list) + if isinstance(num_classes, int): + if num_classes > 2: + return type_hints.Task.TABULAR_MULTI_CLASSIFICATION + return type_hints.Task.TABULAR_BINARY_CLASSIFICATION + return type_hints.Task.UNKNOWN + return type_hints.Task.UNKNOWN + + +def get_model_task_catboost(model: "catboost.CatBoost") -> type_hints.Task: + loss_function = None + if type_utils.LazyType("catboost.CatBoost").isinstance(model): + loss_function = model.get_all_params()["loss_function"] # type: ignore[attr-defined] + + if (type_utils.LazyType("catboost.CatBoostClassifier").isinstance(model)) or model._is_classification_objective( + loss_function + ): + num_classes = handlers_utils.get_num_classes_if_exists(model) + if num_classes == 0: + return type_hints.Task.UNKNOWN + if num_classes <= 2: + return type_hints.Task.TABULAR_BINARY_CLASSIFICATION + return type_hints.Task.TABULAR_MULTI_CLASSIFICATION + if (type_utils.LazyType("catboost.CatBoostRanker").isinstance(model)) or model._is_ranking_objective(loss_function): + return type_hints.Task.TABULAR_RANKING + if (type_utils.LazyType("catboost.CatBoostRegressor").isinstance(model)) or model._is_regression_objective( + loss_function + ): + return type_hints.Task.TABULAR_REGRESSION - import lightgbm + return type_hints.Task.UNKNOWN + + +def get_model_task_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> type_hints.Task: _BINARY_CLASSIFICATION_OBJECTIVES = ["binary"] _MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"] @@ -36,81 +80,90 @@ def get_model_objective_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBM ] # does not account for cross-entropy and custom - if isinstance(model, lightgbm.LGBMClassifier): - num_classes = handlers_utils.get_num_classes_if_exists(model) - if num_classes == 2: - return type_hints.ModelObjective.BINARY_CLASSIFICATION - return type_hints.ModelObjective.MULTI_CLASSIFICATION - if isinstance(model, lightgbm.LGBMRanker): - return type_hints.ModelObjective.RANKING - if isinstance(model, lightgbm.LGBMRegressor): - return type_hints.ModelObjective.REGRESSION - model_objective = model.params["objective"] - if model_objective in _BINARY_CLASSIFICATION_OBJECTIVES: - return type_hints.ModelObjective.BINARY_CLASSIFICATION - if model_objective in _MULTI_CLASSIFICATION_OBJECTIVES: - return type_hints.ModelObjective.MULTI_CLASSIFICATION - if model_objective in _RANKING_OBJECTIVES: - return type_hints.ModelObjective.RANKING - if model_objective in _REGRESSION_OBJECTIVES: - return type_hints.ModelObjective.REGRESSION - return type_hints.ModelObjective.UNKNOWN - - -def get_model_objective_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.ModelObjective: - - import xgboost + model_task = "" + if type_utils.LazyType("lightgbm.Booster").isinstance(model): + model_task = model.params["objective"] # type: ignore[attr-defined] + elif hasattr(model, "objective_"): + model_task = model.objective_ + if model_task in _BINARY_CLASSIFICATION_OBJECTIVES: + return type_hints.Task.TABULAR_BINARY_CLASSIFICATION + if model_task in _MULTI_CLASSIFICATION_OBJECTIVES: + return type_hints.Task.TABULAR_MULTI_CLASSIFICATION + if model_task in _RANKING_OBJECTIVES: + return type_hints.Task.TABULAR_RANKING + if model_task in _REGRESSION_OBJECTIVES: + return type_hints.Task.TABULAR_REGRESSION + return type_hints.Task.UNKNOWN + + +def get_model_task_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.Task: _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"] _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"] _RANKING_OBJECTIVE_PREFIX = ["rank:"] _REGRESSION_OBJECTIVE_PREFIX = ["reg:"] - model_objective = "" - if isinstance(model, xgboost.Booster): - model_params = json.loads(model.save_config()) - model_objective = model_params.get("learner", {}).get("objective", "") + model_task = "" + if type_utils.LazyType("xgboost.Booster").isinstance(model): + model_params = json.loads(model.save_config()) # type: ignore[attr-defined] + model_task = model_params.get("learner", {}).get("objective", "") else: if hasattr(model, "get_params"): - model_objective = model.get_params().get("objective", "") + model_task = model.get_params().get("objective", "") - if isinstance(model_objective, dict): - model_objective = model_objective.get("name", "") + if isinstance(model_task, dict): + model_task = model_task.get("name", "") for classification_objective in _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX: - if classification_objective in model_objective: - return type_hints.ModelObjective.BINARY_CLASSIFICATION + if classification_objective in model_task: + return type_hints.Task.TABULAR_BINARY_CLASSIFICATION for classification_objective in _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX: - if classification_objective in model_objective: - return type_hints.ModelObjective.MULTI_CLASSIFICATION + if classification_objective in model_task: + return type_hints.Task.TABULAR_MULTI_CLASSIFICATION for ranking_objective in _RANKING_OBJECTIVE_PREFIX: - if ranking_objective in model_objective: - return type_hints.ModelObjective.RANKING + if ranking_objective in model_task: + return type_hints.Task.TABULAR_RANKING for regression_objective in _REGRESSION_OBJECTIVE_PREFIX: - if regression_objective in model_objective: - return type_hints.ModelObjective.REGRESSION - return type_hints.ModelObjective.UNKNOWN + if regression_objective in model_task: + return type_hints.Task.TABULAR_REGRESSION + return type_hints.Task.UNKNOWN -def get_model_objective_and_output_type(model: Any) -> ModelObjectiveAndOutputType: - import xgboost +def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType: + if type_utils.LazyType("xgboost.Booster").isinstance(model) or type_utils.LazyType("xgboost.XGBModel").isinstance( + model + ): + task = get_model_task_xgb(model) + output_type = model_signature.DataType.DOUBLE + if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION: + output_type = model_signature.DataType.STRING + return ModelTaskAndOutputType(task=task, output_type=output_type) - if isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel): - model_objective = get_model_objective_xgb(model) + if type_utils.LazyType("lightgbm.Booster").isinstance(model) or type_utils.LazyType( + "lightgbm.LGBMModel" + ).isinstance(model): + task = get_model_task_lightgbm(model) output_type = model_signature.DataType.DOUBLE - if model_objective == type_hints.ModelObjective.MULTI_CLASSIFICATION: + if task in [ + type_hints.Task.TABULAR_BINARY_CLASSIFICATION, + type_hints.Task.TABULAR_MULTI_CLASSIFICATION, + ]: output_type = model_signature.DataType.STRING - return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type) + return ModelTaskAndOutputType(task=task, output_type=output_type) - import lightgbm + if type_utils.LazyType("catboost.CatBoost").isinstance(model): + task = get_model_task_catboost(model) + output_type = model_signature.DataType.DOUBLE + if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION: + output_type = model_signature.DataType.STRING + return ModelTaskAndOutputType(task=task, output_type=output_type) - if isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel): - model_objective = get_model_objective_lightgbm(model) + if type_utils.LazyType("sklearn.base.BaseEstimator").isinstance(model) or type_utils.LazyType( + "sklearn.pipeline.Pipeline" + ).isinstance(model): + task = get_task_skl(model) output_type = model_signature.DataType.DOUBLE - if model_objective in [ - type_hints.ModelObjective.BINARY_CLASSIFICATION, - type_hints.ModelObjective.MULTI_CLASSIFICATION, - ]: + if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION: output_type = model_signature.DataType.STRING - return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type) + return ModelTaskAndOutputType(task=task, output_type=output_type) raise ValueError(f"Model type {type(model)} is not supported") diff --git a/snowflake/ml/model/_packager/model_handlers/sentence_transformers.py b/snowflake/ml/model/_packager/model_handlers/sentence_transformers.py index aa3f6348..43212ab4 100644 --- a/snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +++ b/snowflake/ml/model/_packager/model_handlers/sentence_transformers.py @@ -2,7 +2,6 @@ import os from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final -import cloudpickle import pandas as pd from typing_extensions import TypeGuard, Unpack @@ -120,9 +119,21 @@ def get_prediction( model_meta.env.include_if_absent( [ model_env.ModelDependency(requirement="sentence-transformers", pip_name="sentence-transformers"), + model_env.ModelDependency(requirement="transformers", pip_name="transformers"), + model_env.ModelDependency(requirement="pytorch", pip_name="torch"), ], check_local_version=True, ) + model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION) + + @staticmethod + def _get_device_config(**kwargs: Unpack[model_types.SentenceTransformersLoadOptions]) -> Optional[str]: + if kwargs.get("device", None) is not None: + return kwargs["device"] + elif kwargs.get("use_gpu", False): + return "cuda" + + return None @classmethod def load_model( @@ -144,13 +155,9 @@ def load_model( model_blob_filename = model_blob_metadata.path model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename) - if os.path.isdir(model_blob_file_or_dir_path): # if the saved model is a directory - model = sentence_transformers.SentenceTransformer(model_blob_file_or_dir_path) - else: - assert os.path.isfile(model_blob_file_or_dir_path) # if the saved model is a file - with open(model_blob_file_or_dir_path, "rb") as f: - model = cloudpickle.load(f) - assert isinstance(model, sentence_transformers.SentenceTransformer) + model = sentence_transformers.SentenceTransformer( + model_blob_file_or_dir_path, device=cls._get_device_config(**kwargs) + ) return model @classmethod diff --git a/snowflake/ml/model/_packager/model_handlers/sklearn.py b/snowflake/ml/model/_packager/model_handlers/sklearn.py index 1c9e2f11..70416372 100644 --- a/snowflake/ml/model/_packager/model_handlers/sklearn.py +++ b/snowflake/ml/model/_packager/model_handlers/sklearn.py @@ -1,4 +1,5 @@ import os +import warnings from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union, cast, final import cloudpickle @@ -6,22 +7,21 @@ import pandas as pd from typing_extensions import TypeGuard, Unpack -import snowflake.snowpark.dataframe as sp_df from snowflake.ml._internal import type_utils from snowflake.ml.model import custom_model, model_signature, type_hints as model_types from snowflake.ml.model._packager.model_env import model_env -from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils +from snowflake.ml.model._packager.model_handlers import ( + _base, + _utils as handlers_utils, + model_objective_utils, +) from snowflake.ml.model._packager.model_handlers_migrator import base_migrator from snowflake.ml.model._packager.model_meta import ( model_blob_meta, model_meta as model_meta_api, model_meta_schema, ) -from snowflake.ml.model._signatures import ( - numpy_handler, - snowpark_handler, - utils as model_signature_utils, -) +from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils if TYPE_CHECKING: import sklearn.base @@ -40,28 +40,14 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", _MIN_SNOWPARK_ML_VERSION = "1.0.12" _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {} - DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"] - - @classmethod - def get_model_objective( - cls, model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"] - ) -> model_types.ModelObjective: - import sklearn.pipeline - from sklearn.base import is_classifier, is_regressor - - if isinstance(model, sklearn.pipeline.Pipeline): - return model_types.ModelObjective.UNKNOWN - if is_regressor(model): - return model_types.ModelObjective.REGRESSION - if is_classifier(model): - classes_list = getattr(model, "classes_", []) - num_classes = getattr(model, "n_classes_", None) or len(classes_list) - if isinstance(num_classes, int): - if num_classes > 2: - return model_types.ModelObjective.MULTI_CLASSIFICATION - return model_types.ModelObjective.BINARY_CLASSIFICATION - return model_types.ModelObjective.UNKNOWN - return model_types.ModelObjective.UNKNOWN + DEFAULT_TARGET_METHODS = [ + "predict", + "transform", + "predict_proba", + "predict_log_proba", + "decision_function", + ] + EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"] @classmethod def can_handle( @@ -95,18 +81,6 @@ def cast_model( return cast(Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"], model) - @staticmethod - def get_explainability_supported_background( - sample_input_data: Optional[model_types.SupportedDataType] = None, - ) -> Optional[pd.DataFrame]: - if isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame): - return ( - sample_input_data - if isinstance(sample_input_data, pd.DataFrame) - else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data) - ) - return None - @classmethod def save_model( cls, @@ -125,23 +99,10 @@ def save_model( import sklearn.pipeline assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline) - - background_data = cls.get_explainability_supported_background(sample_input_data) - - # if users did not ask then we enable if we have background data - if enable_explainability is None and background_data is not None: - enable_explainability = True if enable_explainability: - # if users set it explicitly but no background data then error out - if background_data is None: - raise ValueError( - "Sample input data is required to enable explainability. Currently we only support this for " - + "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`." - ) - data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR) - os.makedirs(data_blob_path, exist_ok=True) - with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f: - background_data.to_parquet(f) + # if users set it explicitly but no sample_input_data then error out + if sample_input_data is None: + raise ValueError("Sample input data is required to enable explainability.") if not is_sub_model: target_methods = handlers_utils.get_target_methods( @@ -151,7 +112,8 @@ def save_model( ) def get_prediction( - target_method_name: str, sample_input_data: model_types.SupportedLocalDataType + target_method_name: str, + sample_input_data: model_types.SupportedLocalDataType, ) -> model_types.SupportedLocalDataType: if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)): sample_input_data = model_signature._convert_local_data_to_df(sample_input_data) @@ -169,19 +131,40 @@ def get_prediction( get_prediction_fn=get_prediction, ) - model_objective = cls.get_model_objective(model) - model_meta.model_objective = model_objective + explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS) + background_data = handlers_utils.get_explainability_supported_background( + sample_input_data, model_meta, explain_target_method + ) + + model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(model) + model_meta.task = model_task_and_output_type.task + + # if users did not ask then we enable if we have background data + if enable_explainability is None: + if background_data is None: + warnings.warn( + "sample_input_data should be provided to enable explainability by default", + category=UserWarning, + stacklevel=1, + ) + enable_explainability = False + else: + enable_explainability = True if enable_explainability: - output_type = model_signature.DataType.DOUBLE + handlers_utils.save_background_data( + model_blobs_dir_path, + cls.EXPLAIN_ARTIFACTS_DIR, + cls.BG_DATA_FILE_SUFFIX, + name, + background_data, + ) - if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION: - output_type = model_signature.DataType.STRING model_meta = handlers_utils.add_explain_method_signature( model_meta=model_meta, explain_method="explain", - target_method="predict", - output_return_type=output_type, + target_method=explain_target_method, + output_return_type=model_task_and_output_type.output_type, ) model_blob_path = os.path.join(model_blobs_dir_path, name) @@ -202,7 +185,8 @@ def get_prediction( model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP model_meta.env.include_if_absent( - [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], check_local_version=True + [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], + check_local_version=True, ) @classmethod diff --git a/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py b/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py index d6d51fb7..36c51365 100644 --- a/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +++ b/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py @@ -43,6 +43,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]): _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {} DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"] + EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"] + IS_AUTO_SIGNATURE = True @classmethod @@ -71,13 +73,14 @@ def cast_model( @classmethod def _get_local_version_package(cls, pkg_name: str) -> Optional[version.Version]: - import importlib_metadata + from importlib import metadata as importlib_metadata + from packaging import version local_version = None try: - local_dist = importlib_metadata.distribution(pkg_name) # type: ignore[no-untyped-call] + local_dist = importlib_metadata.distribution(pkg_name) local_version = version.parse(local_dist.version) except importlib_metadata.PackageNotFoundError: pass @@ -104,7 +107,13 @@ def _can_support_xgb(cls, enable_explainability: Optional[bool]) -> bool: def _get_supported_object_for_explainability( cls, estimator: "BaseEstimator", enable_explainability: Optional[bool] ) -> Any: - methods = ["to_xgboost", "to_lightgbm"] + from snowflake.ml.modeling import pipeline as snowml_pipeline + + # handle pipeline objects separately + if isinstance(estimator, snowml_pipeline.Pipeline): # type: ignore[attr-defined] + return None + + methods = ["to_xgboost", "to_lightgbm", "to_sklearn"] for method_name in methods: if hasattr(estimator, method_name): try: @@ -136,9 +145,9 @@ def save_model( # Pipeline is inherited from BaseEstimator, so no need to add one more check if not is_sub_model: - if sample_input_data is not None or model_meta.signatures: + if model_meta.signatures: warnings.warn( - "Inferring model signature from sample input or providing model signature for Snowpark ML " + "Providing model signature for Snowpark ML " + "Modeling model is not required. Model signature will automatically be inferred during fitting. ", UserWarning, stacklevel=2, @@ -162,22 +171,31 @@ def save_model( python_base_obj = cls._get_supported_object_for_explainability(model, enable_explainability) if python_base_obj is None: if enable_explainability: # if user set enable_explainability to True, throw error else silently skip - raise ValueError("Explain only support for xgboost or lightgbm Snowpark ML models.") + raise ValueError( + "Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models." + ) # set None to False so we don't include shap in the environment enable_explainability = False else: - model_objective_and_output_type = model_objective_utils.get_model_objective_and_output_type( - python_base_obj - ) - model_meta.model_objective = model_objective_and_output_type.objective + model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(python_base_obj) + model_meta.task = model_task_and_output_type.task + explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS) model_meta = handlers_utils.add_explain_method_signature( model_meta=model_meta, explain_method="explain", - target_method="predict", - output_return_type=model_objective_and_output_type.output_type, + target_method=explain_target_method, + output_return_type=model_task_and_output_type.output_type, ) enable_explainability = True + background_data = handlers_utils.get_explainability_supported_background( + sample_input_data, model_meta, explain_target_method + ) + if background_data is not None: + handlers_utils.save_background_data( + model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data + ) + model_blob_path = os.path.join(model_blobs_dir_path, name) os.makedirs(model_blob_path, exist_ok=True) with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f: @@ -258,6 +276,7 @@ def fn_factory( raw_model: "BaseEstimator", signature: model_signature.ModelSignature, target_method: str, + background_data: Optional[pd.DataFrame] = None, ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]: @custom_model.inference_api def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame: @@ -276,16 +295,16 @@ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame: def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame: import shap - methods = ["to_xgboost", "to_lightgbm"] + methods = ["to_xgboost", "to_lightgbm", "to_sklearn"] for method_name in methods: try: base_model = getattr(raw_model, method_name)() - explainer = shap.TreeExplainer(base_model) - df = pd.DataFrame(explainer(X).values) + explainer = shap.Explainer(base_model, masker=background_data) + df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values) return model_signature_utils.rename_pandas_df(df, signature.outputs) except exceptions.SnowflakeMLException: pass # Do nothing and continue to the next method - raise ValueError("The model must be an xgboost or lightgbm estimator.") + raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.") if target_method == "explain": return explain_fn @@ -294,7 +313,7 @@ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame: type_method_dict = {} for target_method_name, sig in model_meta.signatures.items(): - type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name) + type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name, background_data) _SnowMLModel = type( "_SnowMLModel", diff --git a/snowflake/ml/model/_packager/model_handlers/xgboost.py b/snowflake/ml/model/_packager/model_handlers/xgboost.py index 145ddc9f..2ac006ba 100644 --- a/snowflake/ml/model/_packager/model_handlers/xgboost.py +++ b/snowflake/ml/model/_packager/model_handlers/xgboost.py @@ -1,6 +1,7 @@ # mypy: disable-error-code="import" import os import warnings +from importlib import metadata as importlib_metadata from typing import ( TYPE_CHECKING, Any, @@ -13,7 +14,6 @@ final, ) -import importlib_metadata import numpy as np import pandas as pd from packaging import version @@ -53,6 +53,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X MODEL_BLOB_FILE_OR_DIR = "model.ubj" DEFAULT_TARGET_METHODS = ["predict", "predict_proba"] + EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"] @classmethod def can_handle( @@ -96,7 +97,7 @@ def save_model( local_xgb_version = None try: - local_dist = importlib_metadata.distribution("xgboost") # type: ignore[no-untyped-call] + local_dist = importlib_metadata.distribution("xgboost") local_xgb_version = version.parse(local_dist.version) except importlib_metadata.PackageNotFoundError: pass @@ -138,21 +139,35 @@ def get_prediction( sample_input_data=sample_input_data, get_prediction_fn=get_prediction, ) - model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model) - model_meta.model_objective = handlers_utils.validate_model_objective( - model_meta.model_objective, model_objective_and_output.objective - ) + model_task_and_output = model_objective_utils.get_model_task_and_output_type(model) + model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task) if enable_explainability: model_meta = handlers_utils.add_explain_method_signature( model_meta=model_meta, explain_method="explain", target_method="predict", - output_return_type=model_objective_and_output.output_type, + output_return_type=model_task_and_output.output_type, ) model_meta.function_properties = { "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False} } + explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS) + + background_data = handlers_utils.get_explainability_supported_background( + sample_input_data, model_meta, explain_target_method + ) + if background_data is not None: + handlers_utils.save_background_data( + model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data + ) + else: + warnings.warn( + "sample_input_data should be provided for better explainability results", + category=UserWarning, + stacklevel=1, + ) + model_blob_path = os.path.join(model_blobs_dir_path, name) os.makedirs(model_blob_path, exist_ok=True) model.save_model(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)) diff --git a/snowflake/ml/model/_packager/model_handlers_test/BUILD.bazel b/snowflake/ml/model/_packager/model_handlers_test/BUILD.bazel index 1e8de101..54f1e6e1 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/BUILD.bazel +++ b/snowflake/ml/model/_packager/model_handlers_test/BUILD.bazel @@ -18,6 +18,7 @@ py_test( "//snowflake/ml/model/_packager/model_handlers:_utils", "//snowflake/ml/model/_packager/model_meta", "//snowflake/ml/model/_signatures:snowpark_handler", + "//snowflake/ml/test_utils:exception_utils", ], ) diff --git a/snowflake/ml/model/_packager/model_handlers_test/_utils_test.py b/snowflake/ml/model/_packager/model_handlers_test/_utils_test.py index df627ab6..25886d62 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/_utils_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/_utils_test.py @@ -1,6 +1,8 @@ import json +from typing import cast from unittest import mock +import catboost import numpy as np import pandas as pd from absl.testing import absltest @@ -9,6 +11,7 @@ from snowflake.ml.model._packager.model_env import model_env from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils from snowflake.ml.model._packager.model_meta import model_meta +from snowflake.ml.test_utils import exception_utils class UtilTest(absltest.TestCase): @@ -106,38 +109,140 @@ def test_convert_explanations_to_2D_df_multi_value_no_class_attr(self) -> None: ) pd.testing.assert_frame_equal(explanations_df, expected_df) - def test_validate_model_objective(self) -> None: + def test_validate_model_task(self) -> None: - model_objective_list = list(type_hints.ModelObjective) - for model_objective in model_objective_list: - for inferred_model_objective in model_objective_list: - expected_model_objective = ( - inferred_model_objective - if inferred_model_objective != type_hints.ModelObjective.UNKNOWN - else model_objective - ) + task_list = list(type_hints.Task) + for task in task_list: + for inferred_task in task_list: + expected_task = inferred_task if inferred_task != type_hints.Task.UNKNOWN else task self.assertEqual( - expected_model_objective, - handlers_utils.validate_model_objective(model_objective, inferred_model_objective), + expected_task, + handlers_utils.validate_model_task(task, inferred_task), ) - if inferred_model_objective != type_hints.ModelObjective.UNKNOWN: - if model_objective == type_hints.ModelObjective.UNKNOWN: + if inferred_task != type_hints.Task.UNKNOWN: + if task == type_hints.Task.UNKNOWN: with self.assertLogs(level="INFO") as cm: - handlers_utils.validate_model_objective(model_objective, inferred_model_objective) + handlers_utils.validate_model_task(task, inferred_task) assert len(cm.output) == 1, "expecting only 1 log" log = cm.output[0] self.assertEqual( - f"INFO:absl:Inferred ModelObjective: {inferred_model_objective.name} is used as model " - f"objective for this model version", + f"INFO:absl:Inferred Task: {inferred_task.name} is used as " + f"task for this model version", log, ) - elif inferred_model_objective != model_objective: + elif inferred_task != task: with self.assertWarnsRegex( UserWarning, - f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for " - f"this model version and passed argument ModelObjective: {model_objective.name} is ignored", + f"Inferred Task: {inferred_task.name} is used as task for " + f"this model version and passed argument Task: {task.name} is ignored", ): - handlers_utils.validate_model_objective(model_objective, inferred_model_objective) + handlers_utils.validate_model_task(task, inferred_task) + + def test_validate_signature_with_signature_and_sample_data(self) -> None: + predict_sig = model_signature.ModelSignature( + inputs=[ + model_signature.FeatureSpec(dtype=model_signature.DataType.DOUBLE, name="feature1"), + model_signature.FeatureSpec(dtype=model_signature.DataType.UINT16, name="feature2"), + ], + outputs=[model_signature.FeatureSpec(dtype=model_signature.DataType.UINT16, name="output1")], + ) + + meta = model_meta.ModelMetadata( + name="name", env=model_env.ModelEnv(), model_type="custom", signatures={"predict": predict_sig} + ) + sample_data = pd.DataFrame({"feature1": [10, 20, 30], "feature2": [10, 20, 30]}) + model = catboost.CatBoostRegressor() + predict_fun = mock.MagicMock() + + # check the function is called Once only for inputs + with mock.patch( + "snowflake.ml.model.model_signature._convert_and_validate_local_data" + ) as mock_validate_local_data: + meta = handlers_utils.validate_signature(model, meta, ["predict"], sample_data, predict_fun) + mock_validate_local_data.assert_called_once() + + # comparing unsigned signature against signed values throws error + with exception_utils.assert_snowml_exceptions( + self, + expected_original_error_type=ValueError, + expected_regex="Feature type [^\\s]* is not met by all elements", + ): + sample_data = pd.DataFrame({"feature1": [10, 20, 30], "feature2": [-10, 20, 30]}) + meta = handlers_utils.validate_signature(model, meta, ["predict"], sample_data, predict_fun) + + def test_validate_signature_with_only_signature(self) -> None: + predict_sig = model_signature.ModelSignature( + inputs=[ + model_signature.FeatureSpec(dtype=model_signature.DataType.DOUBLE, name="feature1"), + model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="feature2"), + ], + outputs=[model_signature.FeatureSpec(dtype=model_signature.DataType.UINT16, name="output1")], + ) + predict_fun = mock.MagicMock() + meta = model_meta.ModelMetadata( + name="name", env=model_env.ModelEnv(), model_type="custom", signatures={"predict": predict_sig} + ) + model = catboost.CatBoostRegressor() + + # test with correct signature + with mock.patch( + "snowflake.ml.model._packager.model_handlers._utils.validate_target_methods" + ) as mock_validate_target_methods: + handlers_utils.validate_signature(model, meta, [], None, predict_fun) + mock_validate_target_methods.assert_called_once() + + # test with wrong signature. 'predict_not_callable' is not callable from model + meta = model_meta.ModelMetadata( + name="name", env=model_env.ModelEnv(), model_type="custom", signatures={"predict_not_callable": predict_sig} + ) + with self.assertRaisesRegex( + ValueError, "Target method predict_not_callable is not callable or does not exist in the model." + ): + handlers_utils.validate_signature(model, meta, [], None, predict_fun) + + def test_validate_signature_with_only_sample_data(self) -> None: + # metadata with no signatures + meta = model_meta.ModelMetadata(name="name", env=model_env.ModelEnv(), model_type="custom") + + # sample input data + data = {"feature1": [10, 20, 30], "feature2": [10.0, 20.0, 30.0], "feature3": ["a", "b", "c"]} + sample_data = pd.DataFrame(data) + + model = catboost.CatBoostRegressor() + + # mocking predict function calls + predict_fun = mock.MagicMock() + predict_fun.side_effect = lambda x, y: pd.DataFrame({"output1": [1, 2, 3]}) + + meta = handlers_utils.validate_signature(model, meta, ["predict"], sample_data, predict_fun) + self.assertEqual( + 1, + len(meta.signatures.keys()), + ) + predict_sig = model_signature.ModelSignature( + inputs=[ + model_signature.FeatureSpec(dtype=model_signature.DataType.INT64, name="feature1"), + model_signature.FeatureSpec(dtype=model_signature.DataType.DOUBLE, name="feature2"), + model_signature.FeatureSpec(dtype=model_signature.DataType.STRING, name="feature3"), + ], + outputs=[model_signature.FeatureSpec(dtype=model_signature.DataType.INT64, name="output1")], + ) + self.assertEqual(predict_sig, meta.signatures.get("predict")) + + def test_get_truncated_sample_data(self) -> None: + # the data is truncated w.r.t SIG_INFER_ROWS_COUNT_LIMIT + + # when data_size > 10 rows + df = pd.DataFrame(np.random.randint(0, 100, size=(100, 3))) + self.assertEqual(10, cast(pd.DataFrame, handlers_utils.get_truncated_sample_data(df)).shape[0]) + + # when data_size = 10 rows + df = pd.DataFrame(np.random.randint(0, 100, size=(10, 3))) + self.assertEqual(10, cast(pd.DataFrame, handlers_utils.get_truncated_sample_data(df)).shape[0]) + + # when data_size < 10 rows + df = pd.DataFrame(np.random.randint(0, 100, size=(5, 3))) + self.assertEqual(5, cast(pd.DataFrame, handlers_utils.get_truncated_sample_data(df)).shape[0]) if __name__ == "__main__": diff --git a/snowflake/ml/model/_packager/model_handlers_test/catboost_test.py b/snowflake/ml/model/_packager/model_handlers_test/catboost_test.py index e0aae3b3..94fac2bc 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/catboost_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/catboost_test.py @@ -1,6 +1,7 @@ import os import tempfile import warnings +from unittest import mock import catboost import numpy as np @@ -11,7 +12,6 @@ from snowflake.ml.model import model_signature, type_hints as model_types from snowflake.ml.model._packager import model_packager -from snowflake.ml.model._packager.model_handlers import catboost as catboost_handler from snowflake.ml.model._packager.model_handlers_test import test_utils @@ -107,12 +107,16 @@ def test_catboost_explainablity_enabled(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: s = {"predict": model_signature.infer_signature(cal_X_test, y_pred)} - model_packager.ModelPackager(os.path.join(tmpdir, "model1_default_explain")).save( - name="model1_default_explain", - model=classifier, - signatures=s, - metadata={"author": "halu", "version": "1"}, - ) + # check for warnings if sample_input_data is not provided while saving the model + with self.assertWarnsRegex( + UserWarning, "sample_input_data should be provided for better explainability results" + ): + model_packager.ModelPackager(os.path.join(tmpdir, "model1_default_explain")).save( + name="model1_default_explain", + model=classifier, + signatures=s, + metadata={"author": "halu", "version": "1"}, + ) with warnings.catch_warnings(): warnings.simplefilter("error") @@ -145,13 +149,18 @@ def test_catboost_explainablity_enabled(self) -> None: assert callable(explain_method) np.testing.assert_allclose(explain_method(cal_X_test), explanations) - model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig_explain_enabled")).save( - name="model1_no_sig_explain_enabled", - model=classifier, - sample_input_data=cal_X_test, - metadata={"author": "halu", "version": "1"}, - options=model_types.CatBoostModelSaveOptions(enable_explainability=True), - ) + # test calling saving background_data when sample_input_data is present + with mock.patch( + "snowflake.ml.model._packager.model_handlers._utils.save_background_data" + ) as save_background_data: + model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig_explain_enabled")).save( + name="model1_no_sig_explain_enabled", + model=classifier, + sample_input_data=cal_X_test, + metadata={"author": "halu", "version": "1"}, + options=model_types.CatBoostModelSaveOptions(enable_explainability=True), + ) + save_background_data.assert_called_once() pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig_explain_enabled")) pk.load(as_custom_model=True) explain_method = getattr(pk.model, "explain", None) @@ -215,40 +224,6 @@ def test_catboost_multiclass_explainablity_enabled(self) -> None: test_utils.convert2D_json_to_3D(explain_method(cal_X_test).to_numpy()), explanations ) - def test_model_objective_catboost_binary_classifier(self) -> None: - cal_data = datasets.load_breast_cancer() - cal_X = pd.DataFrame(cal_data.data, columns=cal_data.feature_names) - cal_y = pd.Series(cal_data.target) - catboost_binary_classifier = catboost.CatBoostClassifier() - catboost_binary_classifier.fit(cal_X, cal_y) - self.assertEqual( - model_types.ModelObjective.BINARY_CLASSIFICATION, - catboost_handler.CatBoostModelHandler.get_model_objective_and_output_type(catboost_binary_classifier), - ) - - def test_model_objective_catboost_multi_classifier(self) -> None: - cal_data = datasets.load_iris() - cal_X = pd.DataFrame(cal_data.data, columns=cal_data.feature_names) - cal_y = pd.Series(cal_data.target) - catboost_multi_classifier = catboost.CatBoostClassifier() - catboost_multi_classifier.fit(cal_X, cal_y) - self.assertEqual( - model_types.ModelObjective.MULTI_CLASSIFICATION, - catboost_handler.CatBoostModelHandler.get_model_objective_and_output_type(catboost_multi_classifier), - ) - - def test_model_objective_catboost_ranking(self) -> None: - self.assertEqual( - model_types.ModelObjective.RANKING, - catboost_handler.CatBoostModelHandler.get_model_objective_and_output_type(catboost.CatBoostRanker()), - ) - - def test_model_objective_catboost_regressor(self) -> None: - self.assertEqual( - model_types.ModelObjective.REGRESSION, - catboost_handler.CatBoostModelHandler.get_model_objective_and_output_type(catboost.CatBoostRegressor()), - ) - if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_packager/model_handlers_test/lightgbm_test.py b/snowflake/ml/model/_packager/model_handlers_test/lightgbm_test.py index a3ad3ba8..c8fbdecb 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/lightgbm_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/lightgbm_test.py @@ -1,6 +1,7 @@ import os import tempfile import warnings +from unittest import mock import lightgbm import numpy as np @@ -98,12 +99,16 @@ def test_lightgbm_booster_explainablity_enabled(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: s = {"predict": model_signature.infer_signature(cal_X_test, y_pred)} - model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( - name="model1", - model=regressor, - signatures=s, - metadata={"author": "halu", "version": "1"}, - ) + # check for warnings if sample_input_data is not provided while saving the model + with self.assertWarnsRegex( + UserWarning, "sample_input_data should be provided for better explainability results" + ): + model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( + name="model1", + model=regressor, + signatures=s, + metadata={"author": "halu", "version": "1"}, + ) with warnings.catch_warnings(): warnings.simplefilter("error") @@ -127,12 +132,17 @@ def test_lightgbm_booster_explainablity_enabled(self) -> None: test_utils.convert2D_json_to_3D(explain_method(cal_X_test).to_numpy()), explanations ) - model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( - name="model1_no_sig", - model=regressor, - sample_input_data=cal_X_test, - metadata={"author": "halu", "version": "1"}, - ) + # test calling saving background_data when sample_input_data is present + with mock.patch( + "snowflake.ml.model._packager.model_handlers._utils.save_background_data" + ) as save_background_data: + model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( + name="model1_no_sig", + model=regressor, + sample_input_data=cal_X_test, + metadata={"author": "halu", "version": "1"}, + ) + save_background_data.assert_called_once() pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) pk.load(as_custom_model=True) diff --git a/snowflake/ml/model/_packager/model_handlers_test/model_objective_utils_test.py b/snowflake/ml/model/_packager/model_handlers_test/model_objective_utils_test.py index d06b6fd0..324fb7e9 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/model_objective_utils_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/model_objective_utils_test.py @@ -1,5 +1,7 @@ +from itertools import groupby from typing import Any +import catboost import lightgbm import numpy as np import pandas as pd @@ -13,113 +15,180 @@ binary_dataset = datasets.load_breast_cancer() binary_data_X = pd.DataFrame(binary_dataset.data, columns=binary_dataset.feature_names) binary_data_y = pd.Series(binary_dataset.target) +single_class_y = pd.Series([0] * len(binary_dataset.target)) multiclass_data = datasets.load_iris() multiclass_data_X = pd.DataFrame(multiclass_data.data, columns=multiclass_data.feature_names) multiclass_data_y = pd.Series(multiclass_data.target) +# Make a synthetic ranking dataset for demonstration +seed = 1994 +ranking_X, ranking_y = datasets.make_classification(random_state=seed) +rng = np.random.default_rng(seed) +n_query_groups = 3 +ranking_qid = rng.integers(0, n_query_groups, size=ranking_X.shape[0]) + +# Sort the inputs based on query index +sorted_idx = np.argsort(ranking_qid) +ranking_X = ranking_X[sorted_idx, :] +ranking_y = ranking_y[sorted_idx] +ranking_qid = ranking_qid[sorted_idx] + class ModelObjectiveUtilsTest(absltest.TestCase): - def _validate_model_objective_and_output( + def _validate_model_task_and_output( self, model: Any, - expected_objective: type_hints.ModelObjective, + expected_task: type_hints.Task, expected_output: model_signature.DataType, ) -> None: - model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model) - self.assertEqual(expected_objective, model_objective_and_output.objective) - self.assertEqual(expected_output, model_objective_and_output.output_type) + model_task_and_output = model_objective_utils.get_model_task_and_output_type(model) + self.assertEqual(expected_task, model_task_and_output.task) + self.assertEqual(expected_output, model_task_and_output.output_type) - def test_model_objective_and_output_xgb_binary_classifier(self) -> None: + def test_model_task_and_output_xgb_binary_classifier(self) -> None: classifier = xgboost.XGBClassifier(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) classifier.fit(binary_data_X, binary_data_y) - self._validate_model_objective_and_output( - classifier, type_hints.ModelObjective.BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE + self._validate_model_task_and_output( + classifier, type_hints.Task.TABULAR_BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE ) - def test_model_objective_and_output_xgb_for_single_class(self) -> None: - single_class_y = pd.Series([0] * len(binary_dataset.target)) + def test_model_task_and_output_xgb_for_single_class(self) -> None: # without objective classifier = xgboost.XGBClassifier() classifier.fit(binary_data_X, single_class_y) - self._validate_model_objective_and_output( - classifier, type_hints.ModelObjective.BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE + self._validate_model_task_and_output( + classifier, type_hints.Task.TABULAR_BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE ) # with binary objective classifier = xgboost.XGBClassifier(objective="binary:logistic") classifier.fit(binary_data_X, single_class_y) - self._validate_model_objective_and_output( - classifier, type_hints.ModelObjective.BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE + self._validate_model_task_and_output( + classifier, type_hints.Task.TABULAR_BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE ) # with multiclass objective params = {"objective": "multi:softmax", "num_class": 3} classifier = xgboost.XGBClassifier(**params) classifier.fit(binary_data_X, single_class_y) - self._validate_model_objective_and_output( - classifier, type_hints.ModelObjective.MULTI_CLASSIFICATION, model_signature.DataType.STRING + self._validate_model_task_and_output( + classifier, type_hints.Task.TABULAR_MULTI_CLASSIFICATION, model_signature.DataType.STRING ) - def test_model_objective_and_output_xgb_multiclass_classifier(self) -> None: + def test_model_task_and_output_xgb_multiclass_classifier(self) -> None: classifier = xgboost.XGBClassifier() classifier.fit(multiclass_data_X, multiclass_data_y) - self._validate_model_objective_and_output( - classifier, type_hints.ModelObjective.MULTI_CLASSIFICATION, model_signature.DataType.STRING + self._validate_model_task_and_output( + classifier, type_hints.Task.TABULAR_MULTI_CLASSIFICATION, model_signature.DataType.STRING ) - def test_model_objective_and_output_xgb_regressor(self) -> None: + def test_model_task_and_output_xgb_regressor(self) -> None: regressor = xgboost.XGBRegressor() regressor.fit(multiclass_data_X, multiclass_data_y) - self._validate_model_objective_and_output( - regressor, type_hints.ModelObjective.REGRESSION, model_signature.DataType.DOUBLE + self._validate_model_task_and_output( + regressor, type_hints.Task.TABULAR_REGRESSION, model_signature.DataType.DOUBLE ) - def test_model_objective_and_output_xgb_booster(self) -> None: + def test_model_task_and_output_xgb_booster(self) -> None: params = dict(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3, objective="binary:logistic") booster = xgboost.train(params, xgboost.DMatrix(data=binary_data_X, label=binary_data_y)) - self._validate_model_objective_and_output( - booster, type_hints.ModelObjective.BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE - ) - - def test_model_objective_and_output_xgb_ranker(self) -> None: - # Make a synthetic ranking dataset for demonstration - seed = 1994 - X, y = datasets.make_classification(random_state=seed) - rng = np.random.default_rng(seed) - n_query_groups = 3 - qid = rng.integers(0, n_query_groups, size=X.shape[0]) - - # Sort the inputs based on query index - sorted_idx = np.argsort(qid) - X = X[sorted_idx, :] - y = y[sorted_idx] - qid = qid[sorted_idx] + self._validate_model_task_and_output( + booster, type_hints.Task.TABULAR_BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE + ) + + def test_model_task_and_output_xgb_ranker(self) -> None: ranker = xgboost.XGBRanker( tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk" ) - ranker.fit(X, y, qid=qid) - self._validate_model_objective_and_output( - ranker, type_hints.ModelObjective.RANKING, model_signature.DataType.DOUBLE - ) + ranker.fit(ranking_X, ranking_y, qid=ranking_qid) + self._validate_model_task_and_output(ranker, type_hints.Task.TABULAR_RANKING, model_signature.DataType.DOUBLE) - def test_model_objective_and_output_lightgbm_classifier(self) -> None: + def test_model_task_and_output_lightgbm_classifier(self) -> None: classifier = lightgbm.LGBMClassifier() classifier.fit(binary_data_X, binary_data_y) - self._validate_model_objective_and_output( - classifier, type_hints.ModelObjective.BINARY_CLASSIFICATION, model_signature.DataType.STRING + self._validate_model_task_and_output( + classifier, type_hints.Task.TABULAR_BINARY_CLASSIFICATION, model_signature.DataType.STRING + ) + + def test_model_task_and_output_lightgbm_for_single_class(self) -> None: + # without objective + classifier = lightgbm.LGBMClassifier() + classifier.fit(binary_data_X, single_class_y) + self._validate_model_task_and_output( + classifier, type_hints.Task.TABULAR_BINARY_CLASSIFICATION, model_signature.DataType.STRING + ) + # with binary objective + classifier = lightgbm.LGBMClassifier(objective="binary") + classifier.fit(binary_data_X, single_class_y) + self._validate_model_task_and_output( + classifier, type_hints.Task.TABULAR_BINARY_CLASSIFICATION, model_signature.DataType.STRING + ) + # with multiclass objective + classifier = lightgbm.LGBMClassifier(objective="multiclass", num_classes=3) + classifier.fit(binary_data_X, single_class_y) + self._validate_model_task_and_output( + classifier, type_hints.Task.TABULAR_MULTI_CLASSIFICATION, model_signature.DataType.STRING ) - def test_model_objective_and_output_lightgbm_booster(self) -> None: + def test_model_task_and_output_lightgbm_booster(self) -> None: booster = lightgbm.train({"objective": "binary"}, lightgbm.Dataset(binary_data_X, label=binary_data_y)) - self._validate_model_objective_and_output( - booster, type_hints.ModelObjective.BINARY_CLASSIFICATION, model_signature.DataType.STRING + self._validate_model_task_and_output( + booster, type_hints.Task.TABULAR_BINARY_CLASSIFICATION, model_signature.DataType.STRING + ) + + def test_model_task_and_output_lightgbm_regressor(self) -> None: + regressor = lightgbm.LGBMRegressor() + regressor.fit(multiclass_data_X, multiclass_data_y) + self._validate_model_task_and_output( + regressor, type_hints.Task.TABULAR_REGRESSION, model_signature.DataType.DOUBLE + ) + + def test_model_task_and_output_lightgbm_ranker(self) -> None: + ranker = lightgbm.LGBMRanker() + ranker.fit(ranking_X, ranking_y, group=[len(list(group)) for _, group in groupby(ranking_qid)]) + self._validate_model_task_and_output(ranker, type_hints.Task.TABULAR_RANKING, model_signature.DataType.DOUBLE) + + def test_model_task_catboost_binary_classifier(self) -> None: + classifier = catboost.CatBoostClassifier() + classifier.fit(binary_data_X, binary_data_y) + self._validate_model_task_and_output( + classifier, + type_hints.Task.TABULAR_BINARY_CLASSIFICATION, + model_signature.DataType.DOUBLE, + ) + + def test_model_task_catboost_multi_classifier(self) -> None: + classifier = catboost.CatBoostClassifier() + classifier.fit(multiclass_data_X, multiclass_data_y) + self._validate_model_task_and_output( + classifier, + type_hints.Task.TABULAR_MULTI_CLASSIFICATION, + model_signature.DataType.STRING, + ) + + def test_model_task_catboost_ranking(self) -> None: + ranker = catboost.CatBoostRanker() + ranker.fit(ranking_X, ranking_y, group_id=ranking_qid) + self._validate_model_task_and_output( + ranker, + type_hints.Task.TABULAR_RANKING, + model_signature.DataType.DOUBLE, + ) + + def test_model_task_catboost_regressor(self) -> None: + regressor = catboost.CatBoostRegressor() + regressor.fit(multiclass_data_X, multiclass_data_y) + self._validate_model_task_and_output( + regressor, + type_hints.Task.TABULAR_REGRESSION, + model_signature.DataType.DOUBLE, ) - def test_model_objective_and_output_unknown_model(self) -> None: + def test_model_task_and_output_unknown_model(self) -> None: def unknown_model(x: int) -> int: return x + 1 with self.assertRaises(ValueError) as e: - model_objective_utils.get_model_objective_and_output_type(unknown_model) + model_objective_utils.get_model_task_and_output_type(unknown_model) self.assertEqual(str(e.exception), "Model type is not supported") diff --git a/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py b/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py index 73c33bc9..a59f381d 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py @@ -1,6 +1,7 @@ import os import tempfile import warnings +from unittest import mock import numpy as np import pandas as pd @@ -115,8 +116,7 @@ def test_skl_unsupported_explain(self) -> None: s = {"predict_proba": model_signature.infer_signature(iris_X_df, model.predict_proba(iris_X_df))} with self.assertRaisesRegex( ValueError, - "Sample input data is required to enable explainability. Currently we only support this for " - + "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`.", + "Sample input data is required to enable explainability.", ): model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( name="model1", @@ -238,13 +238,13 @@ def test_skl_explain(self) -> None: regr = linear_model.LinearRegression() iris_X_df = pd.DataFrame(iris_X, columns=["c1", "c2", "c3", "c4"]) regr.fit(iris_X_df, iris_y) + explanations = shap.Explainer(regr, iris_X_df)(iris_X_df).values with tempfile.TemporaryDirectory() as tmpdir: s = {"predict": model_signature.infer_signature(iris_X_df, regr.predict(iris_X_df))} with self.assertRaisesRegex( ValueError, - "Sample input data is required to enable explainability. Currently we only support this for " - + "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`.", + "Sample input data is required to enable explainability.", ): model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( name="model1", @@ -254,6 +254,18 @@ def test_skl_explain(self) -> None: options=model_types.SKLModelSaveOptions(enable_explainability=True), ) + # test calling saving background_data when sample_input_data is present + with mock.patch( + "snowflake.ml.model._packager.model_handlers._utils.save_background_data" + ) as save_background_data: + model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( + name="model1_no_sig", + model=regr, + sample_input_data=iris_X_df, + metadata={"author": "halu", "version": "1"}, + ) + save_background_data.assert_called_once() + model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( name="model1_no_sig", model=regr, @@ -275,6 +287,36 @@ def test_skl_explain(self) -> None: np.testing.assert_allclose(np.array([[-0.08254936]]), predict_method(iris_X_df[:1])) np.testing.assert_allclose(explain_method(iris_X_df), explanations) + def test_skl_explain_with_np(self) -> None: + iris_X, iris_y = datasets.load_iris(return_X_y=True) + regr = linear_model.LinearRegression() + iris_X_df = pd.DataFrame(iris_X, columns=["c1", "c2", "c3", "c4"]) + regr.fit(iris_X_df, iris_y) + + explanations = shap.Explainer(regr, iris_X_df)(iris_X_df).values + with tempfile.TemporaryDirectory() as tmpdir: + + model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( + name="model1_no_sig", + model=regr, + sample_input_data=iris_X_df.values, + metadata={"author": "halu", "version": "1"}, + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + + pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1")) + pk.load(as_custom_model=True) + assert pk.model + assert pk.meta + predict_method = getattr(pk.model, "predict", None) + explain_method = getattr(pk.model, "explain", None) + assert callable(predict_method) + assert callable(explain_method) + np.testing.assert_allclose(np.array([[-0.08254936]]), predict_method(iris_X_df[:1])) + np.testing.assert_allclose(explain_method(iris_X_df), explanations) + def test_skl_no_default_explain_without_background_data(self) -> None: iris_X, iris_y = datasets.load_iris(return_X_y=True) regr = linear_model.LinearRegression() diff --git a/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py b/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py index 93b7a4e6..c0434346 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py @@ -33,13 +33,6 @@ def test_snowml_all_input_no_explain(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: s = {"predict": model_signature.infer_signature(df[INPUT_COLUMNS], regr.predict(df)[[OUTPUT_COLUMNS]])} - with self.assertRaisesRegex(ValueError, "Explain only support for xgboost or lightgbm Snowpark ML models."): - model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( - name="model1", - model=regr, - metadata={"author": "halu", "version": "1"}, - options={"enable_explainability": True}, - ) with self.assertWarnsRegex(UserWarning, "Model signature will automatically be inferred during fitting"): model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( @@ -49,14 +42,6 @@ def test_snowml_all_input_no_explain(self) -> None: metadata={"author": "halu", "version": "1"}, options={"enable_explainability": False}, ) - with self.assertWarnsRegex(UserWarning, "Model signature will automatically be inferred during fitting"): - model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( - name="model1_no_sig", - model=regr, - sample_input_data=df[INPUT_COLUMNS], - metadata={"author": "halu", "version": "1"}, - options={"enable_explainability": False}, - ) with tempfile.TemporaryDirectory() as tmpdir: model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( @@ -201,6 +186,43 @@ def test_snowml_xgboost_explain_default(self) -> None: np.testing.assert_allclose(predictions, predict_method(df[:1])[[OUTPUT_COLUMNS]]) np.testing.assert_allclose(explanations, explain_method(df[INPUT_COLUMNS]).values) + def test_snowml_all_input_with_explain(self) -> None: + iris = datasets.load_iris() + + df = pd.DataFrame(data=np.c_[iris["data"], iris["target"]], columns=iris["feature_names"] + ["target"]) + df.columns = [s.replace(" (CM)", "").replace(" ", "") for s in df.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + OUTPUT_COLUMNS = "PREDICTED_TARGET" + regr = LinearRegression(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + regr.fit(df) + + predictions = regr.predict(df[:1])[[OUTPUT_COLUMNS]] + explanations = shap.Explainer(regr.to_sklearn(), df[INPUT_COLUMNS])(df[INPUT_COLUMNS]).values + + with tempfile.TemporaryDirectory() as tmpdir: + model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( + name="model1", + model=regr, + sample_input_data=df[INPUT_COLUMNS], + metadata={"author": "halu", "version": "1"}, + options={"enable_explainability": True}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + + pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1")) + pk.load(as_custom_model=True) + assert pk.model + assert pk.meta + predict_method = getattr(pk.model, "predict", None) + assert callable(predict_method) + np.testing.assert_allclose(predictions, predict_method(df[:1])[[OUTPUT_COLUMNS]]) + explain_method = getattr(pk.model, "explain", None) + assert callable(explain_method) + np.testing.assert_allclose(explanations, explain_method(df[INPUT_COLUMNS]).values) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_packager/model_handlers_test/xgboost_test.py b/snowflake/ml/model/_packager/model_handlers_test/xgboost_test.py index 49ad9363..fa8560a8 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/xgboost_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/xgboost_test.py @@ -1,6 +1,7 @@ import os import tempfile import warnings +from unittest import mock import numpy as np import pandas as pd @@ -167,12 +168,16 @@ def test_xgb_explainablity_enabled(self) -> None: explanations = shap.TreeExplainer(classifier)(cal_X_test).values with tempfile.TemporaryDirectory() as tmpdir: - model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( - name="model1", - model=classifier, - signatures={"predict": model_signature.infer_signature(cal_X_test, y_pred)}, - metadata={"author": "halu", "version": "1"}, - ) + # check for warnings if sample_input_data is not provided while saving the model + with self.assertWarnsRegex( + UserWarning, "sample_input_data should be provided for better explainability results" + ): + model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( + name="model1", + model=classifier, + signatures={"predict": model_signature.infer_signature(cal_X_test, y_pred)}, + metadata={"author": "halu", "version": "1"}, + ) with warnings.catch_warnings(): warnings.simplefilter("error") @@ -186,12 +191,16 @@ def test_xgb_explainablity_enabled(self) -> None: np.testing.assert_allclose(predict_method(cal_X_test), np.expand_dims(y_pred, axis=1)) np.testing.assert_allclose(explain_method(cal_X_test), explanations) - model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( - name="model1_no_sig", - model=classifier, - sample_input_data=cal_X_test, - metadata={"author": "halu", "version": "1"}, - ) + with mock.patch( + "snowflake.ml.model._packager.model_handlers._utils.save_background_data" + ) as save_background_data: + model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( + name="model1_no_sig", + model=classifier, + sample_input_data=cal_X_test, + metadata={"author": "halu", "version": "1"}, + ) + save_background_data.assert_called_once() pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) pk.load(as_custom_model=True) diff --git a/snowflake/ml/model/_packager/model_meta/BUILD.bazel b/snowflake/ml/model/_packager/model_meta/BUILD.bazel index ab420923..823c07c4 100644 --- a/snowflake/ml/model/_packager/model_meta/BUILD.bazel +++ b/snowflake/ml/model/_packager/model_meta/BUILD.bazel @@ -2,24 +2,6 @@ load("//bazel:py_rules.bzl", "py_genrule", "py_library", "py_test") package(default_visibility = ["//visibility:public"]) -GEN_CORE_REQ_CMD = "$(location //bazel/requirements:parse_and_generate_requirements) $(location //:requirements.yml) --schema $(location //bazel/requirements:requirements.schema.json) --mode version_requirements --format python --filter_by_tag deployment_core > $@" - -py_genrule( - name = "gen_core_requirements", - srcs = [ - "//:requirements.yml", - "//bazel/requirements:requirements.schema.json", - ], - outs = ["_core_requirements.py"], - cmd = GEN_CORE_REQ_CMD, - tools = ["//bazel/requirements:parse_and_generate_requirements"], -) - -py_library( - name = "_core_requirements", - srcs = [":gen_core_requirements"], -) - GEN_PACKAGING_REQ_CMD = "$(location //bazel/requirements:parse_and_generate_requirements) $(location //:requirements.yml) --schema $(location //bazel/requirements:requirements.schema.json) --mode version_requirements --format python --filter_by_tag model_packaging > $@" py_genrule( @@ -58,7 +40,6 @@ py_library( name = "model_meta", srcs = ["model_meta.py"], deps = [ - ":_core_requirements", ":_packaging_requirements", ":model_blob_meta", ":model_meta_schema", diff --git a/snowflake/ml/model/_packager/model_meta/model_meta.py b/snowflake/ml/model/_packager/model_meta/model_meta.py index d4e7460b..48e2d515 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta.py @@ -2,7 +2,6 @@ import pathlib import sys import tempfile -import warnings import zipfile from contextlib import contextmanager from datetime import datetime @@ -18,7 +17,6 @@ from snowflake.ml.model import model_signature, type_hints as model_types from snowflake.ml.model._packager.model_env import model_env from snowflake.ml.model._packager.model_meta import ( - _core_requirements, _packaging_requirements, model_blob_meta, model_meta_schema, @@ -29,14 +27,10 @@ MODEL_METADATA_FILE = "model.yaml" MODEL_CODE_DIR = "code" -_PACKAGING_CORE_DEPENDENCIES = [ - str(env_utils.get_package_spec_with_supported_ops_only(requirements.Requirement(r))) - for r in _core_requirements.REQUIREMENTS -] # Legacy Model only _PACKAGING_REQUIREMENTS = [ str(env_utils.get_package_spec_with_supported_ops_only(requirements.Requirement(r))) for r in _packaging_requirements.REQUIREMENTS -] # New Model only +] _SNOWFLAKE_PKG_NAME = "snowflake" _SNOWFLAKE_ML_PKG_NAME = f"{_SNOWFLAKE_PKG_NAME}.ml" @@ -55,7 +49,7 @@ def create_model_metadata( conda_dependencies: Optional[List[str]] = None, pip_requirements: Optional[List[str]] = None, python_version: Optional[str] = None, - model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN, + task: model_types.Task = model_types.Task.UNKNOWN, **kwargs: Any, ) -> Generator["ModelMetadata", None, None]: """Create a generator for model metadata object. Use generator to ensure correct register and unregister for @@ -75,9 +69,9 @@ def create_model_metadata( pip_requirements: List of pip Python packages requirements for running the model. Defaults to None. python_version: A string of python version where model is run. Used for user override. If specified as None, current version would be captured. Defaults to None. - model_objective: The objective of the Model Version. It is an enum class ModelObjective with values REGRESSION, - BINARY_CLASSIFICATION, MULTI_CLASSIFICATION, RANKING, or UNKNOWN. By default it is set to - ModelObjective.UNKNOWN and may be overridden by inferring from the Model Object. + task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION, + TABULAR_BINARY_CLASSIFICATION, TABULAR_MULTI_CLASSIFICATION, TABULAR_RANKING, or UNKNOWN. By default, + it is set to Task.UNKNOWN and may be overridden by inferring from the Model Object. **kwargs: Dict of attributes and values of the metadata. Used when loading from file. Raises: @@ -88,18 +82,6 @@ def create_model_metadata( """ model_dir_path = os.path.normpath(model_dir_path) embed_local_ml_library = kwargs.pop("embed_local_ml_library", False) - legacy_save = kwargs.pop("_legacy_save", False) - if "relax_version" not in kwargs: - warnings.warn( - ( - "`relax_version` is not set and therefore defaulted to True. Dependency version constraints relaxed " - "from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility, " - "reproducibility, etc., set `options={'relax_version': False}` when logging the model." - ), - category=UserWarning, - stacklevel=2, - ) - relax_version = kwargs.pop("relax_version", True) if embed_local_ml_library: # Use the last one which is loaded first, that is mean, it is loaded from site-packages. @@ -122,7 +104,6 @@ def create_model_metadata( pip_requirements=pip_requirements, python_version=python_version, embed_local_ml_library=embed_local_ml_library, - legacy_save=legacy_save, ) if embed_local_ml_library: @@ -135,18 +116,13 @@ def create_model_metadata( model_type=model_type, signatures=signatures, function_properties=function_properties, - model_objective=model_objective, + task=task, ) code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR) - if (embed_local_ml_library and legacy_save) or code_paths: + if code_paths: os.makedirs(code_dir_path, exist_ok=True) - if embed_local_ml_library and legacy_save: - snowml_path_in_code = os.path.join(code_dir_path, _SNOWFLAKE_PKG_NAME) - os.makedirs(snowml_path_in_code, exist_ok=True) - file_utils.copy_file_or_tree(path_to_copy, snowml_path_in_code) - if code_paths: for code_path in code_paths: # This part is to prevent users from providing code following our naming and overwrite our code. @@ -165,8 +141,6 @@ def create_model_metadata( cloudpickle.register_pickle_by_value(mod) imported_modules.append(mod) yield model_meta - if relax_version: - model_meta.env.relax_version() model_meta.save(model_dir_path) finally: for mod in imported_modules: @@ -179,7 +153,6 @@ def _create_env_for_model_metadata( pip_requirements: Optional[List[str]] = None, python_version: Optional[str] = None, embed_local_ml_library: bool = False, - legacy_save: bool = False, ) -> model_env.ModelEnv: env = model_env.ModelEnv() @@ -189,7 +162,7 @@ def _create_env_for_model_metadata( env.python_version = python_version # type: ignore[assignment] env.snowpark_ml_version = snowml_env.VERSION - requirements_to_add = _PACKAGING_CORE_DEPENDENCIES if legacy_save else _PACKAGING_REQUIREMENTS + requirements_to_add = _PACKAGING_REQUIREMENTS if embed_local_ml_library: env.include_if_absent( @@ -242,7 +215,7 @@ class ModelMetadata: function_properties: A dict mapping function names to dict mapping function property key to value. metadata: User provided key-value metadata of the model. Defaults to None. creation_timestamp: Unix timestamp when the model metadata is created. - model_objective: Model objective like regression, classification etc. + task: Model task like TABULAR_REGRESSION, tabular_classification, timeseries_forecasting etc. """ def telemetry_metadata(self) -> ModelMetadataTelemetryDict: @@ -266,7 +239,7 @@ def __init__( min_snowpark_ml_version: Optional[str] = None, models: Optional[Dict[str, model_blob_meta.ModelBlobMeta]] = None, original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION, - model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN, + task: model_types.Task = model_types.Task.UNKNOWN, explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None, ) -> None: self.name = name @@ -292,7 +265,7 @@ def __init__( self.original_metadata_version = original_metadata_version - self.model_objective: model_types.ModelObjective = model_objective + self.task: model_types.Task = task self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm @property @@ -309,10 +282,10 @@ def runtimes(self) -> Dict[str, model_runtime.ModelRuntime]: if self._runtimes and "cpu" in self._runtimes: return self._runtimes runtimes = { - "cpu": model_runtime.ModelRuntime("cpu", self.env), + "cpu": model_runtime.ModelRuntime("cpu", self.env, is_warehouse=False), } if self.env.cuda_version: - runtimes.update({"gpu": model_runtime.ModelRuntime("gpu", self.env, is_gpu=True)}) + runtimes.update({"gpu": model_runtime.ModelRuntime("gpu", self.env, is_warehouse=False, is_gpu=True)}) return runtimes def save(self, model_dir_path: str) -> None: @@ -346,7 +319,7 @@ def save(self, model_dir_path: str) -> None: "signatures": {func_name: sig.to_dict() for func_name, sig in self.signatures.items()}, "version": model_meta_schema.MODEL_METADATA_VERSION, "min_snowpark_ml_version": self.min_snowpark_ml_version, - "model_objective": self.model_objective.value, + "task": self.task.value, "explainability": ( model_meta_schema.ExplainabilityMetadataDict(algorithm=self.explain_algorithm.value) if self.explain_algorithm @@ -390,7 +363,7 @@ def _validate_model_metadata(loaded_meta: Any) -> model_meta_schema.ModelMetadat signatures=loaded_meta["signatures"], version=original_loaded_meta_version, min_snowpark_ml_version=loaded_meta_min_snowpark_ml_version, - model_objective=loaded_meta.get("model_objective", model_types.ModelObjective.UNKNOWN.value), + task=loaded_meta.get("task", model_types.Task.UNKNOWN.value), explainability=loaded_meta.get("explainability", None), function_properties=loaded_meta.get("function_properties", {}), ) @@ -445,9 +418,7 @@ def load(cls, model_dir_path: str) -> "ModelMetadata": min_snowpark_ml_version=model_dict["min_snowpark_ml_version"], models=models, original_metadata_version=model_dict["version"], - model_objective=model_types.ModelObjective( - model_dict.get("model_objective", model_types.ModelObjective.UNKNOWN.value) - ), + task=model_types.Task(model_dict.get("task", model_types.Task.UNKNOWN.value)), explain_algorithm=explanation_algorithm, function_properties=model_dict.get("function_properties", {}), ) diff --git a/snowflake/ml/model/_packager/model_meta/model_meta_schema.py b/snowflake/ml/model/_packager/model_meta/model_meta_schema.py index 8b46a19b..fc534c6e 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta_schema.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta_schema.py @@ -50,10 +50,6 @@ class LightGBMModelBlobOptions(BaseModelBlobOptions): lightgbm_estimator_type: Required[str] -class LLMModelBlobOptions(BaseModelBlobOptions): - batch_size: Required[int] - - class MLFlowModelBlobOptions(BaseModelBlobOptions): artifact_path: Required[str] @@ -65,7 +61,6 @@ class XgboostModelBlobOptions(BaseModelBlobOptions): ModelBlobOptions = Union[ BaseModelBlobOptions, HuggingFacePipelineModelBlobOptions, - LLMModelBlobOptions, MLFlowModelBlobOptions, XgboostModelBlobOptions, ] @@ -96,7 +91,7 @@ class ModelMetadataDict(TypedDict): signatures: Required[Dict[str, Dict[str, Any]]] version: Required[str] min_snowpark_ml_version: Required[str] - model_objective: Required[str] + task: Required[str] explainability: NotRequired[Optional[ExplainabilityMetadataDict]] function_properties: NotRequired[Dict[str, Dict[str, Any]]] diff --git a/snowflake/ml/model/_packager/model_meta/model_meta_test.py b/snowflake/ml/model/_packager/model_meta/model_meta_test.py index 0b2b7f4e..8e21b0b5 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta_test.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta_test.py @@ -28,50 +28,6 @@ name="model1", model_type="custom", path="mock_path", handler_version="version_0" ) -_BASIC_DEPENDENCIES_TARGET = list( - sorted( - map( - lambda x: str(env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x))), - model_meta._PACKAGING_CORE_DEPENDENCIES, - ) - ) -) - -_BASIC_DEPENDENCIES_TARGET_RELAXED = list( - sorted( - map( - lambda x: str( - env_utils.relax_requirement_version( - env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x)) - ) - ), - model_meta._PACKAGING_CORE_DEPENDENCIES, - ) - ) -) - -_BASIC_DEPENDENCIES_TARGET_WITH_SNOWML = list( - sorted( - map( - lambda x: str(env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x))), - model_meta._PACKAGING_CORE_DEPENDENCIES + [env_utils.SNOWPARK_ML_PKG_NAME], - ) - ) -) - -_BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED = list( - sorted( - map( - lambda x: str( - env_utils.relax_requirement_version( - env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x)) - ) - ), - model_meta._PACKAGING_CORE_DEPENDENCIES + [env_utils.SNOWPARK_ML_PKG_NAME], - ) - ) -) - _PACKAGING_REQUIREMENTS_TARGET = list( sorted( map( @@ -81,18 +37,6 @@ ) ) -_PACKAGING_REQUIREMENTS_TARGET_RELAXED = list( - sorted( - map( - lambda x: str( - env_utils.relax_requirement_version( - env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x)) - ) - ), - model_meta._PACKAGING_REQUIREMENTS, - ) - ) -) _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML = list( sorted( @@ -103,291 +47,29 @@ ) ) -_PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED = list( - sorted( - map( - lambda x: str( - env_utils.relax_requirement_version( - env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x)) - ) - ), - model_meta._PACKAGING_REQUIREMENTS + [env_utils.SNOWPARK_ML_PKG_NAME], - ) - ) -) - - -class ModelMetaEnvLegacyTest(absltest.TestCase): - def test_model_meta_dependencies_no_packages(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - with model_meta.create_model_metadata( - model_dir_path=tmpdir, name="model1", model_type="custom", signatures=_DUMMY_SIG, _legacy_save=True - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED) - self.assertEqual(meta.env.snowpark_ml_version, snowml_env.VERSION) - - loaded_meta = model_meta.ModelMetadata.load(tmpdir) - - self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED) - self.assertEqual(meta.env.snowpark_ml_version, snowml_env.VERSION) - - def test_model_meta_dependencies_no_packages_embedded_snowml_strict(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - embed_local_ml_library=True, - _legacy_save=True, - relax_version=False, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET) - self.assertIsNotNone(meta.env._snowpark_ml_version.local) - - loaded_meta = model_meta.ModelMetadata.load(tmpdir) - - self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET) - - def test_model_meta_dependencies_no_packages_embedded_snowml(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - embed_local_ml_library=True, - _legacy_save=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET_RELAXED) - self.assertIsNotNone(meta.env._snowpark_ml_version.local) - - loaded_meta = model_meta.ModelMetadata.load(tmpdir) - - self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET_RELAXED) - self.assertIsNotNone(meta.env._snowpark_ml_version.local) - - def test_model_meta_dependencies_dup_basic_dep(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - conda_dependencies=["cloudpickle"], - relax_version=False, - _legacy_save=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] - dep_target.remove(f"cloudpickle=={importlib_metadata.version('cloudpickle')}") - dep_target.append("cloudpickle") - dep_target.sort() - - self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, dep_target) - - loaded_meta = model_meta.ModelMetadata.load(tmpdir) - - self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, dep_target) - - def test_model_meta_dependencies_dup_basic_dep_other_channel(self) -> None: - with self.assertWarns(UserWarning): - with tempfile.TemporaryDirectory() as tmpdir: - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - conda_dependencies=["conda-forge::cloudpickle"], - relax_version=False, - _legacy_save=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] - dep_target.remove(f"cloudpickle=={importlib_metadata.version('cloudpickle')}") - dep_target.append("conda-forge::cloudpickle") - dep_target.sort() - - self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, dep_target) - - with self.assertWarns(UserWarning): - loaded_meta = model_meta.ModelMetadata.load(tmpdir) - - self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, dep_target) - - def test_model_meta_dependencies_dup_basic_dep_pip(self) -> None: - with self.assertWarns(UserWarning): - with tempfile.TemporaryDirectory() as tmpdir: - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - pip_requirements=["cloudpickle"], - relax_version=False, - _legacy_save=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] - dep_target.remove(f"cloudpickle=={importlib_metadata.version('cloudpickle')}") - dep_target.sort() - - self.assertListEqual(meta.env.pip_requirements, ["cloudpickle"]) - self.assertListEqual(meta.env.conda_dependencies, dep_target) - - with self.assertWarns(UserWarning): - loaded_meta = model_meta.ModelMetadata.load(tmpdir) - - self.assertListEqual(loaded_meta.env.pip_requirements, ["cloudpickle"]) - self.assertListEqual(loaded_meta.env.conda_dependencies, dep_target) - - def test_model_meta_dependencies_conda(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - conda_dependencies=["pytorch==2.0.1"], - _legacy_save=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] - dep_target.append("pytorch<3,>=2.0") - dep_target.sort() - - self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, dep_target) - - loaded_meta = model_meta.ModelMetadata.load(tmpdir) - - self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, dep_target) - - def test_model_meta_dependencies_conda_additional_package(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - _legacy_save=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - meta.env.include_if_absent([model_env.ModelDependency("pytorch==2.0.1", "torch")]) - - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] - dep_target.append("pytorch<3,>=2.0") - dep_target.sort() - - self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, dep_target) - - loaded_meta = model_meta.ModelMetadata.load(tmpdir) - - self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, dep_target) - - def test_model_meta_dependencies_pip(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - pip_requirements=["torch"], - _legacy_save=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] - dep_target.sort() - - self.assertListEqual(meta.env.pip_requirements, ["torch"]) - self.assertListEqual(meta.env.conda_dependencies, dep_target) - - loaded_meta = model_meta.ModelMetadata.load(tmpdir) - - self.assertListEqual(loaded_meta.env.pip_requirements, ["torch"]) - self.assertListEqual(loaded_meta.env.conda_dependencies, dep_target) - - def test_model_meta_dependencies_both(self) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - conda_dependencies=["pytorch"], - pip_requirements=["torch"], - _legacy_save=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] - dep_target.append("pytorch") - dep_target.sort() - - self.assertListEqual(meta.env.pip_requirements, ["torch"]) - self.assertListEqual(meta.env.conda_dependencies, dep_target) - - loaded_meta = model_meta.ModelMetadata.load(tmpdir) - - self.assertListEqual(loaded_meta.env.pip_requirements, ["torch"]) - self.assertListEqual(loaded_meta.env.conda_dependencies, dep_target) - class ModelMetaEnvTest(absltest.TestCase): def test_model_meta_dependencies(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: with model_meta.create_model_metadata( - model_dir_path=tmpdir, name="model1", model_type="custom", signatures=_DUMMY_SIG, relax_version=True + model_dir_path=tmpdir, name="model1", model_type="custom", signatures=_DUMMY_SIG ) as meta: meta.models["model1"] = _DUMMY_BLOB self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED) + self.assertListEqual(meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML) self.assertEqual(meta.env.snowpark_ml_version, snowml_env.VERSION) loaded_meta = model_meta.ModelMetadata.load(tmpdir) self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED) + self.assertListEqual(loaded_meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML) self.assertEqual(meta.env.snowpark_ml_version, snowml_env.VERSION) - with self.assertWarnsRegex(UserWarning, "`relax_version` is not set and therefore defaulted to True."): - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - def test_model_meta_dependencies_no_relax(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: with model_meta.create_model_metadata( - model_dir_path=tmpdir, name="model1", model_type="custom", signatures=_DUMMY_SIG, relax_version=False + model_dir_path=tmpdir, name="model1", model_type="custom", signatures=_DUMMY_SIG ) as meta: meta.models["model1"] = _DUMMY_BLOB @@ -413,13 +95,13 @@ def test_model_meta_dependencies_no_packages_embedded_snowml(self) -> None: meta.models["model1"] = _DUMMY_BLOB self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_RELAXED) + self.assertListEqual(meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET) self.assertIsNotNone(meta.env._snowpark_ml_version.local) loaded_meta = model_meta.ModelMetadata.load(tmpdir) self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_RELAXED) + self.assertListEqual(loaded_meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET) self.assertIsNotNone(meta.env._snowpark_ml_version.local) def test_model_meta_dependencies_no_packages_embedded_snowml_strict(self) -> None: @@ -430,7 +112,6 @@ def test_model_meta_dependencies_no_packages_embedded_snowml_strict(self) -> Non model_type="custom", signatures=_DUMMY_SIG, embed_local_ml_library=True, - relax_version=False, ) as meta: meta.models["model1"] = _DUMMY_BLOB @@ -452,7 +133,6 @@ def test_model_meta_dependencies_dup_basic_dep(self) -> None: model_type="custom", signatures=_DUMMY_SIG, conda_dependencies=["cloudpickle"], - relax_version=False, ) as meta: meta.models["model1"] = _DUMMY_BLOB @@ -478,7 +158,6 @@ def test_model_meta_dependencies_dup_basic_dep_other_channel(self) -> None: model_type="custom", signatures=_DUMMY_SIG, conda_dependencies=["conda-forge::cloudpickle"], - relax_version=False, ) as meta: meta.models["model1"] = _DUMMY_BLOB @@ -505,7 +184,6 @@ def test_model_meta_dependencies_dup_basic_dep_pip(self) -> None: model_type="custom", signatures=_DUMMY_SIG, pip_requirements=["cloudpickle"], - relax_version=False, ) as meta: meta.models["model1"] = _DUMMY_BLOB @@ -533,8 +211,8 @@ def test_model_meta_dependencies_conda(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED[:] - dep_target.append("pytorch<3,>=2.0") + dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML[:] + dep_target.append("pytorch==2.0.1") dep_target.sort() self.assertListEqual(meta.env.pip_requirements, []) @@ -556,8 +234,8 @@ def test_model_meta_dependencies_conda_additional_package(self) -> None: meta.models["model1"] = _DUMMY_BLOB meta.env.include_if_absent([model_env.ModelDependency("pytorch==2.0.1", "torch")]) - dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED[:] - dep_target.append("pytorch<3,>=2.0") + dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML[:] + dep_target.append("pytorch==2.0.1") dep_target.sort() self.assertListEqual(meta.env.pip_requirements, []) @@ -579,7 +257,7 @@ def test_model_meta_dependencies_pip(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED[:] + dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML[:] dep_target.sort() self.assertListEqual(meta.env.pip_requirements, ["torch"]) @@ -602,7 +280,7 @@ def test_model_meta_dependencies_both(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED[:] + dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML[:] dep_target.append("pytorch") dep_target.sort() @@ -645,7 +323,6 @@ def test_model_meta_metadata(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - self.assertEqual(meta.model_objective, type_hints.ModelObjective.UNKNOWN) self.assertEqual(meta.explain_algorithm, None) saved_meta = meta @@ -686,10 +363,10 @@ def test_model_meta_model_specified_objective(self) -> None: metadata={"foo": "bar"}, ) as meta: meta.models["model1"] = _DUMMY_BLOB - meta.model_objective = type_hints.ModelObjective.REGRESSION + meta.task = type_hints.Task.TABULAR_REGRESSION loaded_meta = model_meta.ModelMetadata.load(tmpdir) - self.assertEqual(loaded_meta.model_objective, type_hints.ModelObjective.REGRESSION) + self.assertEqual(loaded_meta.task, type_hints.Task.TABULAR_REGRESSION) def test_model_meta_explain_algorithm(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: @@ -701,11 +378,11 @@ def test_model_meta_explain_algorithm(self) -> None: metadata={"foo": "bar"}, ) as meta: meta.models["model1"] = _DUMMY_BLOB - meta.model_objective = type_hints.ModelObjective.REGRESSION + meta.task = type_hints.Task.TABULAR_REGRESSION meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP loaded_meta = model_meta.ModelMetadata.load(tmpdir) - self.assertEqual(loaded_meta.model_objective, type_hints.ModelObjective.REGRESSION) + self.assertEqual(loaded_meta.task, type_hints.Task.TABULAR_REGRESSION) self.assertEqual(loaded_meta.explain_algorithm, model_meta_schema.ModelExplainAlgorithm.SHAP) def test_model_meta_new_fields(self) -> None: @@ -797,16 +474,16 @@ def test_model_meta_runtimes_gpu(self) -> None: with open(os.path.join(tmpdir, "runtimes", "cpu", "env", "conda.yml"), encoding="utf-8") as f: self.assertListEqual(yaml.safe_load(f)["channels"], ["conda-forge", "nodefaults"]) self.assertContainsSubset( - ["nvidia::cuda==11.7.*", "pytorch::pytorch", "pytorch::pytorch-cuda==11.7.*"], + ["nvidia::cuda==11.7.*", "pytorch"], meta.runtimes["gpu"].runtime_env.conda_dependencies, ) with open(os.path.join(tmpdir, "runtimes", "gpu", "env", "conda.yml"), encoding="utf-8") as f: - self.assertListEqual(yaml.safe_load(f)["channels"], ["conda-forge", "pytorch", "nvidia", "nodefaults"]) + self.assertListEqual(yaml.safe_load(f)["channels"], ["conda-forge", "nvidia", "nodefaults"]) loaded_meta = model_meta.ModelMetadata.load(tmpdir) self.assertContainsSubset(["pytorch"], loaded_meta.runtimes["cpu"].runtime_env.conda_dependencies) self.assertContainsSubset( - ["nvidia::cuda==11.7.*", "pytorch::pytorch", "pytorch::pytorch-cuda==11.7.*"], + ["nvidia::cuda==11.7.*", "pytorch"], loaded_meta.runtimes["gpu"].runtime_env.conda_dependencies, ) diff --git a/snowflake/ml/model/_packager/model_packager.py b/snowflake/ml/model/_packager/model_packager.py index 984381ef..eac41036 100644 --- a/snowflake/ml/model/_packager/model_packager.py +++ b/snowflake/ml/model/_packager/model_packager.py @@ -47,8 +47,9 @@ def save( ext_modules: Optional[List[ModuleType]] = None, code_paths: Optional[List[str]] = None, options: Optional[model_types.ModelSaveOption] = None, - model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN, + task: model_types.Task = model_types.Task.UNKNOWN, ) -> model_meta.ModelMetadata: + if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model): raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_ARGUMENT, @@ -57,17 +58,20 @@ def save( ), ) - if (signatures is not None) and (sample_input_data is not None): - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - "Signatures and sample_input_data both cannot be specified at the same time." - ), - ) - if not options: options = model_types.BaseModelSaveOption() + # here handling the case of enable_explainability is False/None + enable_explainability = options.get("enable_explainability", None) + if enable_explainability is False or enable_explainability is None: + if (signatures is not None) and (sample_input_data is not None): + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_ARGUMENT, + original_exception=ValueError( + "Signatures and sample_input_data both cannot be specified at the same time." + ), + ) + handler = model_handler.find_handler(model) if handler is None: raise snowml_exceptions.SnowflakeMLException( @@ -85,7 +89,7 @@ def save( conda_dependencies=conda_dependencies, pip_requirements=pip_requirements, python_version=python_version, - model_objective=model_objective, + task=task, **options, ) as meta: model_blobs_path = os.path.join(self.local_dir_path, ModelPackager.MODEL_BLOBS_DIR) diff --git a/snowflake/ml/model/_packager/model_packager_test.py b/snowflake/ml/model/_packager/model_packager_test.py index 9d83f4e5..feb75583 100644 --- a/snowflake/ml/model/_packager/model_packager_test.py +++ b/snowflake/ml/model/_packager/model_packager_test.py @@ -110,7 +110,9 @@ def test_zipimport_snowml(self) -> None: model=lm, sample_input_data=d, metadata={"author": "halu", "version": "1"}, - options={"embed_local_ml_library": True, "_legacy_save": True}, + options={ + "embed_local_ml_library": True, + }, ) self.assertTrue( os.path.exists( @@ -130,6 +132,7 @@ def test_save_validation_1(self) -> None: d = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) pk = model_packager.ModelPackager(os.path.join(workspace, "model1")) + # exception thrown when enable_explainability is not set with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, @@ -142,6 +145,20 @@ def test_save_validation_1(self) -> None: signatures={"predict": model_signature.ModelSignature(inputs=[], outputs=[])}, ) + # exception thrown when enable_explainability is set to False + with exception_utils.assert_snowml_exceptions( + self, + expected_original_error_type=ValueError, + expected_regex="Signatures and sample_input_data both cannot be specified at the same time.", + ): + pk.save( + name="model1", + model=linear_model.LinearRegression(), + sample_input_data=d, + signatures={"predict": model_signature.ModelSignature(inputs=[], outputs=[])}, + options={"enable_explainability": False}, + ) + with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, @@ -173,14 +190,14 @@ def test_save_validation_2(self) -> None: name="model1", model=regr, metadata={"author": "halu", "version": "1"}, - model_objective=type_hints.ModelObjective.REGRESSION, + task=type_hints.Task.TABULAR_REGRESSION, ) pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1")) pk.load() assert pk.model assert pk.meta - self.assertEqual(type_hints.ModelObjective.REGRESSION, pk.meta.model_objective) + self.assertEqual(type_hints.Task.TABULAR_REGRESSION, pk.meta.task) assert isinstance(pk.model, LinearRegression) np.testing.assert_allclose(predictions, desired=pk.model.predict(df[:1])[[OUTPUT_COLUMNS]]) diff --git a/snowflake/ml/model/_packager/model_runtime/model_runtime.py b/snowflake/ml/model/_packager/model_runtime/model_runtime.py index 81502522..98c7cc7e 100644 --- a/snowflake/ml/model/_packager/model_runtime/model_runtime.py +++ b/snowflake/ml/model/_packager/model_runtime/model_runtime.py @@ -36,6 +36,7 @@ def __init__( name: str, env: model_env.ModelEnv, imports: Optional[List[str]] = None, + is_warehouse: bool = False, is_gpu: bool = False, loading_from_file: bool = False, ) -> None: @@ -60,6 +61,16 @@ def __init__( ], ) + if not is_warehouse and self.embed_local_ml_library: + self.runtime_env.include_if_absent( + [ + model_env.ModelDependency( + requirement="pyarrow", + pip_name="pyarrow", + ) + ], + ) + if is_gpu: self.runtime_env.generate_env_for_cuda() diff --git a/snowflake/ml/model/_packager/model_runtime/model_runtime_test.py b/snowflake/ml/model/_packager/model_runtime/model_runtime_test.py index 03f55db8..f21707c1 100644 --- a/snowflake/ml/model/_packager/model_runtime/model_runtime_test.py +++ b/snowflake/ml/model/_packager/model_runtime/model_runtime_test.py @@ -125,6 +125,33 @@ def test_model_runtime_local_snowml(self) -> None: ) returned_dict = mr.save(pathlib.Path(workspace)) + self.assertDictEqual( + returned_dict, + { + "imports": ["runtimes/cpu/snowflake-ml-python.zip"], + "dependencies": { + "conda": "runtimes/cpu/env/conda.yml", + "pip": "runtimes/cpu/env/requirements.txt", + }, + }, + ) + with open(os.path.join(workspace, "runtimes/cpu/env/conda.yml"), encoding="utf-8") as f: + dependencies = yaml.safe_load(f) + + self.assertContainsSubset(_BASIC_DEPENDENCIES_TARGET_RELAXED + ["pyarrow"], dependencies["dependencies"]) + + def test_model_runtime_local_snowml_warehouse(self) -> None: + with tempfile.TemporaryDirectory() as workspace: + m_env = model_env.ModelEnv() + m_env.snowpark_ml_version = "1.0.0+abcdef" + + mr = model_runtime.ModelRuntime( + "cpu", + m_env, + is_warehouse=True, + ) + returned_dict = mr.save(pathlib.Path(workspace)) + self.assertDictEqual( returned_dict, { @@ -139,6 +166,7 @@ def test_model_runtime_local_snowml(self) -> None: dependencies = yaml.safe_load(f) self.assertContainsSubset(_BASIC_DEPENDENCIES_TARGET_RELAXED, dependencies["dependencies"]) + self.assertNotIn("pyarrow", dependencies["dependencies"]) def test_model_runtime_dup_basic_dep(self) -> None: with tempfile.TemporaryDirectory() as workspace: @@ -269,7 +297,7 @@ def test_model_runtime_gpu(self) -> None: dependencies = yaml.safe_load(f) self.assertContainsSubset( - ["nvidia::cuda==11.7.*", "pytorch::pytorch", "pytorch::pytorch-cuda==11.7.*"], + ["python==3.8.*", "pytorch", "snowflake-ml-python==1.0.0", "nvidia::cuda==11.7.*"], dependencies["dependencies"], ) diff --git a/snowflake/ml/model/_signatures/BUILD.bazel b/snowflake/ml/model/_signatures/BUILD.bazel index f8a58174..af07fbe0 100644 --- a/snowflake/ml/model/_signatures/BUILD.bazel +++ b/snowflake/ml/model/_signatures/BUILD.bazel @@ -168,7 +168,6 @@ py_library( "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/model:type_hints", - "//snowflake/ml/model/_deploy_client/warehouse:infer_template", ], ) diff --git a/snowflake/ml/model/_signatures/snowpark_handler.py b/snowflake/ml/model/_signatures/snowpark_handler.py index 500644ea..4d85e6f4 100644 --- a/snowflake/ml/model/_signatures/snowpark_handler.py +++ b/snowflake/ml/model/_signatures/snowpark_handler.py @@ -14,9 +14,10 @@ ) from snowflake.ml._internal.utils import identifier from snowflake.ml.model import type_hints as model_types -from snowflake.ml.model._deploy_client.warehouse import infer_template from snowflake.ml.model._signatures import base_handler, core, pandas_handler +_KEEP_ORDER_COL_NAME = "_ID" + class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.DataFrame]): @staticmethod @@ -109,7 +110,7 @@ def convert_from_df( # Role will be no effect on the column index. That is to say, the feature name is the actual column name. if keep_order: df = df.reset_index(drop=True) - df[infer_template._KEEP_ORDER_COL_NAME] = df.index + df[_KEEP_ORDER_COL_NAME] = df.index sp_df = session.create_dataframe(df) column_names = [] columns = [] diff --git a/snowflake/ml/model/deploy_platforms.py b/snowflake/ml/model/deploy_platforms.py deleted file mode 100644 index 65bd6440..00000000 --- a/snowflake/ml/model/deploy_platforms.py +++ /dev/null @@ -1,6 +0,0 @@ -from enum import Enum - - -class TargetPlatform(Enum): - WAREHOUSE = "warehouse" - SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES" diff --git a/snowflake/ml/model/models/BUILD.bazel b/snowflake/ml/model/models/BUILD.bazel index bbe59851..e452497b 100644 --- a/snowflake/ml/model/models/BUILD.bazel +++ b/snowflake/ml/model/models/BUILD.bazel @@ -7,20 +7,8 @@ py_library( srcs = ["huggingface_pipeline.py"], ) -py_library( - name = "llm_model", - srcs = ["llm.py"], -) - py_test( name = "huggingface_pipeline_test", srcs = ["huggingface_pipeline_test.py"], deps = [":huggingface_pipeline"], ) - -py_test( - name = "llm_test", - srcs = ["llm_test.py"], - compatible_with_snowpark = False, - deps = [":llm_model"], -) diff --git a/snowflake/ml/model/models/llm.py b/snowflake/ml/model/models/llm.py deleted file mode 100644 index 21aec7ba..00000000 --- a/snowflake/ml/model/models/llm.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -from dataclasses import dataclass, field -from enum import Enum -from typing import Optional, Set - -_PEFT_CONFIG_NAME = "adapter_config.json" - - -class SupportedLLMType(Enum): - LLAMA_MODEL_TYPE = "llama" - OPT_MODEL_TYPE = "opt" - - @classmethod - def valid_values(cls) -> Set[str]: - return {member.value for member in cls} - - -@dataclass(frozen=True) -class LLMOptions: - """ - This is the option class for LLM. - - Args: - revision: Revision of HF model. Defaults to None. - token: The token to use as HTTP bearer authorization for remote files. Defaults to None. - max_batch_size: Max batch size allowed for single inferenced. Defaults to 1. - """ - - revision: Optional[str] = field(default=None) - token: Optional[str] = field(default=None) - max_batch_size: int = field(default=1) - enable_tp: bool = field(default=False) - # TODO(halu): Below could be per query call param instead. - temperature: float = field(default=0.01) - top_p: float = field(default=1.0) - max_tokens: int = field(default=100) - - -class LLM: - class Mode(Enum): - LOCAL_LORA = "local_lora" - REMOTE_PRETRAIN = "remote_pretrain" - - def __init__( - self, - model_id_or_path: str, - *, - options: Optional[LLMOptions] = None, - ) -> None: - """ - - Args: - model_id_or_path: model_id or local dir to PEFT lora weights. - options: Options for LLM. Defaults to be None. - - Raises: - ValueError: When unsupported. - """ - if not options: - options = LLMOptions() - hub_kwargs = { - "revision": options.revision, - "token": options.token, - } - import transformers - - if os.path.isdir(model_id_or_path): - if not os.path.isfile(os.path.join(model_id_or_path, _PEFT_CONFIG_NAME)): - raise ValueError("Peft config is not found.") - - import peft - - peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined] - model_id_or_path, **hub_kwargs - ) - if peft_config.peft_type != peft.PeftType.LORA: # type: ignore[attr-defined] - raise ValueError("Only LORA is supported.") - if peft_config.task_type != peft.TaskType.CAUSAL_LM: # type: ignore[attr-defined] - raise ValueError("Only CAUSAL_LM is supported.") - base_model = peft_config.base_model_name_or_path - base_config = transformers.AutoConfig.from_pretrained(base_model, **hub_kwargs) - assert ( - base_config.model_type in SupportedLLMType.valid_values() - ), f"{base_config.model_type} is not supported." - self.mode = LLM.Mode.LOCAL_LORA - self.model_type = base_config.model_type - else: - # We support pre-train model as well - model_config = transformers.AutoConfig.from_pretrained( - model_id_or_path, - **hub_kwargs, - ) - assert ( - model_config.model_type in SupportedLLMType.valid_values() - ), f"{model_config.model_type} is not supported." - self.mode = LLM.Mode.REMOTE_PRETRAIN - self.model_type = model_config.model_type - - self.model_id_or_path = model_id_or_path - self.token = options.token - self.revision = options.revision - self.max_batch_size = options.max_batch_size - self.temperature = options.temperature - self.top_p = options.top_p - self.max_tokens = options.max_tokens - self.enable_tp = options.enable_tp diff --git a/snowflake/ml/model/models/llm_test.py b/snowflake/ml/model/models/llm_test.py deleted file mode 100644 index ae453518..00000000 --- a/snowflake/ml/model/models/llm_test.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -import tempfile - -from absl.testing import absltest - -from snowflake.ml.model.models import llm - - -class LLMTest(absltest.TestCase): - @classmethod - def setUpClass(self) -> None: - self.cache_dir = tempfile.TemporaryDirectory() - self._original_hf_home = os.getenv("HF_HOME", None) - os.environ["HF_HOME"] = self.cache_dir.name - - @classmethod - def tearDownClass(self) -> None: - if self._original_hf_home: - os.environ["HF_HOME"] = self._original_hf_home - else: - del os.environ["HF_HOME"] - self.cache_dir.cleanup() - - def test_llm(self) -> None: - import peft - - ft_model = peft.AutoPeftModelForCausalLM.from_pretrained( # type: ignore[attr-defined] - "peft-internal-testing/tiny-OPTForCausalLM-lora", - device_map="auto", - ) - tmp_dir = self.create_tempdir().full_path - ft_model.save_pretrained(tmp_dir) - llm.LLM(model_id_or_path=tmp_dir) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/model/package_visibility_test.py b/snowflake/ml/model/package_visibility_test.py index fe74e208..9da143ef 100644 --- a/snowflake/ml/model/package_visibility_test.py +++ b/snowflake/ml/model/package_visibility_test.py @@ -3,13 +3,7 @@ from absl.testing import absltest from snowflake.ml import model -from snowflake.ml.model import ( - _api, - custom_model, - deploy_platforms, - model_signature, - type_hints, -) +from snowflake.ml.model import custom_model, model_signature, type_hints class PackageVisibilityTest(absltest.TestCase): @@ -19,14 +13,10 @@ def test_class_visible(self) -> None: self.assertIsInstance(model.Model, type) self.assertIsInstance(model.ModelVersion, type) self.assertIsInstance(model.HuggingFacePipelineModel, type) - self.assertIsInstance(model.LLM, type) - self.assertIsInstance(model.LLMOptions, type) def test_module_visible(self) -> None: - self.assertIsInstance(_api, ModuleType) self.assertIsInstance(custom_model, ModuleType) self.assertIsInstance(model_signature, ModuleType) - self.assertIsInstance(deploy_platforms, ModuleType) self.assertIsInstance(type_hints, ModuleType) diff --git a/snowflake/ml/model/type_hints.py b/snowflake/ml/model/type_hints.py index c8873c6f..c62b94ff 100644 --- a/snowflake/ml/model/type_hints.py +++ b/snowflake/ml/model/type_hints.py @@ -1,23 +1,9 @@ # mypy: disable-error-code="import" from enum import Enum -from typing import ( - TYPE_CHECKING, - Any, - Dict, - List, - Literal, - Optional, - Sequence, - TypedDict, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Dict, Literal, Sequence, TypedDict, TypeVar, Union import numpy.typing as npt -from typing_extensions import NotRequired, Required - -from snowflake.ml.model import deploy_platforms -from snowflake.ml.model._signatures import core +from typing_extensions import NotRequired if TYPE_CHECKING: import catboost @@ -35,7 +21,6 @@ import snowflake.ml.model.custom_model import snowflake.ml.model.models.huggingface_pipeline - import snowflake.ml.model.models.llm import snowflake.snowpark from snowflake.ml.modeling.framework import base # noqa: F401 @@ -91,7 +76,6 @@ "transformers.Pipeline", "sentence_transformers.SentenceTransformer", "snowflake.ml.model.models.huggingface_pipeline.HuggingFacePipelineModel", - "snowflake.ml.model.models.llm.LLM", ] SupportedModelType = Union[ @@ -134,86 +118,11 @@ "tensorflow", "torchscript", "xgboost", - "llm", ] _ModelType = TypeVar("_ModelType", bound=SupportedModelType) -class DeployOptions(TypedDict): - """Common Options for deploying to Snowflake.""" - - ... - - -class WarehouseDeployOptions(DeployOptions): - """Options for deploying to the Snowflake Warehouse. - - - permanent_udf_stage_location: A Snowflake stage option where the UDF should be persisted. If specified, the model - will be deployed as a permanent UDF, otherwise temporary. - relax_version: Whether or not relax the version constraints of the dependencies if unresolvable. It detects any - ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to False. - replace_udf: Flag to indicate when deploying model as permanent UDF, whether overwriting existed UDF is allowed. - Default to False. - """ - - permanent_udf_stage_location: NotRequired[str] - relax_version: NotRequired[bool] - replace_udf: NotRequired[bool] - - -class SnowparkContainerServiceDeployOptions(DeployOptions): - """Deployment options for deploying to SnowService. - When type hint is updated, please ensure the concrete class is updated accordingly at: - //snowflake/ml/model/_deploy_client/snowservice/_deploy_options - - compute_pool[REQUIRED]: SnowService compute pool name. Please refer to official doc for how to create a - compute pool: https://docs.snowflake.com/LIMITEDACCESS/snowpark-containers/reference/compute-pool - image_repo: SnowService image repo path. e.g. "///". Default to auto - inferred based on session information. - min_instances: Minimum number of service replicas. Default to 1. - max_instances: Maximum number of service replicas. Default to 1. - prebuilt_snowflake_image: When provided, the image-building step is skipped, and the pre-built image from - Snowflake is used as is. This option is for users who consistently use the same image for multiple use - cases, allowing faster deployment. The snowflake image used for deployment is logged to the console for - future use. Default to None. - num_gpus: Number of GPUs to be used for the service. Default to 0. - num_workers: Number of workers used for model inference. Please ensure that the number of workers is set lower than - the total available memory divided by the size of model to prevent memory-related issues. Default is number of - CPU cores * 2 + 1. - enable_remote_image_build: When set to True, will enable image build on a remote SnowService job. Default is True. - force_image_build: When set to True, an image rebuild will occur. The default is False, which means the system - will automatically check whether a previously built image can be reused - model_in_image: When set to True, image would container full model weights. The default if False, which - means image without model weights and we do stage mount to access weights. - debug_mode: When set to True, deployment artifacts will be persisted in a local temp directory. - enable_ingress: When set to True, will expose HTTP endpoint for access to the predict method of the created - service. - external_access_integrations: External Access Integrations name used to build image and deploy the model. - Please refer to the doc for how to create an External Access Integrations: https://docs.snowflake.com/ - developer-guide/snowpark-container-services/additional-considerations-services-jobs - #configuring-network-capabilities . - To make sure your image could be built, access to the following endpoint must be allowed. - docker.com:80, docker.com:443, anaconda.com:80, anaconda.com:443, anaconda.org:80, anaconda.org:443, - pypi.org:80, pypi.org:443 - """ - - compute_pool: str - image_repo: NotRequired[str] - min_instances: NotRequired[int] - max_instances: NotRequired[int] - prebuilt_snowflake_image: NotRequired[str] - num_gpus: NotRequired[int] - num_workers: NotRequired[int] - enable_remote_image_build: NotRequired[bool] - force_image_build: NotRequired[bool] - model_in_image: NotRequired[bool] - debug_mode: NotRequired[bool] - enable_ingress: NotRequired[bool] - external_access_integrations: List[str] - - class ModelMethodSaveOptions(TypedDict): case_sensitive: NotRequired[bool] max_batch_size: NotRequired[int] @@ -224,13 +133,12 @@ class BaseModelSaveOption(TypedDict): """Options for saving the model. embed_local_ml_library: Embedding local SnowML into the code directory of the folder. - relax_version: Whether or not relax the version constraints of the dependencies if unresolvable. It detects any - ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to False. + relax_version: Whether or not relax the version constraints of the dependencies if unresolvable in Warehouse. + It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True. """ embed_local_ml_library: NotRequired[bool] relax_version: NotRequired[bool] - _legacy_save: NotRequired[bool] function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]] method_options: NotRequired[Dict[str, ModelMethodSaveOptions]] enable_explainability: NotRequired[bool] @@ -293,10 +201,6 @@ class SentenceTransformersSaveOptions(BaseModelSaveOption): cuda_version: NotRequired[str] -class LLMSaveOptions(BaseModelSaveOption): - cuda_version: NotRequired[str] - - ModelSaveOption = Union[ BaseModelSaveOption, CatBoostModelSaveOptions, @@ -311,7 +215,6 @@ class LLMSaveOptions(BaseModelSaveOption): MLFlowSaveOptions, HuggingFaceSaveOptions, SentenceTransformersSaveOptions, - LLMSaveOptions, ] @@ -369,10 +272,7 @@ class HuggingFaceLoadOptions(BaseModelLoadOption): class SentenceTransformersLoadOptions(BaseModelLoadOption): use_gpu: NotRequired[bool] - - -class LLMLoadOptions(BaseModelLoadOption): - ... + device: NotRequired[str] ModelLoadOption = Union[ @@ -389,53 +289,12 @@ class LLMLoadOptions(BaseModelLoadOption): MLFlowLoadOptions, HuggingFaceLoadOptions, SentenceTransformersLoadOptions, - LLMLoadOptions, -] - - -class SnowparkContainerServiceDeployDetails(TypedDict): - """ - Attributes: - service_info: A snowpark row containing the result of "describe service" - service_function_sql: SQL for service function creation. - """ - - service_info: Optional[Dict[str, Any]] - service_function_sql: str - - -class WarehouseDeployDetails(TypedDict): - ... - - -DeployDetails = Union[ - SnowparkContainerServiceDeployDetails, - WarehouseDeployDetails, ] -class Deployment(TypedDict): - """Deployment information. - - Attributes: - name: Name of the deployment. - platform: Target platform to deploy the model. - target_method: Target method name. - signature: The signature of the model method. - options: Additional options when deploying the model. - """ - - name: Required[str] - platform: Required[deploy_platforms.TargetPlatform] - target_method: Required[str] - signature: core.ModelSignature - options: Required[DeployOptions] - details: NotRequired[DeployDetails] - - -class ModelObjective(Enum): - UNKNOWN = "unknown" - BINARY_CLASSIFICATION = "binary_classification" - MULTI_CLASSIFICATION = "multi_classification" - REGRESSION = "regression" - RANKING = "ranking" +class Task(Enum): + UNKNOWN = "UNKNOWN" + TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION" + TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION" + TABULAR_REGRESSION = "TABULAR_REGRESSION" + TABULAR_RANKING = "TABULAR_RANKING" diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py b/snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py index d0e2d065..8a14aedc 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py @@ -377,7 +377,6 @@ def fit_search_snowpark( anonymous=True, imports=imports, # type: ignore[arg-type] statement_params=sproc_statement_params, - execute_as="caller", ) def _distributed_search( session: Session, @@ -783,7 +782,6 @@ def fit_search_snowpark_enable_efficient_memory_usage( anonymous=True, imports=imports, # type: ignore[arg-type] statement_params=sproc_statement_params, - execute_as="caller", ) def _distributed_search( session: Session, diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py index cc37c338..0546a0ba 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py @@ -230,7 +230,6 @@ def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bo replace=True, session=self.session, statement_params=statement_params, - execute_as="caller", anonymous=anonymous, ) return fit_wrapper_sproc @@ -461,9 +460,7 @@ def fit_transform_wrapper_function( session.write_pandas( transformed_pandas_df, fit_transform_result_name, - auto_create_table=True, - table_type="temp", - quote_identifiers=False, + overwrite=True, ) return str(os.path.basename(local_result_file_name)) @@ -488,7 +485,6 @@ def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str], anony session=self.session, statement_params=statement_params, anonymous=anonymous, - execute_as="caller", ) return fit_predict_wrapper_sproc @@ -510,7 +506,6 @@ def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str], ano replace=True, session=self.session, statement_params=statement_params, - execute_as="caller", anonymous=anonymous, ) @@ -730,6 +725,22 @@ def train_fit_transform( fit_transform_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE) + # Create a temp table in advance to store the output + # This would allow us to use the same table outside the stored procedure + df_one_line = dataset.limit(1).to_pandas(statement_params=statement_params) + df_one_line[ + expected_output_cols_list[0] + ] = "[0]" # Add one column as the output_col; this is a dummy value to represent the OBJECT type + if drop_input_cols: + self.session.write_pandas( + df_one_line[expected_output_cols_list[0]], + fit_transform_result_name, + auto_create_table=True, + table_type="temp", + ) + else: + self.session.write_pandas(df_one_line, fit_transform_result_name, auto_create_table=True, table_type="temp") + sproc_export_file_name: str = fit_transform_wrapper_sproc( self.session, queries, diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py b/snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py index 10b47c6e..df3d9241 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py @@ -303,7 +303,6 @@ def _get_xgb_external_memory_fit_wrapper_sproc( statement_params=statement_params, anonymous=True, imports=list(import_file_paths), - execute_as="caller", ) # type: ignore[misc] def fit_wrapper_sproc( session: Session, diff --git a/snowflake/ml/modeling/metrics/metrics_utils.py b/snowflake/ml/modeling/metrics/metrics_utils.py index 55ffd17c..69840702 100644 --- a/snowflake/ml/modeling/metrics/metrics_utils.py +++ b/snowflake/ml/modeling/metrics/metrics_utils.py @@ -59,7 +59,7 @@ def end_partition(self) -> Iterable[Tuple[bytes]]: ] ), input_types=[T.BinaryType()], - packages=["numpy", "cloudpickle"], + packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"], name=accumulator, is_permanent=False, replace=True, @@ -174,7 +174,7 @@ def accumulate_batch_sum_and_dot_prod(self) -> None: ] ), input_types=[T.ArrayType(), T.IntegerType(), T.IntegerType()], - packages=["numpy", "cloudpickle"], + packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"], name=sharded_dot_and_sum_computer, is_permanent=False, replace=True, diff --git a/snowflake/ml/modeling/metrics/ranking.py b/snowflake/ml/modeling/metrics/ranking.py index a820885d..540a94da 100644 --- a/snowflake/ml/modeling/metrics/ranking.py +++ b/snowflake/ml/modeling/metrics/ranking.py @@ -102,7 +102,6 @@ def precision_recall_curve( ], statement_params=statement_params, anonymous=True, - execute_as="caller", ) def precision_recall_curve_anon_sproc(session: snowpark.Session) -> bytes: for query in queries[:-1]: @@ -250,7 +249,6 @@ class scores must correspond to the order of ``labels``, ], statement_params=statement_params, anonymous=True, - execute_as="caller", ) def roc_auc_score_anon_sproc(session: snowpark.Session) -> bytes: for query in queries[:-1]: @@ -354,7 +352,6 @@ def roc_curve( ], statement_params=statement_params, anonymous=True, - execute_as="caller", ) def roc_curve_anon_sproc(session: snowpark.Session) -> bytes: for query in queries[:-1]: diff --git a/snowflake/ml/modeling/metrics/regression.py b/snowflake/ml/modeling/metrics/regression.py index 91bfc7e5..28d80b5b 100644 --- a/snowflake/ml/modeling/metrics/regression.py +++ b/snowflake/ml/modeling/metrics/regression.py @@ -87,7 +87,6 @@ def d2_absolute_error_score( ], statement_params=statement_params, anonymous=True, - execute_as="caller", ) def d2_absolute_error_score_anon_sproc(session: snowpark.Session) -> bytes: for query in queries[:-1]: @@ -185,7 +184,6 @@ def d2_pinball_score( ], statement_params=statement_params, anonymous=True, - execute_as="caller", ) def d2_pinball_score_anon_sproc(session: snowpark.Session) -> bytes: for query in queries[:-1]: @@ -301,7 +299,6 @@ def explained_variance_score( ], statement_params=statement_params, anonymous=True, - execute_as="caller", ) def explained_variance_score_anon_sproc(session: snowpark.Session) -> bytes: for query in queries[:-1]: diff --git a/snowflake/ml/modeling/pipeline/pipeline.py b/snowflake/ml/modeling/pipeline/pipeline.py index 039b5e60..27aa3a9d 100644 --- a/snowflake/ml/modeling/pipeline/pipeline.py +++ b/snowflake/ml/modeling/pipeline/pipeline.py @@ -379,7 +379,6 @@ def pipeline_within_one_sproc( anonymous=True, imports=imports, # type: ignore[arg-type] statement_params=sproc_statement_params, - execute_as="caller", ) sproc_export_file_name: str = pipeline_within_one_sproc( diff --git a/snowflake/ml/monitoring/BUILD.bazel b/snowflake/ml/monitoring/BUILD.bazel index 1a9e8e18..dbd05879 100644 --- a/snowflake/ml/monitoring/BUILD.bazel +++ b/snowflake/ml/monitoring/BUILD.bazel @@ -1,4 +1,6 @@ -load("//bazel:py_rules.bzl", "py_library", "py_package", "py_wheel") +load("//bazel:py_rules.bzl", "py_library", "py_package") + +package(default_visibility = ["//visibility:public"]) package_group( name = "monitoring", @@ -7,16 +9,9 @@ package_group( ], ) -package(default_visibility = ["//visibility:public"]) - -exports_files([ - "pyproject.toml", -]) - py_library( - name = "monitoring_lib", + name = "shap_lib", srcs = [ - "monitor.py", "shap.py", ], deps = [ @@ -24,16 +19,19 @@ py_library( ], ) +py_library( + name = "model_monitor_impl", + deps = [ + "//snowflake/ml/monitoring/_client:model_monitor_lib", + "//snowflake/ml/monitoring/entities:entities_lib", + ], +) + py_package( name = "monitoring_pkg", packages = ["snowflake.ml"], deps = [ - ":monitoring_lib", + ":model_monitor_impl", + ":shap_lib", ], ) - -py_wheel( - name = "wheel", - pyproject_toml = ":pyproject.toml", - deps = ["//snowflake/ml/monitoring:monitoring_pkg"], -) diff --git a/snowflake/ml/monitoring/_client/BUILD.bazel b/snowflake/ml/monitoring/_client/BUILD.bazel new file mode 100644 index 00000000..c0a4b344 --- /dev/null +++ b/snowflake/ml/monitoring/_client/BUILD.bazel @@ -0,0 +1,82 @@ +load("//bazel:py_rules.bzl", "py_library", "py_test") + +package(default_visibility = [ + "//bazel:snowml_public_common", + "//snowflake/ml/monitoring", +]) + +filegroup( + name = "queries", + srcs = glob([ + "queries/*.sql", + "queries/*.ssql", + ]), +) + +# TODO(jfishbein): Move this to //snowflake/ml/model/_client/ops/ or somewhere similar +py_library( + name = "monitor_sql", + srcs = [ + "monitor_sql_client.py", + ], + data = [":queries"], + deps = [ + "//snowflake/ml/_internal/utils:db_utils", + "//snowflake/ml/_internal/utils:query_result_checker", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/_internal/utils:table_manager", + "//snowflake/ml/dataset", + "//snowflake/ml/model/_client/model:model_version_impl", + "//snowflake/ml/monitoring/entities:entities_lib", + ], +) + +# TODO(jfishbein): Move this to //snowflake/ml/monitoring/_manager/ or somewhere similar +py_library( + name = "model_monitor_lib", + srcs = [ + "model_monitor.py", + "model_monitor_manager.py", + "model_monitor_version.py", + ], + deps = [ + ":monitor_sql", + "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/utils:db_utils", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/monitoring/entities:entities_lib", + ], +) + +py_test( + name = "monitor_sql_client_test", + srcs = [ + "monitor_sql_client_test.py", + ], + deps = [ + ":model_monitor_lib", + "//snowflake/ml/test_utils:mock_session", + ], +) + +py_test( + name = "model_monitor_manager_test", + srcs = [ + "model_monitor_manager_test.py", + ], + deps = [ + ":model_monitor_lib", + "//snowflake/ml/test_utils:mock_session", + ], +) + +py_test( + name = "model_monitor_test", + srcs = [ + "model_monitor_test.py", + ], + deps = [ + ":model_monitor_lib", + "//snowflake/ml/test_utils:mock_session", + ], +) diff --git a/snowflake/ml/monitoring/_client/model_monitor.py b/snowflake/ml/monitoring/_client/model_monitor.py new file mode 100644 index 00000000..1f8d49a0 --- /dev/null +++ b/snowflake/ml/monitoring/_client/model_monitor.py @@ -0,0 +1,126 @@ +from typing import List, Union + +import pandas as pd + +from snowflake import snowpark +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.monitoring._client import monitor_sql_client + + +class ModelMonitor: + """Class to manage instrumentation of Model Monitoring and Observability""" + + name: sql_identifier.SqlIdentifier + _model_monitor_client: monitor_sql_client._ModelMonitorSQLClient + _fully_qualified_model_name: str + _version_name: sql_identifier.SqlIdentifier + _function_name: sql_identifier.SqlIdentifier + _prediction_columns: List[sql_identifier.SqlIdentifier] + _label_columns: List[sql_identifier.SqlIdentifier] + + def __init__(self) -> None: + raise RuntimeError("ModelMonitor's initializer is not meant to be used.") + + @classmethod + def _ref( + cls, + model_monitor_client: monitor_sql_client._ModelMonitorSQLClient, + name: sql_identifier.SqlIdentifier, + *, + fully_qualified_model_name: str, + version_name: sql_identifier.SqlIdentifier, + function_name: sql_identifier.SqlIdentifier, + prediction_columns: List[sql_identifier.SqlIdentifier], + label_columns: List[sql_identifier.SqlIdentifier], + ) -> "ModelMonitor": + self: "ModelMonitor" = object.__new__(cls) + self.name = name + self._model_monitor_client = model_monitor_client + self._fully_qualified_model_name = fully_qualified_model_name + self._version_name = version_name + self._function_name = function_name + self._prediction_columns = prediction_columns + self._label_columns = label_columns + return self + + @telemetry.send_api_usage_telemetry( + project=telemetry.TelemetryProject.MLOPS.value, + subproject=telemetry.TelemetrySubProject.MONITORING.value, + ) + def set_baseline(self, baseline_df: Union[pd.DataFrame, snowpark.DataFrame]) -> None: + """ + The baseline dataframe is compared with the monitored data once monitoring is enabled. + The columns of the dataframe should match the columns of the source table that the + ModelMonitor was configured with. Calling this method overwrites any existing baseline split data. + + Args: + baseline_df: Snowpark dataframe containing baseline data. + + Raises: + ValueError: baseline_df does not contain prediction or label columns + """ + statement_params = telemetry.get_statement_params( + project=telemetry.TelemetryProject.MLOPS.value, + subproject=telemetry.TelemetrySubProject.MONITORING.value, + ) + + if isinstance(baseline_df, pd.DataFrame): + baseline_df = self._model_monitor_client._sql_client._session.create_dataframe(baseline_df) + + column_names_identifiers: List[sql_identifier.SqlIdentifier] = [ + sql_identifier.SqlIdentifier(column_name) for column_name in baseline_df.columns + ] + prediction_cols_not_found = any( + [prediction_col not in column_names_identifiers for prediction_col in self._prediction_columns] + ) + label_cols_not_found = any( + [label_col.identifier() not in column_names_identifiers for label_col in self._label_columns] + ) + + if prediction_cols_not_found: + raise ValueError( + "Specified prediction columns were not found in the baseline dataframe. " + f"Columns provided were: {column_names_identifiers}. " + f"Configured prediction columns were: {self._prediction_columns}." + ) + if label_cols_not_found: + raise ValueError( + "Specified label columns were not found in the baseline dataframe." + f"Columns provided in the baseline dataframe were: {column_names_identifiers}." + f"Configured label columns were: {self._label_columns}." + ) + + # Create the table by materializing the df + self._model_monitor_client.materialize_baseline_dataframe( + baseline_df, + self._fully_qualified_model_name, + self._version_name, + statement_params=statement_params, + ) + + def suspend(self) -> None: + """Suspend pipeline for ModelMonitor""" + statement_params = telemetry.get_statement_params( + telemetry.TelemetryProject.MLOPS.value, + telemetry.TelemetrySubProject.MONITORING.value, + ) + _, _, model_name = sql_identifier.parse_fully_qualified_name(self._fully_qualified_model_name) + self._model_monitor_client.suspend_monitor_dynamic_tables( + model_name=model_name, + version_name=self._version_name, + statement_params=statement_params, + ) + + def resume(self) -> None: + """Resume pipeline for ModelMonitor""" + statement_params = telemetry.get_statement_params( + telemetry.TelemetryProject.MLOPS.value, + telemetry.TelemetrySubProject.MONITORING.value, + ) + _, _, model_name = sql_identifier.parse_fully_qualified_name(self._fully_qualified_model_name) + self._model_monitor_client.resume_monitor_dynamic_tables( + model_name=model_name, + version_name=self._version_name, + statement_params=statement_params, + ) diff --git a/snowflake/ml/monitoring/_client/model_monitor_manager.py b/snowflake/ml/monitoring/_client/model_monitor_manager.py new file mode 100644 index 00000000..8877b095 --- /dev/null +++ b/snowflake/ml/monitoring/_client/model_monitor_manager.py @@ -0,0 +1,361 @@ +from typing import Any, Dict, List, Optional + +from snowflake import snowpark +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.utils import db_utils, sql_identifier +from snowflake.ml.model import type_hints +from snowflake.ml.model._client.model import model_version_impl +from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema +from snowflake.ml.monitoring._client import model_monitor, monitor_sql_client +from snowflake.ml.monitoring.entities import ( + model_monitor_config, + model_monitor_interval, +) +from snowflake.snowpark import session + + +def _validate_name_constraints(model_version: model_version_impl.ModelVersion) -> None: + system_table_prefixes = [ + monitor_sql_client._SNOWML_MONITORING_TABLE_NAME_PREFIX, + monitor_sql_client._SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX, + ] + + max_allowed_model_name_and_version_length = ( + db_utils.MAX_IDENTIFIER_LENGTH - max(len(prefix) for prefix in system_table_prefixes) - 1 + ) # -1 includes '_' between model_name + model_version + if len(model_version.model_name) + len(model_version.version_name) > max_allowed_model_name_and_version_length: + error_msg = f"Model name and version name exceeds maximum length of {max_allowed_model_name_and_version_length}" + raise ValueError(error_msg) + + +class ModelMonitorManager: + """Class to manage internal operations for Model Monitor workflows.""" # TODO: Move to Registry. + + @staticmethod + def setup(session: session.Session, database_name: str, schema_name: str) -> None: + """Static method to set up schema for Model Monitoring resources. + + Args: + session: The Snowpark Session to connect with Snowflake. + database_name: The name of the database. If None, the current database of the session + will be used. Defaults to None. + schema_name: The name of the schema. If None, the current schema of the session + will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None. + """ + statement_params = telemetry.get_statement_params( + project=telemetry.TelemetryProject.MLOPS.value, + subproject=telemetry.TelemetrySubProject.MONITORING.value, + ) + database_name_id = sql_identifier.SqlIdentifier(database_name) + schema_name_id = sql_identifier.SqlIdentifier(schema_name) + monitor_sql_client._ModelMonitorSQLClient.initialize_monitoring_schema( + session, database_name_id, schema_name_id, statement_params=statement_params + ) + + def _fetch_task_from_model_version( + self, + model_version: model_version_impl.ModelVersion, + ) -> type_hints.Task: + task = model_version.get_model_task() + if task == type_hints.Task.UNKNOWN: + raise ValueError("Registry model must be logged with task in order to be monitored.") + return task + + def __init__( + self, + session: session.Session, + database_name: sql_identifier.SqlIdentifier, + schema_name: sql_identifier.SqlIdentifier, + *, + create_if_not_exists: bool = False, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Opens a ModelMonitorManager for a given database and schema. + Optionally sets up the schema for Model Monitoring. + + Args: + session: The Snowpark Session to connect with Snowflake. + database_name: The name of the database. + schema_name: The name of the schema. + create_if_not_exists: Flag whether to initialize resources in the schema needed for Model Monitoring. + statement_params: Optional set of statement params. + + Raises: + ValueError: When there is no specified or active database in the session. + """ + self._database_name = database_name + self._schema_name = schema_name + self.statement_params = statement_params + self._model_monitor_client = monitor_sql_client._ModelMonitorSQLClient( + session, + database_name=self._database_name, + schema_name=self._schema_name, + ) + if create_if_not_exists: + monitor_sql_client._ModelMonitorSQLClient.initialize_monitoring_schema( + session, self._database_name, self._schema_name, self.statement_params + ) + elif not self._model_monitor_client._validate_is_initialized(): + raise ValueError( + "Monitoring has not been setup. Set create_if_not_exists or call ModelMonitorManager.setup" + ) + + def _get_and_validate_model_function_from_model_version( + self, function: str, model_version: model_version_impl.ModelVersion + ) -> model_manifest_schema.ModelFunctionInfo: + functions = model_version.show_functions() + for f in functions: + if f["target_method"] == function: + return f + existing_target_methods = {f["target_method"] for f in functions} + raise ValueError( + f"Function with name {function} does not exist in the given model version. " + f"Found: {existing_target_methods}." + ) + + def _validate_monitor_config_or_raise( + self, + table_config: model_monitor_config.ModelMonitorTableConfig, + model_monitor_config: model_monitor_config.ModelMonitorConfig, + ) -> None: + """Validate provided config for model monitor. + + Args: + table_config: Config for model monitor tables. + model_monitor_config: Config for ModelMonitor. + + Raises: + ValueError: If warehouse provided does not exist. + """ + + # Validate naming will not exceed 255 chars + _validate_name_constraints(model_monitor_config.model_version) + + if len(table_config.prediction_columns) != len(table_config.label_columns): + raise ValueError("Prediction and Label column names must be of the same length.") + # output and ground cols are list to keep interface extensible. + # for prpr only one label and one output col will be supported + if len(table_config.prediction_columns) != 1 or len(table_config.label_columns) != 1: + raise ValueError("Multiple Output columns are not supported in monitoring") + + # Validate warehouse exists. + warehouse_name_id = sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name) + self._model_monitor_client.validate_monitor_warehouse(warehouse_name_id, statement_params=self.statement_params) + + # Validate refresh interval. + try: + num_units, time_units = model_monitor_config.refresh_interval.strip().split(" ") + int(num_units) # try to cast + if time_units.lower() not in {"seconds", "minutes", "hours", "days"}: + raise ValueError( + """Invalid time unit in refresh interval. Provide ' '. +See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info.""" + ) + except Exception as e: # TODO: Link to DT page. + raise ValueError( + f"""Failed to parse refresh interval with exception {e}. + Provide ' '. +See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info.""" + ) + + def add_monitor( + self, + name: str, + table_config: model_monitor_config.ModelMonitorTableConfig, + model_monitor_config: model_monitor_config.ModelMonitorConfig, + *, + add_dashboard_udtfs: bool = False, + ) -> model_monitor.ModelMonitor: + """Add a new Model Monitor. + + Args: + name: Name of Model Monitor to create. + table_config: Configuration options for the source table used in ModelMonitor. + model_monitor_config: Configuration options of ModelMonitor. + add_dashboard_udtfs: Add UDTFs useful for creating a dashboard. + + Returns: + The newly added ModelMonitor object. + """ + # Validates configuration or raise. + self._validate_monitor_config_or_raise(table_config, model_monitor_config) + model_function = self._get_and_validate_model_function_from_model_version( + model_monitor_config.model_function_name, model_monitor_config.model_version + ) + monitor_refresh_interval = model_monitor_interval.ModelMonitorRefreshInterval( + model_monitor_config.refresh_interval + ) + name_id = sql_identifier.SqlIdentifier(name) + source_table_name_id = sql_identifier.SqlIdentifier(table_config.source_table) + prediction_columns = [ + sql_identifier.SqlIdentifier(column_name) for column_name in table_config.prediction_columns + ] + label_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in table_config.label_columns] + id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in table_config.id_columns] + ts_column = sql_identifier.SqlIdentifier(table_config.timestamp_column) + + # Validate source table + self._model_monitor_client.validate_source_table( + source_table_name=source_table_name_id, + timestamp_column=ts_column, + prediction_columns=prediction_columns, + label_columns=label_columns, + id_columns=id_columns, + model_function=model_function, + ) + + task = self._fetch_task_from_model_version(model_version=model_monitor_config.model_version) + score_type = self._model_monitor_client.get_score_type(task, source_table_name_id, prediction_columns) + + # Insert monitoring metadata for new model version. + self._model_monitor_client.create_monitor_on_model_version( + monitor_name=name_id, + source_table_name=source_table_name_id, + fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name, + version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name), + function_name=model_monitor_config.model_function_name, + timestamp_column=ts_column, + prediction_columns=prediction_columns, + label_columns=label_columns, + id_columns=id_columns, + task=task, + statement_params=self.statement_params, + ) + + # Create Dynamic tables for model monitor. + self._model_monitor_client.create_dynamic_tables_for_monitor( + model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name), + model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name), + task=task, + source_table_name=source_table_name_id, + refresh_interval=monitor_refresh_interval, + aggregation_window=model_monitor_config.aggregation_window, + warehouse_name=sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name), + timestamp_column=sql_identifier.SqlIdentifier(table_config.timestamp_column), + id_columns=id_columns, + prediction_columns=prediction_columns, + label_columns=label_columns, + score_type=score_type, + ) + + # Initialize baseline table. + self._model_monitor_client.initialize_baseline_table( + model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name), + version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name), + source_table_name=table_config.source_table, + columns_to_drop=[ts_column, *id_columns], + statement_params=self.statement_params, + ) + + # Add udtfs helpful for dashboard queries. + # TODO(apgupta) Make this true by default. + if add_dashboard_udtfs: + self._model_monitor_client.add_dashboard_udtfs( + name_id, + model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name), + model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name), + task=task, + score_type=score_type, + output_columns=prediction_columns, + ground_truth_columns=label_columns, + ) + + return model_monitor.ModelMonitor._ref( + model_monitor_client=self._model_monitor_client, + name=name_id, + fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name, + version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name), + function_name=sql_identifier.SqlIdentifier(model_monitor_config.model_function_name), + prediction_columns=prediction_columns, + label_columns=label_columns, + ) + + def get_monitor_by_model_version( + self, model_version: model_version_impl.ModelVersion + ) -> model_monitor.ModelMonitor: + fq_model_name = model_version.fully_qualified_model_name + version_name = sql_identifier.SqlIdentifier(model_version.version_name) + if self._model_monitor_client.validate_existence(fq_model_name, version_name, self.statement_params): + model_db, model_schema, model_name = sql_identifier.parse_fully_qualified_name(fq_model_name) + if model_db is None or model_schema is None: + raise ValueError("Failed to parse model name") + + model_monitor_params: monitor_sql_client._ModelMonitorParams = ( + self._model_monitor_client.get_model_monitor_by_model_version( + model_db=model_db, + model_schema=model_schema, + model_name=model_name, + version_name=version_name, + statement_params=self.statement_params, + ) + ) + return model_monitor.ModelMonitor._ref( + model_monitor_client=self._model_monitor_client, + name=sql_identifier.SqlIdentifier(model_monitor_params["monitor_name"]), + fully_qualified_model_name=fq_model_name, + version_name=version_name, + function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]), + prediction_columns=model_monitor_params["prediction_columns"], + label_columns=model_monitor_params["label_columns"], + ) + + else: + raise ValueError( + f"ModelMonitor not found for model version {model_version.model_name} - {model_version.version_name}" + ) + + def get_monitor(self, name: str) -> model_monitor.ModelMonitor: + """Get a Model Monitor from the Registry + + Args: + name: Name of Model Monitor to retrieve. + + Raises: + ValueError: If model monitor is not found. + + Returns: + The fetched ModelMonitor. + """ + name_id = sql_identifier.SqlIdentifier(name) + + if not self._model_monitor_client.validate_existence_by_name( + monitor_name=name_id, + statement_params=self.statement_params, + ): + raise ValueError(f"Unable to find model monitor '{name}'") + model_monitor_params: monitor_sql_client._ModelMonitorParams = ( + self._model_monitor_client.get_model_monitor_by_name(name_id, statement_params=self.statement_params) + ) + + return model_monitor.ModelMonitor._ref( + model_monitor_client=self._model_monitor_client, + name=name_id, + fully_qualified_model_name=model_monitor_params["fully_qualified_model_name"], + version_name=sql_identifier.SqlIdentifier(model_monitor_params["version_name"]), + function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]), + prediction_columns=model_monitor_params["prediction_columns"], + label_columns=model_monitor_params["label_columns"], + ) + + def show_model_monitors(self) -> List[snowpark.Row]: + """Show all model monitors in the registry. + + Returns: + List of snowpark.Row containing metadata for each model monitor. + """ + return self._model_monitor_client.get_all_model_monitor_metadata() + + def delete_monitor(self, name: str) -> None: + """Delete a Model Monitor from the Registry + + Args: + name: Name of the Model Monitor to delete. + """ + name_id = sql_identifier.SqlIdentifier(name) + monitor_params = self._model_monitor_client.get_model_monitor_by_name(name_id) + _, _, model = sql_identifier.parse_fully_qualified_name(monitor_params["fully_qualified_model_name"]) + version = sql_identifier.SqlIdentifier(monitor_params["version_name"]) + self._model_monitor_client.delete_monitor_metadata(name_id) + self._model_monitor_client.delete_baseline_table(model, version) + self._model_monitor_client.delete_dynamic_tables(model, version) diff --git a/snowflake/ml/monitoring/_client/model_monitor_manager_test.py b/snowflake/ml/monitoring/_client/model_monitor_manager_test.py new file mode 100644 index 00000000..e3a8ee87 --- /dev/null +++ b/snowflake/ml/monitoring/_client/model_monitor_manager_test.py @@ -0,0 +1,379 @@ +import re +from typing import cast +from unittest import mock +from unittest.mock import patch + +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import model_signature, type_hints +from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema +from snowflake.ml.monitoring._client import model_monitor_manager, monitor_sql_client +from snowflake.ml.monitoring.entities import ( + model_monitor_config, + model_monitor_interval, + output_score_type, +) +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import Row, Session + + +def _build_mock_model_version( + fq_model_name: str, + model_version_name: str, + task: type_hints.Task = type_hints.Task.TABULAR_REGRESSION, +) -> mock.MagicMock: + model_version = mock.MagicMock() + model_version.fully_qualified_model_name = fq_model_name + model_version.version_name = model_version_name + + _, _, model_name = sql_identifier.parse_fully_qualified_name(fq_model_name) + model_version.model_name = model_name + model_version.get_model_task.return_value = task + model_version.show_functions.return_value = [ + model_manifest_schema.ModelFunctionInfo( + name="PREDICT", + target_method="predict", + target_method_function_type="FUNCTION", + signature=model_signature.ModelSignature(inputs=[], outputs=[]), + is_partitioned=False, + ) + ] + return model_version + + +class ModelMonitorManagerHelpersTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.test_db = sql_identifier.SqlIdentifier("SNOWML_OBSERVABILITY") + self.test_schema = sql_identifier.SqlIdentifier("METADATA") + self.test_warehouse = "WH_TEST" + self.test_model_name = "TEST_MODEL" + self.test_version_name = "TEST_VERSION" + self.test_fq_model_name = f"{self.test_db}.{self.test_schema}.{self.test_model_name}" + self.test_source_table_name = "TEST_TABLE" + + self.test_model_version = "TEST_VERSION" + self.test_model = "TEST_MODEL" + self.test_fq_model_name = f"{self.test_db}.{self.test_schema}.{self.test_model}" + + m_model_version = mock.MagicMock() + m_model_version.version_name = self.test_model_version + m_model_version.model_name = self.test_model + m_model_version.fully_qualified_model_name = self.test_fq_model_name + m_model_version.get_model_task.return_value = type_hints.Task.TABULAR_REGRESSION + self.mv = m_model_version + + self.test_monitor_config = model_monitor_config.ModelMonitorConfig( + model_version=self.mv, + model_function_name="predict", + background_compute_warehouse_name=self.test_warehouse, + ) + self.test_table_config = model_monitor_config.ModelMonitorTableConfig( + prediction_columns=["A"], + label_columns=["B"], + id_columns=["C"], + timestamp_column="D", + source_table=self.test_source_table_name, + ) + self._init_mm_with_patch() + + def tearDown(self) -> None: + self.m_session.finalize() + + def test_validate_monitor_config(self) -> None: + malformed_refresh = "BAD BAD" + mm_config = model_monitor_config.ModelMonitorConfig( + model_version=_build_mock_model_version(self.test_fq_model_name, self.test_version_name), + model_function_name="predict", + background_compute_warehouse_name=self.test_warehouse, + refresh_interval=malformed_refresh, + ) + with self.assertRaisesRegex(ValueError, "Failed to parse refresh interval with exception"): + self.mm._validate_monitor_config_or_raise(self.test_table_config, mm_config) + + def test_validate_name_constraints(self) -> None: + model_name, version_name = "M" * 231, "V" + m_model_version = _build_mock_model_version(model_name, version_name) + with self.assertRaisesRegex( + ValueError, + "Model name and version name exceeds maximum length of 231", + ): + model_monitor_manager._validate_name_constraints(m_model_version) + + good_model_name = "M" * 230 + m_model_version = _build_mock_model_version(good_model_name, version_name) + model_monitor_manager._validate_name_constraints(m_model_version) + + def test_fetch_task(self) -> None: + model_version = _build_mock_model_version( + self.test_fq_model_name, self.test_version_name, task=type_hints.Task.UNKNOWN + ) + expected_msg = "Registry model must be logged with task in order to be monitored." + with self.assertRaisesRegex(ValueError, expected_msg): + self.mm._fetch_task_from_model_version(model_version) + + def test_validate_function_name(self) -> None: + model_version = _build_mock_model_version(self.test_fq_model_name, self.test_version_name) + bad_function_name = "not_predict" + expected_message = ( + f"Function with name {bad_function_name} does not exist in the given model version. Found: {{'predict'}}." + ) + with self.assertRaisesRegex(ValueError, re.escape(expected_message)): + self.mm._get_and_validate_model_function_from_model_version(bad_function_name, model_version) + + def test_get_monitor_by_model_version(self) -> None: + self.mock_model_monitor_sql_client.validate_existence.return_value = True + self.mock_model_monitor_sql_client.get_model_monitor_by_model_version.return_value = ( + monitor_sql_client._ModelMonitorParams( + monitor_name="TEST_MONITOR_NAME", + fully_qualified_model_name=self.test_fq_model_name, + version_name=self.test_model_version, + function_name="PREDICT", + prediction_columns=[], + label_columns=[], + ) + ) + model_monitor = self.mm.get_monitor_by_model_version(self.mv) + + self.mock_model_monitor_sql_client.validate_existence.assert_called_once_with( + self.test_fq_model_name, self.test_model_version, None + ) + self.mock_model_monitor_sql_client.get_model_monitor_by_model_version.assert_called_once_with( + model_db=self.test_db, + model_schema=self.test_schema, + model_name=self.test_model, + version_name=self.test_model_version, + statement_params=None, + ) + self.assertEqual(model_monitor.name, "TEST_MONITOR_NAME") + self.assertEqual(model_monitor._function_name, "PREDICT") + + def test_get_monitor_by_model_version_not_exists(self) -> None: + with self.assertRaisesRegex(ValueError, "ModelMonitor not found for model version"): + with mock.patch.object( + self.mm._model_monitor_client, "validate_existence", return_value=False + ) as mock_validate_existence: + self.mm.get_monitor_by_model_version(self.mv) + + mock_validate_existence.assert_called_once_with(self.test_fq_model_name, self.test_model_version, None) + + def _init_mm_with_patch(self) -> None: + patcher = patch("snowflake.ml.monitoring._client.monitor_sql_client._ModelMonitorSQLClient", autospec=True) + self.addCleanup(patcher.stop) + self.mock_model_monitor_sql_client_class = patcher.start() + self.mock_model_monitor_sql_client = self.mock_model_monitor_sql_client_class.return_value + self.mm = model_monitor_manager.ModelMonitorManager( + cast(Session, self.m_session), database_name=self.test_db, schema_name=self.test_schema + ) + + +class ModelMonitorManagerTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.test_warehouse = "TEST_WAREHOUSE" + self.test_db = sql_identifier.SqlIdentifier("TEST_DB") + self.test_schema = sql_identifier.SqlIdentifier("TEST_SCHEMA") + + self.test_model_version = "TEST_VERSION" + self.test_model = "TEST_MODEL" + self.test_fq_model_name = f"db1.schema1.{self.test_model}" + self.test_source_table_name = "TEST_TABLE" + + self.mv = _build_mock_model_version(self.test_fq_model_name, self.test_model_version) + + self.test_table_config = model_monitor_config.ModelMonitorTableConfig( + prediction_columns=["PREDICTION"], + label_columns=["LABEL"], + id_columns=["ID"], + timestamp_column="TS", + source_table=self.test_source_table_name, + ) + self.test_monitor_config = model_monitor_config.ModelMonitorConfig( + model_version=self.mv, + model_function_name="predict", + background_compute_warehouse_name=self.test_warehouse, + ) + session = cast(Session, self.m_session) + self.m_session.add_mock_sql( + query=f"""SHOW TABLES LIKE '_SYSTEM_MONITORING_METADATA' IN {self.test_db}.{self.test_schema}""", + result=mock_data_frame.MockDataFrame([Row(name="_SYSTEM_MONITORING_METADATA")]), + ) + self.mm = model_monitor_manager.ModelMonitorManager( + session, database_name=self.test_db, schema_name=self.test_schema + ) + self.mm._model_monitor_client = mock.MagicMock() + + def tearDown(self) -> None: + self.m_session.finalize() + + def test_manual_init(self) -> None: + self.m_session.add_mock_sql( + query=f"""CREATE TABLE IF NOT EXISTS {self.test_db}.{self.test_schema}._SYSTEM_MONITORING_METADATA + (MONITOR_NAME VARCHAR, SOURCE_TABLE_NAME VARCHAR, FULLY_QUALIFIED_MODEL_NAME VARCHAR, + MODEL_VERSION_NAME VARCHAR, FUNCTION_NAME VARCHAR, TASK VARCHAR, IS_ENABLED BOOLEAN, + TIMESTAMP_COLUMN_NAME VARCHAR, PREDICTION_COLUMN_NAMES ARRAY, + LABEL_COLUMN_NAMES ARRAY, ID_COLUMN_NAMES ARRAY) + """, + result=mock_data_frame.MockDataFrame([Row(status="Table successfully created.")]), + ) + self.m_session.add_mock_sql( + query=f"""SHOW TABLES LIKE '_SYSTEM_MONITORING_METADATA' IN {self.test_db}.{self.test_schema}""", + result=mock_data_frame.MockDataFrame([Row(name="_SYSTEM_MONITORING_METADATA")]), + ) + session = cast(Session, self.m_session) + model_monitor_manager.ModelMonitorManager.setup(session, self.test_db, self.test_schema) + model_monitor_manager.ModelMonitorManager( + session, database_name=self.test_db, schema_name=self.test_schema, create_if_not_exists=False + ) + + def test_init_fails_not_initialized(self) -> None: + self.m_session.add_mock_sql( + query=f"""SHOW TABLES LIKE '_SYSTEM_MONITORING_METADATA' IN {self.test_db}.{self.test_schema}""", + result=mock_data_frame.MockDataFrame([]), + ) + session = cast(Session, self.m_session) + expected_msg = "Monitoring has not been setup. Set create_if_not_exists or call ModelMonitorManager.setup" + + with self.assertRaisesRegex(ValueError, expected_msg): + model_monitor_manager.ModelMonitorManager( + session, database_name=self.test_db, schema_name=self.test_schema, create_if_not_exists=False + ) + + def test_add_monitor(self) -> None: + with mock.patch.object( + self.mm._model_monitor_client, "validate_source_table" + ) as mock_validate_source_table, mock.patch.object( + self.mv, "get_model_task", return_value=type_hints.Task.TABULAR_REGRESSION + ) as mock_get_model_task, mock.patch.object( + self.mm._model_monitor_client, + "get_score_type", + return_value=output_score_type.OutputScoreType.REGRESSION, + ) as mock_get_score_type, mock.patch.object( + self.mm._model_monitor_client, "create_monitor_on_model_version", return_value=None + ) as mock_create_monitor_on_model_version, mock.patch.object( + self.mm._model_monitor_client, "create_dynamic_tables_for_monitor", return_value=None + ) as mock_create_dynamic_tables_for_monitor, mock.patch.object( + self.mm._model_monitor_client, + "initialize_baseline_table", + return_value=None, + ) as mock_initialize_baseline_table: + self.mm.add_monitor("TEST", self.test_table_config, self.test_monitor_config) + mock_validate_source_table.assert_called_once_with( + source_table_name=self.test_source_table_name, + timestamp_column="TS", + prediction_columns=["PREDICTION"], + label_columns=["LABEL"], + id_columns=["ID"], + model_function=self.mv.show_functions()[0], + ) + mock_get_model_task.assert_called_once() + mock_get_score_type.assert_called_once() + mock_create_monitor_on_model_version.assert_called_once_with( + monitor_name=sql_identifier.SqlIdentifier("TEST"), + source_table_name=sql_identifier.SqlIdentifier(self.test_source_table_name), + fully_qualified_model_name=self.test_fq_model_name, + version_name=sql_identifier.SqlIdentifier(self.test_model_version), + function_name="predict", + timestamp_column="TS", + prediction_columns=["PREDICTION"], + label_columns=["LABEL"], + id_columns=["ID"], + task=type_hints.Task.TABULAR_REGRESSION, + statement_params=None, + ) + mock_create_dynamic_tables_for_monitor.assert_called_once_with( + model_name="TEST_MODEL", + model_version_name="TEST_VERSION", + task=type_hints.Task.TABULAR_REGRESSION, + source_table_name=self.test_source_table_name, + refresh_interval=model_monitor_interval.ModelMonitorRefreshInterval("1 days"), + aggregation_window=model_monitor_interval.ModelMonitorAggregationWindow.WINDOW_1_DAY, + warehouse_name="TEST_WAREHOUSE", + timestamp_column="TS", + id_columns=["ID"], + prediction_columns=["PREDICTION"], + label_columns=["LABEL"], + score_type=output_score_type.OutputScoreType.REGRESSION, + ) + mock_initialize_baseline_table.assert_called_once_with( + model_name="TEST_MODEL", + version_name="TEST_VERSION", + source_table_name=self.test_source_table_name, + columns_to_drop=[self.test_table_config.timestamp_column, *self.test_table_config.id_columns], + statement_params=None, + ) + + def test_add_monitor_fails_no_task(self) -> None: + with mock.patch.object( + self.mm._model_monitor_client, "validate_source_table" + ) as mock_validate_source_table, mock.patch.object( + self.mv, "get_model_task", return_value=type_hints.Task.UNKNOWN + ): + with self.assertRaisesRegex( + ValueError, "Registry model must be logged with task in order to be monitored." + ): + self.mm.add_monitor("TEST", self.test_table_config, self.test_monitor_config) + mock_validate_source_table.assert_called_once() + + def test_add_monitor_fails_multiple_predictions(self) -> None: + bad_table_config = model_monitor_config.ModelMonitorTableConfig( + source_table=self.test_source_table_name, + prediction_columns=["PREDICTION1", "PREDICTION2"], + label_columns=["LABEL1", "LABEL2"], + id_columns=["ID"], + timestamp_column="TIMESTAMP", + ) + expected_error = "Multiple Output columns are not supported in monitoring" + with self.assertRaisesRegex(ValueError, expected_error): + self.mm.add_monitor("test", bad_table_config, self.test_monitor_config) + self.m_session.finalize() + + def test_add_monitor_fails_column_lengths_do_not_match(self) -> None: + bad_table_config = model_monitor_config.ModelMonitorTableConfig( + source_table=self.test_source_table_name, + prediction_columns=["PREDICTION"], + label_columns=["LABEL1", "LABEL2"], + id_columns=["ID"], + timestamp_column="TIMESTAMP", + ) + expected_msg = "Prediction and Label column names must be of the same length." + with self.assertRaisesRegex(ValueError, expected_msg): + self.mm.add_monitor( + "test", + bad_table_config, + self.test_monitor_config, + ) + + self.m_session.finalize() + + def test_delete_monitor(self) -> None: + monitor = "TEST" + model = "TEST" + version = "V1" + monitor_params = monitor_sql_client._ModelMonitorParams( + monitor_name=monitor, + fully_qualified_model_name=f"TEST_DB.TEST_SCHEMA.{model}", + version_name=version, + function_name="predict", + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + ) + with mock.patch.object( + self.mm._model_monitor_client, "get_model_monitor_by_name", return_value=monitor_params + ) as mock_get_model_monitor_by_name, mock.patch.object( + self.mm._model_monitor_client, "delete_monitor_metadata" + ) as mock_delete_monitor_metadata, mock.patch.object( + self.mm._model_monitor_client, "delete_baseline_table" + ) as mock_delete_baseline_table, mock.patch.object( + self.mm._model_monitor_client, "delete_dynamic_tables" + ) as mock_delete_dynamic_tables: + self.mm.delete_monitor(monitor) + mock_get_model_monitor_by_name.assert_called_once_with(monitor) + mock_delete_monitor_metadata.assert_called_once_with(sql_identifier.SqlIdentifier(monitor)) + mock_delete_baseline_table.assert_called_once_with(model, version) + mock_delete_dynamic_tables.assert_called_once_with(model, version) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/monitoring/_client/model_monitor_test.py b/snowflake/ml/monitoring/_client/model_monitor_test.py new file mode 100644 index 00000000..6aa39941 --- /dev/null +++ b/snowflake/ml/monitoring/_client/model_monitor_test.py @@ -0,0 +1,157 @@ +from typing import cast +from unittest import mock + +import pandas as pd +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.monitoring._client import model_monitor +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import DataFrame, Row + + +class ModelMonitorInstanceTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.test_db_name = sql_identifier.SqlIdentifier("SNOWML_OBSERVABILITY") + self.test_schema_name = sql_identifier.SqlIdentifier("METADATA") + + self.test_monitor_name = sql_identifier.SqlIdentifier("TEST") + self.test_model_version_name = sql_identifier.SqlIdentifier("TEST_MODEL_VERSION") + self.test_model_name = sql_identifier.SqlIdentifier("TEST_MODEL") + self.test_fq_model_name = f"{self.test_db_name}.{self.test_schema_name}.{self.test_model_name}" + self.test_prediction_column_name = sql_identifier.SqlIdentifier("PREDICTION") + self.test_label_column_name = sql_identifier.SqlIdentifier("LABEL") + self.monitor_sql_client = mock.MagicMock(name="sql_client") + + self.model_monitor = model_monitor.ModelMonitor._ref( + model_monitor_client=self.monitor_sql_client, + name=self.test_monitor_name, + fully_qualified_model_name=self.test_fq_model_name, + version_name=self.test_model_version_name, + function_name=sql_identifier.SqlIdentifier("predict"), + prediction_columns=[sql_identifier.SqlIdentifier(self.test_prediction_column_name)], + label_columns=[sql_identifier.SqlIdentifier(self.test_label_column_name)], + ) + + def test_set_baseline(self) -> None: + baseline_df = mock_data_frame.MockDataFrame( + [ + Row( + ID=1, + TIMESTAMP=1, + PREDICTION=0.5, + LABEL=1, + ), + Row( + ID=2, + TIMESTAMP=2, + PREDICTION=0.6, + LABEL=0, + ), + ], + columns=[ + "ID", + "TIMESTAMP", + "PREDICTION", + "LABEL", + ], + ) + with mock.patch.object(self.monitor_sql_client, "materialize_baseline_dataframe") as mock_materialize: + self.model_monitor.set_baseline(cast(DataFrame, baseline_df)) + mock_materialize.assert_called_once_with( + baseline_df, self.test_fq_model_name, self.test_model_version_name, statement_params=mock.ANY + ) + + def test_set_baseline_pandas_df(self) -> None: + # Initialize a test pandas dataframe + pandas_baseline_df = pd.DataFrame( + { + "ID": [1, 2], + "TIMESTAMP": [1, 2], + "PREDICTION": [0.5, 0.6], + "LABEL": [1, 0], + } + ) + snowflake_baseline_df = mock_data_frame.MockDataFrame( + [ + Row( + ID=1, + TIMESTAMP=1, + PREDICTION=0.5, + LABEL=1, + ), + Row( + ID=2, + TIMESTAMP=2, + PREDICTION=0.6, + LABEL=0, + ), + ], + columns=[ + "ID", + "TIMESTAMP", + "PREDICTION", + "LABEL", + ], + ) + + with mock.patch.object( + self.monitor_sql_client, "materialize_baseline_dataframe" + ) as mock_materialize, mock.patch.object(self.monitor_sql_client._sql_client, "_session"), mock.patch.object( + self.monitor_sql_client._sql_client._session, "create_dataframe", return_value=snowflake_baseline_df + ) as mock_create_df: + self.model_monitor.set_baseline(pandas_baseline_df) + mock_materialize.assert_called_once_with( + snowflake_baseline_df, self.test_fq_model_name, self.test_model_version_name, statement_params=mock.ANY + ) + mock_create_df.assert_called_once_with(pandas_baseline_df) + + def test_set_baseline_missing_columns(self) -> None: + baseline_df = mock_data_frame.MockDataFrame( + [ + Row( + ID=1, + TIMESTAMP=1, + PREDICTION=0.5, + LABEL=1, + ), + Row( + ID=2, + TIMESTAMP=2, + PREDICTION=0.6, + LABEL=0, + ), + ], + columns=[ + "ID", + "TIMESTAMP", + "LABEL", + ], + ) + + expected_msg = "Specified prediction columns were not found in the baseline dataframe. Columns provided were: " + with self.assertRaisesRegex(ValueError, expected_msg): + self.model_monitor.set_baseline(cast(DataFrame, baseline_df)) + + def test_suspend(self) -> None: + with mock.patch.object( + self.model_monitor._model_monitor_client, "suspend_monitor_dynamic_tables" + ) as mock_suspend: + self.model_monitor.suspend() + mock_suspend.assert_called_once_with( + model_name=self.test_model_name, version_name=self.test_model_version_name, statement_params=mock.ANY + ) + + def test_resume(self) -> None: + with mock.patch.object( + self.model_monitor._model_monitor_client, "resume_monitor_dynamic_tables" + ) as mock_suspend: + self.model_monitor.resume() + mock_suspend.assert_called_once_with( + model_name=self.test_model_name, version_name=self.test_model_version_name, statement_params=mock.ANY + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/monitoring/_client/model_monitor_version.py b/snowflake/ml/monitoring/_client/model_monitor_version.py new file mode 100644 index 00000000..ae06d754 --- /dev/null +++ b/snowflake/ml/monitoring/_client/model_monitor_version.py @@ -0,0 +1 @@ +SNOWFLAKE_ML_MONITORING_MIN_VERSION = "1.7.0" diff --git a/snowflake/ml/monitoring/_client/monitor_sql_client.py b/snowflake/ml/monitoring/_client/monitor_sql_client.py new file mode 100644 index 00000000..ef5da1c0 --- /dev/null +++ b/snowflake/ml/monitoring/_client/monitor_sql_client.py @@ -0,0 +1,1335 @@ +import json +import string +import textwrap +import typing +from collections import Counter +from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, TypedDict + +from importlib_resources import files +from typing_extensions import Required + +from snowflake import snowpark +from snowflake.connector import errors +from snowflake.ml._internal.utils import ( + db_utils, + formatting, + query_result_checker, + sql_identifier, + table_manager, +) +from snowflake.ml.model import type_hints +from snowflake.ml.model._client.sql import _base +from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema +from snowflake.ml.monitoring.entities import model_monitor_interval, output_score_type +from snowflake.ml.monitoring.entities.model_monitor_interval import ( + ModelMonitorAggregationWindow, + ModelMonitorRefreshInterval, +) +from snowflake.snowpark import DataFrame, exceptions, session, types +from snowflake.snowpark._internal import type_utils + +SNOWML_MONITORING_METADATA_TABLE_NAME = "_SYSTEM_MONITORING_METADATA" +_SNOWML_MONITORING_TABLE_NAME_PREFIX = "_SNOWML_OBS_MONITORING_" +_SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX = "_SNOWML_OBS_ACCURACY_" + +MONITOR_NAME_COL_NAME = "MONITOR_NAME" +SOURCE_TABLE_NAME_COL_NAME = "SOURCE_TABLE_NAME" +FQ_MODEL_NAME_COL_NAME = "FULLY_QUALIFIED_MODEL_NAME" +VERSION_NAME_COL_NAME = "MODEL_VERSION_NAME" +FUNCTION_NAME_COL_NAME = "FUNCTION_NAME" +TASK_COL_NAME = "TASK" +MONITORING_ENABLED_COL_NAME = "IS_ENABLED" +TIMESTAMP_COL_NAME_COL_NAME = "TIMESTAMP_COLUMN_NAME" +PREDICTION_COL_NAMES_COL_NAME = "PREDICTION_COLUMN_NAMES" +LABEL_COL_NAMES_COL_NAME = "LABEL_COLUMN_NAMES" +ID_COL_NAMES_COL_NAME = "ID_COLUMN_NAMES" + +_DASHBOARD_UDTFS_COMMON_LIST = ["record_count"] +_DASHBOARD_UDTFS_REGRESSION_LIST = ["rmse"] + + +def _initialize_monitoring_metadata_tables( + session: session.Session, + database_name: sql_identifier.SqlIdentifier, + schema_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, +) -> None: + """Create tables necessary for Model Monitoring in provided schema. + + Args: + session: Active Snowpark session. + database_name: The database in which to setup resources for Model Monitoring. + schema_name: The schema in which to setup resources for Model Monitoring. + statement_params: Optional statement params for queries. + """ + table_manager.create_single_table( + session, + database_name, + schema_name, + SNOWML_MONITORING_METADATA_TABLE_NAME, + [ + (MONITOR_NAME_COL_NAME, "VARCHAR"), + (SOURCE_TABLE_NAME_COL_NAME, "VARCHAR"), + (FQ_MODEL_NAME_COL_NAME, "VARCHAR"), + (VERSION_NAME_COL_NAME, "VARCHAR"), + (FUNCTION_NAME_COL_NAME, "VARCHAR"), + (TASK_COL_NAME, "VARCHAR"), + (MONITORING_ENABLED_COL_NAME, "BOOLEAN"), + (TIMESTAMP_COL_NAME_COL_NAME, "VARCHAR"), + (PREDICTION_COL_NAMES_COL_NAME, "ARRAY"), + (LABEL_COL_NAMES_COL_NAME, "ARRAY"), + (ID_COL_NAMES_COL_NAME, "ARRAY"), + ], + statement_params=statement_params, + ) + + +def _create_baseline_table_name(model_name: str, version_name: str) -> str: + return f"_SNOWML_OBS_BASELINE_{model_name}_{version_name}" + + +def _infer_numeric_categoric_feature_column_names( + *, + source_table_schema: Mapping[str, types.DataType], + timestamp_column: sql_identifier.SqlIdentifier, + id_columns: List[sql_identifier.SqlIdentifier], + prediction_columns: List[sql_identifier.SqlIdentifier], + label_columns: List[sql_identifier.SqlIdentifier], +) -> Tuple[List[sql_identifier.SqlIdentifier], List[sql_identifier.SqlIdentifier]]: + cols_to_remove = {timestamp_column, *id_columns, *prediction_columns, *label_columns} + cols_to_consider = [ + (col_name, source_table_schema[col_name]) for col_name in source_table_schema if col_name not in cols_to_remove + ] + numeric_cols = [ + sql_identifier.SqlIdentifier(column[0]) + for column in cols_to_consider + if isinstance(column[1], types._NumericType) + ] + categorical_cols = [ + sql_identifier.SqlIdentifier(column[0]) + for column in cols_to_consider + if isinstance(column[1], types.StringType) or isinstance(column[1], types.BooleanType) + ] + return (numeric_cols, categorical_cols) + + +class _ModelMonitorParams(TypedDict): + """Class to transfer model monitor parameters to the ModelMonitor class.""" + + monitor_name: Required[str] + fully_qualified_model_name: Required[str] + version_name: Required[str] + function_name: Required[str] + prediction_columns: Required[List[sql_identifier.SqlIdentifier]] + label_columns: Required[List[sql_identifier.SqlIdentifier]] + + +class _ModelMonitorSQLClient: + def __init__( + self, + session: session.Session, + *, + database_name: sql_identifier.SqlIdentifier, + schema_name: sql_identifier.SqlIdentifier, + ) -> None: + """Client to manage monitoring metadata persisted in SNOWML_OBSERVABILITY.METADATA schema. + + Args: + session: Active snowpark session. + database_name: Name of the Database where monitoring resources are provisioned. + schema_name: Name of the Schema where monitoring resources are provisioned. + """ + self._sql_client = _base._BaseSQLClient(session, database_name=database_name, schema_name=schema_name) + self._database_name = database_name + self._schema_name = schema_name + + @staticmethod + def initialize_monitoring_schema( + session: session.Session, + database_name: sql_identifier.SqlIdentifier, + schema_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize tables for tracking metadata associated with model monitoring. + + Args: + session: The Snowpark Session to connect with Snowflake. + database_name: The database in which to setup resources for Model Monitoring. + schema_name: The schema in which to setup resources for Model Monitoring. + statement_params: Optional set of statement_params to include with query. + """ + # Create metadata management tables + _initialize_monitoring_metadata_tables(session, database_name, schema_name, statement_params) + + def _validate_is_initialized(self) -> bool: + """Validates whether monitoring metadata has been initialized. + + Returns: + boolean to indicate whether tables have been initialized. + """ + try: + return table_manager.validate_table_exist( + self._sql_client._session, + SNOWML_MONITORING_METADATA_TABLE_NAME, + f"{self._database_name}.{self._schema_name}", + ) + except exceptions.SnowparkSQLException: + return False + + def _validate_unique_columns( + self, + timestamp_column: sql_identifier.SqlIdentifier, + id_columns: List[sql_identifier.SqlIdentifier], + prediction_columns: List[sql_identifier.SqlIdentifier], + label_columns: List[sql_identifier.SqlIdentifier], + ) -> None: + all_columns = [*id_columns, *prediction_columns, *label_columns, timestamp_column] + num_all_columns = len(all_columns) + num_unique_columns = len(set(all_columns)) + if num_all_columns != num_unique_columns: + raise ValueError("Column names must be unique across id, timestamp, prediction, and label columns.") + + def validate_existence_by_name( + self, + monitor_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> bool: + res = ( + query_result_checker.SqlResultValidator( + self._sql_client._session, + f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME} + FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME} + WHERE {MONITOR_NAME_COL_NAME} = '{monitor_name}'""", + statement_params=statement_params, + ) + .has_column(FQ_MODEL_NAME_COL_NAME, allow_empty=True) + .has_column(VERSION_NAME_COL_NAME, allow_empty=True) + .validate() + ) + return len(res) >= 1 + + def validate_existence( + self, + fully_qualified_model_name: str, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> bool: + """Validate existence of a ModelMonitor on a Model Version. + + Args: + fully_qualified_model_name: Fully qualified name of model. + version_name: Name of model version. + statement_params: Optional set of statement_params to include with query. + + Returns: + Boolean indicating whether monitor exists on model version. + """ + res = ( + query_result_checker.SqlResultValidator( + self._sql_client._session, + f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME} + FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME} + WHERE {FQ_MODEL_NAME_COL_NAME} = '{fully_qualified_model_name}' + AND {VERSION_NAME_COL_NAME} = '{version_name}'""", + statement_params=statement_params, + ) + .has_column(FQ_MODEL_NAME_COL_NAME, allow_empty=True) + .has_column(VERSION_NAME_COL_NAME, allow_empty=True) + .validate() + ) + return len(res) >= 1 + + def validate_monitor_warehouse( + self, + warehouse_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + """Validate warehouse provided for monitoring exists. + + Args: + warehouse_name: Warehouse name + statement_params: Optional set of statement params to include in queries. + + Raises: + ValueError: If warehouse does not exist. + """ + if not db_utils.db_object_exists( + session=self._sql_client._session, + object_type=db_utils.SnowflakeDbObjectType.WAREHOUSE, + object_name=warehouse_name, + statement_params=statement_params, + ): + raise ValueError(f"Warehouse '{warehouse_name}' not found.") + + def add_dashboard_udtfs( + self, + monitor_name: sql_identifier.SqlIdentifier, + model_name: sql_identifier.SqlIdentifier, + model_version_name: sql_identifier.SqlIdentifier, + task: type_hints.Task, + score_type: output_score_type.OutputScoreType, + output_columns: List[sql_identifier.SqlIdentifier], + ground_truth_columns: List[sql_identifier.SqlIdentifier], + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + udtf_name_query_map = self._create_dashboard_udtf_queries( + monitor_name, + model_name, + model_version_name, + task, + score_type, + output_columns, + ground_truth_columns, + ) + for udtf_query in udtf_name_query_map.values(): + query_result_checker.SqlResultValidator( + self._sql_client._session, + f"""{udtf_query}""", + statement_params=statement_params, + ).validate() + + def get_monitoring_table_fully_qualified_name( + self, + model_name: sql_identifier.SqlIdentifier, + model_version_name: sql_identifier.SqlIdentifier, + ) -> str: + table_name = f"{_SNOWML_MONITORING_TABLE_NAME_PREFIX}_{model_name}_{model_version_name}" + return table_manager.get_fully_qualified_table_name(self._database_name, self._schema_name, table_name) + + def get_accuracy_monitoring_table_fully_qualified_name( + self, + model_name: sql_identifier.SqlIdentifier, + model_version_name: sql_identifier.SqlIdentifier, + ) -> str: + table_name = f"{_SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX}_{model_name}_{model_version_name}" + return table_manager.get_fully_qualified_table_name(self._database_name, self._schema_name, table_name) + + def _create_dashboard_udtf_queries( + self, + monitor_name: sql_identifier.SqlIdentifier, + model_name: sql_identifier.SqlIdentifier, + model_version_name: sql_identifier.SqlIdentifier, + task: type_hints.Task, + score_type: output_score_type.OutputScoreType, + output_columns: List[sql_identifier.SqlIdentifier], + ground_truth_columns: List[sql_identifier.SqlIdentifier], + ) -> Mapping[str, str]: + query_files = files("snowflake.ml.monitoring._client") + # TODO(apgupta): Expand list of queries based on model objective and score type. + queries_list = [] + queries_list.extend(_DASHBOARD_UDTFS_COMMON_LIST) + if task == type_hints.Task.TABULAR_REGRESSION: + queries_list.extend(_DASHBOARD_UDTFS_REGRESSION_LIST) + var_map = { + "MODEL_MONITOR_NAME": monitor_name, + "MONITORING_TABLE": self.get_monitoring_table_fully_qualified_name(model_name, model_version_name), + "MONITORING_PRED_LABEL_JOINED_TABLE": self.get_accuracy_monitoring_table_fully_qualified_name( + model_name, model_version_name + ), + "OUTPUT_COLUMN_NAME": output_columns[0], + "GROUND_TRUTH_COLUMN_NAME": ground_truth_columns[0], + } + + udf_name_query_map = {} + for q in queries_list: + q_template = query_files.joinpath(f"queries/{q}.ssql").read_text() + q_actual = string.Template(q_template).substitute(var_map) + udf_name_query_map[q] = q_actual + return udf_name_query_map + + def _validate_columns_exist_in_source_table( + self, + *, + table_schema: Mapping[str, types.DataType], + source_table_name: sql_identifier.SqlIdentifier, + timestamp_column: sql_identifier.SqlIdentifier, + prediction_columns: List[sql_identifier.SqlIdentifier], + label_columns: List[sql_identifier.SqlIdentifier], + id_columns: List[sql_identifier.SqlIdentifier], + ) -> None: + """Ensures all columns exist in the source table. + + Args: + table_schema: Dictionary of column names and types in the source table. + source_table_name: Name of the table with model data to monitor. + timestamp_column: Name of the timestamp column. + prediction_columns: List of prediction column names. + label_columns: List of label column names. + id_columns: List of id column names. + + Raises: + ValueError: If any of the columns do not exist in the source table. + """ + + if timestamp_column not in table_schema: + raise ValueError(f"Timestamp column {timestamp_column} does not exist in table {source_table_name}.") + + if not all([column_name in table_schema for column_name in prediction_columns]): + raise ValueError(f"Prediction column(s): {prediction_columns} do not exist in table {source_table_name}.") + if not all([column_name in table_schema for column_name in label_columns]): + raise ValueError(f"Label column(s): {label_columns} do not exist in table {source_table_name}.") + if not all([column_name in table_schema for column_name in id_columns]): + raise ValueError(f"ID column(s): {id_columns} do not exist in table {source_table_name}.") + + def _validate_timestamp_column_type( + self, table_schema: Mapping[str, types.DataType], timestamp_column: sql_identifier.SqlIdentifier + ) -> None: + """Ensures columns have the same type. + + Args: + table_schema: Dictionary of column names and types in the source table. + timestamp_column: Name of the timestamp column. + + Raises: + ValueError: If the timestamp column is not of type TimestampType. + """ + if not isinstance(table_schema[timestamp_column], types.TimestampType): + raise ValueError( + f"Timestamp column: {timestamp_column} must be TimestampType. " + f"Found: {table_schema[timestamp_column]}" + ) + + def _validate_id_columns_types( + self, table_schema: Mapping[str, types.DataType], id_columns: List[sql_identifier.SqlIdentifier] + ) -> None: + """Ensures id columns have the correct type. + + Args: + table_schema: Dictionary of column names and types in the source table. + id_columns: List of id column names. + + Raises: + ValueError: If the id column is not of type StringType. + """ + id_column_types = list({table_schema[column_name] for column_name in id_columns}) + all_id_columns_string = all([isinstance(column_type, types.StringType) for column_type in id_column_types]) + if not all_id_columns_string: + raise ValueError(f"Id columns must all be StringType. Found: {id_column_types}") + + def _validate_prediction_columns_types( + self, table_schema: Mapping[str, types.DataType], prediction_columns: List[sql_identifier.SqlIdentifier] + ) -> None: + """Ensures prediction columns have the same type. + + Args: + table_schema: Dictionary of column names and types in the source table. + prediction_columns: List of prediction column names. + + Raises: + ValueError: If the prediction columns do not share the same type. + """ + + prediction_column_types = {table_schema[column_name] for column_name in prediction_columns} + if len(prediction_column_types) > 1: + raise ValueError(f"Prediction column types must be the same. Found: {prediction_column_types}") + + def _validate_label_columns_types( + self, + table_schema: Mapping[str, types.DataType], + label_columns: List[sql_identifier.SqlIdentifier], + ) -> None: + """Ensures label columns have the same type, and the correct type for the score type. + + Args: + table_schema: Dictionary of column names and types in the source table. + label_columns: List of label column names. + + Raises: + ValueError: If the label columns do not share the same type. + """ + label_column_types = {table_schema[column_name] for column_name in label_columns} + if len(label_column_types) > 1: + raise ValueError(f"Label column types must be the same. Found: {label_column_types}") + + def _validate_column_types( + self, + *, + table_schema: Mapping[str, types.DataType], + timestamp_column: sql_identifier.SqlIdentifier, + id_columns: List[sql_identifier.SqlIdentifier], + prediction_columns: List[sql_identifier.SqlIdentifier], + label_columns: List[sql_identifier.SqlIdentifier], + ) -> None: + """Ensures columns have the expected type. + + Args: + table_schema: Dictionary of column names and types in the source table. + timestamp_column: Name of the timestamp column. + id_columns: List of id column names. + prediction_columns: List of prediction column names. + label_columns: List of label column names. + """ + self._validate_timestamp_column_type(table_schema, timestamp_column) + self._validate_id_columns_types(table_schema, id_columns) + self._validate_prediction_columns_types(table_schema, prediction_columns) + self._validate_label_columns_types(table_schema, label_columns) + # TODO(SNOW-1646693): Validate label makes sense with model task + + def _validate_source_table_features_shape( + self, + table_schema: Mapping[str, types.DataType], + special_columns: Set[sql_identifier.SqlIdentifier], + model_function: model_manifest_schema.ModelFunctionInfo, + ) -> None: + table_schema_without_special_columns = { + k: v for k, v in table_schema.items() if sql_identifier.SqlIdentifier(k) not in special_columns + } + schema_column_types_to_count: typing.Counter[types.DataType] = Counter() + for column_type in table_schema_without_special_columns.values(): + schema_column_types_to_count[column_type] += 1 + + inputs = model_function["signature"].inputs + function_input_types = [input.as_snowpark_type() for input in inputs] + function_input_types_to_count: typing.Counter[types.DataType] = Counter() + for function_input_type in function_input_types: + function_input_types_to_count[function_input_type] += 1 + + if function_input_types_to_count != schema_column_types_to_count: + raise ValueError( + "Model function input types do not match the source table input columns types. " + f"Model function expected: {inputs} but got {table_schema_without_special_columns}" + ) + + def get_model_monitor_by_name( + self, + monitor_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> _ModelMonitorParams: + """Fetch metadata for a Model Monitor by name. + + Args: + monitor_name: Name of ModelMonitor to fetch. + statement_params: Optional set of statement_params to include with query. + + Returns: + _ModelMonitorParams dict with Name of monitor, fully qualified model name, + model version name, model function name, prediction_col, label_col. + + Raises: + ValueError: If multiple ModelMonitors exist with the same name. + """ + try: + res = ( + query_result_checker.SqlResultValidator( + self._sql_client._session, + f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME}, + {PREDICTION_COL_NAMES_COL_NAME}, {LABEL_COL_NAMES_COL_NAME} + FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME} + WHERE {MONITOR_NAME_COL_NAME} = '{monitor_name}'""", + statement_params=statement_params, + ) + .has_column(FQ_MODEL_NAME_COL_NAME) + .has_column(VERSION_NAME_COL_NAME) + .has_column(FUNCTION_NAME_COL_NAME) + .has_column(PREDICTION_COL_NAMES_COL_NAME) + .has_column(LABEL_COL_NAMES_COL_NAME) + .validate() + ) + except errors.DataError: + raise ValueError(f"Failed to find any monitor with name '{monitor_name}'") + + if len(res) > 1: + raise ValueError(f"Invalid state. Multiple Monitors exist with name '{monitor_name}'") + + return _ModelMonitorParams( + monitor_name=str(monitor_name), + fully_qualified_model_name=res[0][FQ_MODEL_NAME_COL_NAME], + version_name=res[0][VERSION_NAME_COL_NAME], + function_name=res[0][FUNCTION_NAME_COL_NAME], + prediction_columns=[ + sql_identifier.SqlIdentifier(prediction_column) + for prediction_column in json.loads(res[0][PREDICTION_COL_NAMES_COL_NAME]) + ], + label_columns=[ + sql_identifier.SqlIdentifier(label_column) + for label_column in json.loads(res[0][LABEL_COL_NAMES_COL_NAME]) + ], + ) + + def get_model_monitor_by_model_version( + self, + *, + model_db: sql_identifier.SqlIdentifier, + model_schema: sql_identifier.SqlIdentifier, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> _ModelMonitorParams: + """Fetch metadata for a Model Monitor by model version. + + Args: + model_db: Database of model. + model_schema: Schema of model. + model_name: Model name. + version_name: Model version name + statement_params: Optional set of statement_params to include with queries. + + Returns: + _ModelMonitorParams dict with Name of monitor, fully qualified model name, + model version name, model function name, prediction_col, label_col. + + Raises: + ValueError: If multiple ModelMonitors exist with the same name. + """ + res = ( + query_result_checker.SqlResultValidator( + self._sql_client._session, + f"""SELECT {MONITOR_NAME_COL_NAME}, {FQ_MODEL_NAME_COL_NAME}, + {VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME} + FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME} + WHERE {FQ_MODEL_NAME_COL_NAME} = '{model_db}.{model_schema}.{model_name}' + AND {VERSION_NAME_COL_NAME} = '{version_name}'""", + statement_params=statement_params, + ) + .has_column(MONITOR_NAME_COL_NAME) + .has_column(FQ_MODEL_NAME_COL_NAME) + .has_column(VERSION_NAME_COL_NAME) + .has_column(FUNCTION_NAME_COL_NAME) + .validate() + ) + if len(res) > 1: + raise ValueError( + f"Invalid state. Multiple Monitors exist for model: '{model_name}' and version: '{version_name}'" + ) + return _ModelMonitorParams( + monitor_name=res[0][MONITOR_NAME_COL_NAME], + fully_qualified_model_name=res[0][FQ_MODEL_NAME_COL_NAME], + version_name=res[0][VERSION_NAME_COL_NAME], + function_name=res[0][FUNCTION_NAME_COL_NAME], + prediction_columns=[ + sql_identifier.SqlIdentifier(prediction_column) + for prediction_column in json.loads(res[0][PREDICTION_COL_NAMES_COL_NAME]) + ], + label_columns=[ + sql_identifier.SqlIdentifier(label_column) + for label_column in json.loads(res[0][LABEL_COL_NAMES_COL_NAME]) + ], + ) + + def get_score_type( + self, + task: type_hints.Task, + source_table_name: sql_identifier.SqlIdentifier, + prediction_columns: List[sql_identifier.SqlIdentifier], + ) -> output_score_type.OutputScoreType: + """Infer score type given model task and prediction table columns. + + Args: + task: Model task + source_table_name: Source data table containing model outputs. + prediction_columns: columns in source data table corresponding to model outputs. + + Returns: + OutputScoreType for model. + """ + table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types( + self._sql_client._session, + self._database_name, + self._schema_name, + source_table_name, + ) + return output_score_type.OutputScoreType.deduce_score_type(table_schema, prediction_columns, task) + + def validate_source_table( + self, + source_table_name: sql_identifier.SqlIdentifier, + timestamp_column: sql_identifier.SqlIdentifier, + prediction_columns: List[sql_identifier.SqlIdentifier], + label_columns: List[sql_identifier.SqlIdentifier], + id_columns: List[sql_identifier.SqlIdentifier], + model_function: model_manifest_schema.ModelFunctionInfo, + ) -> None: + # Validate source table exists + if not table_manager.validate_table_exist( + self._sql_client._session, + source_table_name, + f"{self._database_name}.{self._schema_name}", + ): + raise ValueError( + f"Table {source_table_name} does not exist in schema {self._database_name}.{self._schema_name}." + ) + table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types( + self._sql_client._session, + self._database_name, + self._schema_name, + source_table_name, + ) + self._validate_columns_exist_in_source_table( + table_schema=table_schema, + source_table_name=source_table_name, + timestamp_column=timestamp_column, + prediction_columns=prediction_columns, + label_columns=label_columns, + id_columns=id_columns, + ) + self._validate_column_types( + table_schema=table_schema, + timestamp_column=timestamp_column, + id_columns=id_columns, + prediction_columns=prediction_columns, + label_columns=label_columns, + ) + self._validate_source_table_features_shape( + table_schema=table_schema, + special_columns={timestamp_column, *id_columns, *prediction_columns, *label_columns}, + model_function=model_function, + ) + + def delete_monitor_metadata( + self, + name: str, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + """Delete the row in the metadata table corresponding to the given monitor name. + + Args: + name: Name of the model monitor whose metadata should be deleted. + statement_params: Optional set of statement_params to include with query. + """ + self._sql_client._session.sql( + f"""DELETE FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME} + WHERE {MONITOR_NAME_COL_NAME} = '{name}'""", + ).collect(statement_params=statement_params) + + def delete_baseline_table( + self, + fully_qualified_model_name: str, + version_name: str, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + """Delete the baseline table corresponding to a particular model and version. + + Args: + fully_qualified_model_name: Fully qualified name of the model. + version_name: Name of the model version. + statement_params: Optional set of statement_params to include with query. + """ + table_name = _create_baseline_table_name(fully_qualified_model_name, version_name) + self._sql_client._session.sql( + f"""DROP TABLE IF EXISTS {self._database_name}.{self._schema_name}.{table_name}""" + ).collect(statement_params=statement_params) + + def delete_dynamic_tables( + self, + fully_qualified_model_name: str, + version_name: str, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + """Delete the dynamic tables corresponding to a particular model and version. + + Args: + fully_qualified_model_name: Fully qualified name of the model. + version_name: Name of the model version. + statement_params: Optional set of statement_params to include with query. + """ + _, _, model_name = sql_identifier.parse_fully_qualified_name(fully_qualified_model_name) + model_id = sql_identifier.SqlIdentifier(model_name) + version_id = sql_identifier.SqlIdentifier(version_name) + monitoring_table_name = self.get_monitoring_table_fully_qualified_name(model_id, version_id) + self._sql_client._session.sql(f"""DROP DYNAMIC TABLE IF EXISTS {monitoring_table_name}""").collect( + statement_params=statement_params + ) + accuracy_table_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_id, version_id) + self._sql_client._session.sql(f"""DROP DYNAMIC TABLE IF EXISTS {accuracy_table_name}""").collect( + statement_params=statement_params + ) + + def create_monitor_on_model_version( + self, + monitor_name: sql_identifier.SqlIdentifier, + source_table_name: sql_identifier.SqlIdentifier, + fully_qualified_model_name: str, + version_name: sql_identifier.SqlIdentifier, + function_name: str, + timestamp_column: sql_identifier.SqlIdentifier, + prediction_columns: List[sql_identifier.SqlIdentifier], + label_columns: List[sql_identifier.SqlIdentifier], + id_columns: List[sql_identifier.SqlIdentifier], + task: type_hints.Task, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Creates a ModelMonitor on a Model Version from the Snowflake Model Registry. Creates public schema for metadata. + + Args: + monitor_name: Name of monitor object to create. + source_table_name: Name of source data table to monitor. + fully_qualified_model_name: fully qualified name of model to monitor '..'. + version_name: model version name to monitor. + function_name: function_name to monitor in model version. + timestamp_column: timestamp column name. + prediction_columns: list of prediction column names. + label_columns: list of label column names. + id_columns: list of id column names. + task: Task of the model, e.g. TABULAR_REGRESSION. + statement_params: Optional dict of statement_params to include with queries. + + Raises: + ValueError: If model version is already monitored. + """ + # Validate monitor does not already exist on model version. + if self.validate_existence(fully_qualified_model_name, version_name, statement_params): + raise ValueError(f"Model {fully_qualified_model_name} Version {version_name} is already monitored!") + + if self.validate_existence_by_name(monitor_name, statement_params): + raise ValueError(f"Model Monitor with name '{monitor_name}' already exists!") + + prediction_columns_for_select = formatting.format_value_for_select(prediction_columns) + label_columns_for_select = formatting.format_value_for_select(label_columns) + id_columns_for_select = formatting.format_value_for_select(id_columns) + query_result_checker.SqlResultValidator( + self._sql_client._session, + textwrap.dedent( + f"""INSERT INTO {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME} + ({MONITOR_NAME_COL_NAME}, {SOURCE_TABLE_NAME_COL_NAME}, {FQ_MODEL_NAME_COL_NAME}, + {VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME}, {TASK_COL_NAME}, + {MONITORING_ENABLED_COL_NAME}, {TIMESTAMP_COL_NAME_COL_NAME}, + {PREDICTION_COL_NAMES_COL_NAME}, {LABEL_COL_NAMES_COL_NAME}, + {ID_COL_NAMES_COL_NAME}) + SELECT '{monitor_name}', '{source_table_name}', '{fully_qualified_model_name}', + '{version_name}', '{function_name}', '{task.value}', TRUE, '{timestamp_column}', + {prediction_columns_for_select}, {label_columns_for_select}, {id_columns_for_select}""" + ), + statement_params=statement_params, + ).insertion_success(expected_num_rows=1).validate() + + def initialize_baseline_table( + self, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + source_table_name: str, + columns_to_drop: Optional[List[sql_identifier.SqlIdentifier]] = None, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Initializes the baseline table for a Model Version. Creates schema for baseline data using the source table. + + Args: + model_name: name of model to monitor. + version_name: model version name to monitor. + source_table_name: name of the user's table containing their model data. + columns_to_drop: special columns in the source table to be excluded from baseline tables. + statement_params: Optional dict of statement_params to include with queries. + """ + table_schema = table_manager.get_table_schema_types( + self._sql_client._session, + database=self._database_name, + schema=self._schema_name, + table_name=source_table_name, + ) + + if columns_to_drop is None: + columns_to_drop = [] + + table_manager.create_single_table( + self._sql_client._session, + self._database_name, + self._schema_name, + _create_baseline_table_name(model_name, version_name), + [ + (k, type_utils.convert_sp_to_sf_type(v)) + for k, v in table_schema.items() + if sql_identifier.SqlIdentifier(k) not in columns_to_drop + ], + statement_params=statement_params, + ) + + def get_all_model_monitor_metadata( + self, + statement_params: Optional[Dict[str, Any]] = None, + ) -> List[snowpark.Row]: + """Get the metadata for all model monitors in the given schema. + + Args: + statement_params: Optional dict of statement_params to include with queries. + + Returns: + List of snowpark.Row containing metadata for each model monitor. + """ + return query_result_checker.SqlResultValidator( + self._sql_client._session, + f"""SELECT * + FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}""", + statement_params=statement_params, + ).validate() + + def materialize_baseline_dataframe( + self, + baseline_df: DataFrame, + fully_qualified_model_name: str, + model_version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Materialize baseline dataframe to a permanent snowflake table. This method + truncates (overwrite without dropping) any existing data in the baseline table. + + Args: + baseline_df: dataframe containing baseline data that monitored data will be compared against. + fully_qualified_model_name: name of the model. + model_version_name: model version name to monitor. + statement_params: Optional dict of statement_params to include with queries. + + Raises: + ValueError: If no baseline table was initialized. + """ + + _, _, model_name = sql_identifier.parse_fully_qualified_name(fully_qualified_model_name) + baseline_table_name = _create_baseline_table_name(model_name, model_version_name) + + baseline_table_exists = db_utils.db_object_exists( + self._sql_client._session, + db_utils.SnowflakeDbObjectType.TABLE, + sql_identifier.SqlIdentifier(baseline_table_name), + database_name=self._database_name, + schema_name=self._schema_name, + statement_params=statement_params, + ) + if not baseline_table_exists: + raise ValueError( + f"Baseline table '{baseline_table_name}' does not exist for model: " + f"'{model_name}' and model_version: '{model_version_name}'" + ) + + fully_qualified_baseline_table_name = [self._database_name, self._schema_name, baseline_table_name] + + try: + # Truncate overwrites by clearing the rows in the table, instead of dropping the table. + # This lets us keep the schema to validate the baseline_df against. + baseline_df.write.mode("truncate").save_as_table( + fully_qualified_baseline_table_name, statement_params=statement_params + ) + except exceptions.SnowparkSQLException as e: + raise ValueError( + f"""Failed to save baseline dataframe. + Ensure that the baseline dataframe columns match those provided in your monitored table: {e}""" + ) + + def _alter_monitor_dynamic_tables( + self, + operation: str, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + if operation not in {"SUSPEND", "RESUME"}: + raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables") + fq_monitor_dt_name = self.get_monitoring_table_fully_qualified_name(model_name, version_name) + query_result_checker.SqlResultValidator( + self._sql_client._session, + f"""ALTER DYNAMIC TABLE {fq_monitor_dt_name} {operation}""", + statement_params=statement_params, + ).has_column("status").has_dimensions(1, 1).validate() + + fq_accuracy_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, version_name) + query_result_checker.SqlResultValidator( + self._sql_client._session, + f"""ALTER DYNAMIC TABLE {fq_accuracy_dt_name} {operation}""", + statement_params=statement_params, + ).has_column("status").has_dimensions(1, 1).validate() + + def suspend_monitor_dynamic_tables( + self, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + self._alter_monitor_dynamic_tables( + operation="SUSPEND", + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + + def resume_monitor_dynamic_tables( + self, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + self._alter_monitor_dynamic_tables( + operation="RESUME", + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + + def create_dynamic_tables_for_monitor( + self, + *, + model_name: sql_identifier.SqlIdentifier, + model_version_name: sql_identifier.SqlIdentifier, + task: type_hints.Task, + source_table_name: sql_identifier.SqlIdentifier, + refresh_interval: model_monitor_interval.ModelMonitorRefreshInterval, + aggregation_window: model_monitor_interval.ModelMonitorAggregationWindow, + warehouse_name: sql_identifier.SqlIdentifier, + timestamp_column: sql_identifier.SqlIdentifier, + id_columns: List[sql_identifier.SqlIdentifier], + prediction_columns: List[sql_identifier.SqlIdentifier], + label_columns: List[sql_identifier.SqlIdentifier], + score_type: output_score_type.OutputScoreType, + ) -> None: + table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types( + self._sql_client._session, + self._database_name, + self._schema_name, + source_table_name, + ) + (numeric_features_names, categorical_feature_names) = _infer_numeric_categoric_feature_column_names( + source_table_schema=table_schema, + timestamp_column=timestamp_column, + id_columns=id_columns, + prediction_columns=prediction_columns, + label_columns=label_columns, + ) + features_dynamic_table_query = self._monitoring_dynamic_table_query( + model_name=model_name, + model_version_name=model_version_name, + source_table_name=source_table_name, + refresh_interval=refresh_interval, + aggregate_window=aggregation_window, + warehouse_name=warehouse_name, + timestamp_column=timestamp_column, + numeric_features=numeric_features_names, + categoric_features=categorical_feature_names, + prediction_columns=prediction_columns, + label_columns=label_columns, + ) + query_result_checker.SqlResultValidator(self._sql_client._session, features_dynamic_table_query).has_column( + "status" + ).has_dimensions(1, 1).validate() + + label_pred_join_table_query = self._monitoring_accuracy_table_query( + model_name=model_name, + model_version_name=model_version_name, + task=task, + source_table_name=source_table_name, + refresh_interval=refresh_interval, + aggregate_window=aggregation_window, + warehouse_name=warehouse_name, + timestamp_column=timestamp_column, + prediction_columns=prediction_columns, + label_columns=label_columns, + score_type=score_type, + ) + query_result_checker.SqlResultValidator(self._sql_client._session, label_pred_join_table_query).has_column( + "status" + ).has_dimensions(1, 1).validate() + + def _monitoring_dynamic_table_query( + self, + *, + model_name: sql_identifier.SqlIdentifier, + model_version_name: sql_identifier.SqlIdentifier, + source_table_name: sql_identifier.SqlIdentifier, + refresh_interval: ModelMonitorRefreshInterval, + aggregate_window: ModelMonitorAggregationWindow, + warehouse_name: sql_identifier.SqlIdentifier, + timestamp_column: sql_identifier.SqlIdentifier, + numeric_features: List[sql_identifier.SqlIdentifier], + categoric_features: List[sql_identifier.SqlIdentifier], + prediction_columns: List[sql_identifier.SqlIdentifier], + label_columns: List[sql_identifier.SqlIdentifier], + ) -> str: + """ + Generates a dynamic table query for Observability - Monitoring. + + Args: + model_name: Model name to monitor. + model_version_name: Model version name to monitor. + source_table_name: Name of source data table to monitor. + refresh_interval: Refresh interval in minutes. + aggregate_window: Aggregate window minutes. + warehouse_name: Warehouse name to use for dynamic table. + timestamp_column: Timestamp column name. + numeric_features: List of numeric features to capture. + categoric_features: List of categoric features to capture. + prediction_columns: List of columns that contain model inference outputs. + label_columns: List of columns that contain ground truth values. + + Raises: + ValueError: If multiple output/ground truth columns are specified. MultiClass models are not yet supported. + + Returns: + Dynamic table query. + """ + # output and ground cols are list to keep interface extensible. + # for prpr only one label and one output col will be supported + if len(prediction_columns) != 1 or len(label_columns) != 1: + raise ValueError("Multiple Output columns are not supported in monitoring") + + monitoring_dt_name = self.get_monitoring_table_fully_qualified_name(model_name, model_version_name) + + feature_cols_query_list = [] + for feature in numeric_features + prediction_columns + label_columns: + feature_cols_query_list.append( + """ + OBJECT_CONSTRUCT( + 'sketch', APPROX_PERCENTILE_ACCUMULATE({col}), + 'count', count_if({col} is not null), + 'count_null', count_if({col} is null), + 'min', min({col}), + 'max', max({col}), + 'sum', sum({col}) + ) AS {col}""".format( + col=feature + ) + ) + + for col in categoric_features: + feature_cols_query_list.append( + f""" + {self._database_name}.{self._schema_name}.OBJECT_SUM(to_varchar({col})) AS {col}""" + ) + feature_cols_query = ",".join(feature_cols_query_list) + + return f""" + CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name} + TARGET_LAG = '{refresh_interval.minutes} minutes' + WAREHOUSE = {warehouse_name} + REFRESH_MODE = AUTO + INITIALIZE = ON_CREATE + AS + SELECT + TIME_SLICE({timestamp_column}, {aggregate_window.minutes}, 'MINUTE') timestamp,{feature_cols_query} + FROM + {source_table_name} + GROUP BY + 1 + """ + + def _monitoring_accuracy_table_query( + self, + *, + model_name: sql_identifier.SqlIdentifier, + model_version_name: sql_identifier.SqlIdentifier, + task: type_hints.Task, + source_table_name: sql_identifier.SqlIdentifier, + refresh_interval: ModelMonitorRefreshInterval, + aggregate_window: ModelMonitorAggregationWindow, + warehouse_name: sql_identifier.SqlIdentifier, + timestamp_column: sql_identifier.SqlIdentifier, + prediction_columns: List[sql_identifier.SqlIdentifier], + label_columns: List[sql_identifier.SqlIdentifier], + score_type: output_score_type.OutputScoreType, + ) -> str: + # output and ground cols are list to keep interface extensible. + # for prpr only one label and one output col will be supported + if len(prediction_columns) != 1 or len(label_columns) != 1: + raise ValueError("Multiple Output columns are not supported in monitoring") + if task == type_hints.Task.TABULAR_BINARY_CLASSIFICATION: + return self._monitoring_classification_accuracy_table_query( + model_name=model_name, + model_version_name=model_version_name, + source_table_name=source_table_name, + refresh_interval=refresh_interval, + aggregate_window=aggregate_window, + warehouse_name=warehouse_name, + timestamp_column=timestamp_column, + prediction_columns=prediction_columns, + label_columns=label_columns, + score_type=score_type, + ) + else: + return self._monitoring_regression_accuracy_table_query( + model_name=model_name, + model_version_name=model_version_name, + source_table_name=source_table_name, + refresh_interval=refresh_interval, + aggregate_window=aggregate_window, + warehouse_name=warehouse_name, + timestamp_column=timestamp_column, + prediction_columns=prediction_columns, + label_columns=label_columns, + ) + + def _monitoring_regression_accuracy_table_query( + self, + *, + model_name: sql_identifier.SqlIdentifier, + model_version_name: sql_identifier.SqlIdentifier, + source_table_name: sql_identifier.SqlIdentifier, + refresh_interval: ModelMonitorRefreshInterval, + aggregate_window: ModelMonitorAggregationWindow, + warehouse_name: sql_identifier.SqlIdentifier, + timestamp_column: sql_identifier.SqlIdentifier, + prediction_columns: List[sql_identifier.SqlIdentifier], + label_columns: List[sql_identifier.SqlIdentifier], + ) -> str: + """ + Generates a dynamic table query for Monitoring - regression model accuracy. + + Args: + model_name: Model name to monitor. + model_version_name: Model version name to monitor. + source_table_name: Name of source data table to monitor. + refresh_interval: Refresh interval in minutes. + aggregate_window: Aggregate window minutes. + warehouse_name: Warehouse name to use for dynamic table. + timestamp_column: Timestamp column name. + prediction_columns: List of output columns. + label_columns: List of ground truth columns. + + Returns: + Dynamic table query. + + Raises: + ValueError: If output columns are not same as ground truth columns. + + """ + + if len(prediction_columns) != len(label_columns): + raise ValueError(f"Mismatch in output & ground truth columns: {prediction_columns} != {label_columns}") + + monitoring_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, model_version_name) + + output_cols_query_list = [] + + output_cols_query_list.append( + f""" + OBJECT_CONSTRUCT( + 'sum_difference_label_pred', sum({prediction_columns[0]} - {label_columns[0]}), + 'sum_log_difference_square_label_pred', + sum( + case + when {prediction_columns[0]} > -1 and {label_columns[0]} > -1 + then pow(ln({prediction_columns[0]} + 1) - ln({label_columns[0]} + 1),2) + else null + END + ), + 'sum_difference_squares_label_pred', + sum( + pow( + {prediction_columns[0]} - {label_columns[0]}, + 2 + ) + ), + 'sum_absolute_regression_labels', sum(abs({label_columns[0]})), + 'sum_absolute_percentage_error', + sum( + abs( + div0null( + ({prediction_columns[0]} - {label_columns[0]}), + {label_columns[0]} + ) + ) + ), + 'sum_absolute_difference_label_pred', + sum( + abs({prediction_columns[0]} - {label_columns[0]}) + ), + 'sum_prediction', sum({prediction_columns[0]}), + 'sum_label', sum({label_columns[0]}), + 'count', count(*) + ) AS AGGREGATE_METRICS, + APPROX_PERCENTILE_ACCUMULATE({prediction_columns[0]}) prediction_sketch, + APPROX_PERCENTILE_ACCUMULATE({label_columns[0]}) label_sketch""" + ) + output_cols_query = ", ".join(output_cols_query_list) + + return f""" + CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name} + TARGET_LAG = '{refresh_interval.minutes} minutes' + WAREHOUSE = {warehouse_name} + REFRESH_MODE = AUTO + INITIALIZE = ON_CREATE + AS + SELECT + TIME_SLICE({timestamp_column}, {aggregate_window.minutes}, 'MINUTE') timestamp, + 'class_regression' label_class,{output_cols_query} + FROM + {source_table_name} + GROUP BY + 1 + """ + + def _monitoring_classification_accuracy_table_query( + self, + *, + model_name: sql_identifier.SqlIdentifier, + model_version_name: sql_identifier.SqlIdentifier, + source_table_name: sql_identifier.SqlIdentifier, + refresh_interval: ModelMonitorRefreshInterval, + aggregate_window: ModelMonitorAggregationWindow, + warehouse_name: sql_identifier.SqlIdentifier, + timestamp_column: sql_identifier.SqlIdentifier, + prediction_columns: List[sql_identifier.SqlIdentifier], + label_columns: List[sql_identifier.SqlIdentifier], + score_type: output_score_type.OutputScoreType, + ) -> str: + monitoring_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, model_version_name) + + # Initialize the select clause components + select_clauses = [] + + select_clauses.append( + f""" + {prediction_columns[0]}, + {label_columns[0]}, + CASE + WHEN {label_columns[0]} = 1 THEN 'class_positive' + ELSE 'class_negative' + END AS label_class""" + ) + + # Join all the select clauses into a single string + select_clause = f"{timestamp_column} AS timestamp," + ",".join(select_clauses) + + # Create the final CTE query + cte_query = f""" + WITH filtered_data AS ( + SELECT + {select_clause} + FROM + {source_table_name} + )""" + + # Initialize the select clause components + select_clauses = [] + + score_type_agg_clause = "" + if score_type == output_score_type.OutputScoreType.PROBITS: + score_type_agg_clause = f""" + 'sum_log_loss', + CASE + WHEN label_class = 'class_positive' THEN sum(-ln({prediction_columns[0]})) + ELSE sum(-ln(1 - {prediction_columns[0]})) + END,""" + else: + score_type_agg_clause = f""" + 'tp', count_if({label_columns[0]} = 1 AND {prediction_columns[0]} = 1), + 'tn', count_if({label_columns[0]} = 0 AND {prediction_columns[0]} = 0), + 'fp', count_if({label_columns[0]} = 0 AND {prediction_columns[0]} = 1), + 'fn', count_if({label_columns[0]} = 1 AND {prediction_columns[0]} = 0),""" + + select_clauses.append( + f""" + label_class, + OBJECT_CONSTRUCT( + 'sum_prediction', sum({prediction_columns[0]}), + 'sum_label', sum({label_columns[0]}),{score_type_agg_clause} + 'count', count(*) + ) AS AGGREGATE_METRICS, + APPROX_PERCENTILE_ACCUMULATE({prediction_columns[0]}) prediction_sketch, + APPROX_PERCENTILE_ACCUMULATE({label_columns[0]}) label_sketch""" + ) + + # Join all the select clauses into a single string + select_clause = ",\n".join(select_clauses) + + return f""" + CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name} + TARGET_LAG = '{refresh_interval.minutes} minutes' + WAREHOUSE = {warehouse_name} + REFRESH_MODE = AUTO + INITIALIZE = ON_CREATE + AS{cte_query} + select + time_slice(timestamp, {aggregate_window.minutes}, 'MINUTE') timestamp,{select_clause} + FROM + filtered_data + group by + 1, + 2 + """ diff --git a/snowflake/ml/monitoring/_client/monitor_sql_client_test.py b/snowflake/ml/monitoring/_client/monitor_sql_client_test.py new file mode 100644 index 00000000..a23058ba --- /dev/null +++ b/snowflake/ml/monitoring/_client/monitor_sql_client_test.py @@ -0,0 +1,1373 @@ +from typing import cast +from unittest import mock + +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import model_signature, type_hints +from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema +from snowflake.ml.monitoring._client import monitor_sql_client +from snowflake.ml.monitoring.entities import output_score_type +from snowflake.ml.monitoring.entities.model_monitor_interval import ( + ModelMonitorAggregationWindow, + ModelMonitorRefreshInterval, +) +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import DataFrame, Row, Session, types + + +class ModelMonitorSqlClientTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.test_db_name = sql_identifier.SqlIdentifier("SNOWML_OBSERVABILITY") + self.test_schema_name = sql_identifier.SqlIdentifier("DATA") + + self.test_monitor_name = sql_identifier.SqlIdentifier("TEST") + self.test_source_table_name = sql_identifier.SqlIdentifier("MODEL_OUTPUTS") + self.test_model_version_name = sql_identifier.SqlIdentifier("TEST_MODEL_VERSION") + self.test_model_name = sql_identifier.SqlIdentifier("TEST_MODEL") + self.test_fq_model_name = f"{self.test_db_name}.{self.test_schema_name}.{self.test_model_name}" + self.test_function_name = sql_identifier.SqlIdentifier("PREDICT") + self.test_timestamp_column = sql_identifier.SqlIdentifier("TIMESTAMP") + self.test_prediction_column_name = sql_identifier.SqlIdentifier("PREDICTION") + self.test_label_column_name = sql_identifier.SqlIdentifier("LABEL") + self.test_id_column_name = sql_identifier.SqlIdentifier("ID") + self.test_baseline_table_name_sql = "_SNOWML_OBS_BASELINE_TEST_MODEL_TEST_MODEL_VERSION" + self.test_wh_name = sql_identifier.SqlIdentifier("ML_OBS_WAREHOUSE") + + session = cast(Session, self.m_session) + self.monitor_sql_client = monitor_sql_client._ModelMonitorSQLClient( + session, database_name=self.test_db_name, schema_name=self.test_schema_name + ) + + self.mon_table_name = ( + f"{monitor_sql_client._SNOWML_MONITORING_TABLE_NAME_PREFIX}_" + + self.test_model_name + + f"_{self.test_model_version_name}" + ) + self.acc_table_name = ( + f"{monitor_sql_client._SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX}_" + + self.test_model_name + + f"_{self.test_model_version_name}" + ) + + def test_validate_source_table(self) -> None: + mocked_table_out = mock.MagicMock(name="schema") + self.m_session.table = mock.MagicMock(name="table", return_value=mocked_table_out) + mocked_table_out.schema = mock.MagicMock(name="schema") + mocked_table_out.schema.fields = [ + types.StructField(self.test_timestamp_column, types.TimestampType()), + types.StructField(self.test_prediction_column_name, types.DoubleType()), + types.StructField(self.test_label_column_name, types.DoubleType()), + types.StructField(self.test_id_column_name, types.StringType()), + ] + + self.m_session.add_mock_sql( + query=f"""SHOW TABLES LIKE '{self.test_source_table_name}' IN SNOWML_OBSERVABILITY.DATA""", + result=mock_data_frame.MockDataFrame([Row(name=self.test_source_table_name)]), + ) + self.monitor_sql_client.validate_source_table( + source_table_name=self.test_source_table_name, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + id_columns=[sql_identifier.SqlIdentifier("ID")], + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + model_function=model_manifest_schema.ModelFunctionInfo( + name="PREDICT", + target_method="predict", + target_method_function_type="FUNCTION", + signature=model_signature.ModelSignature(inputs=[], outputs=[]), + is_partitioned=False, + ), + ) + self.m_session.table.assert_called_once_with( + f"{self.test_db_name}.{self.test_schema_name}.{self.test_source_table_name}" + ) + self.m_session.finalize() + + def test_validate_source_table_shape(self) -> None: + mocked_table_out = mock.MagicMock(name="schema") + self.m_session.table = mock.MagicMock(name="table", return_value=mocked_table_out) + mocked_table_out.schema = mock.MagicMock(name="schema") + mocked_table_out.schema.fields = [ + types.StructField(self.test_timestamp_column, types.TimestampType()), + types.StructField(self.test_prediction_column_name, types.DoubleType()), + types.StructField(self.test_label_column_name, types.DoubleType()), + types.StructField(self.test_id_column_name, types.StringType()), + types.StructField("feature1", types.StringType()), + ] + + self.m_session.add_mock_sql( + query=f"""SHOW TABLES LIKE '{self.test_source_table_name}' IN SNOWML_OBSERVABILITY.DATA""", + result=mock_data_frame.MockDataFrame([Row(name=self.test_source_table_name)]), + ) + self.monitor_sql_client.validate_source_table( + source_table_name=self.test_source_table_name, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + id_columns=[sql_identifier.SqlIdentifier("ID")], + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + model_function=model_manifest_schema.ModelFunctionInfo( + name="PREDICT", + target_method="predict", + target_method_function_type="FUNCTION", + signature=model_signature.ModelSignature( + inputs=[ + model_signature.FeatureSpec("input_feature_0", model_signature.DataType.STRING), + ], + outputs=[], + ), + is_partitioned=False, + ), + ) + self.m_session.table.assert_called_once_with( + f"{self.test_db_name}.{self.test_schema_name}.{self.test_source_table_name}" + ) + self.m_session.finalize() + + def test_validate_source_table_shape_does_not_match_function_signature(self) -> None: + mocked_table_out = mock.MagicMock(name="schema") + self.m_session.table = mock.MagicMock(name="table", return_value=mocked_table_out) + mocked_table_out.schema = mock.MagicMock(name="schema") + mocked_table_out.schema.fields = [ + types.StructField(self.test_timestamp_column, types.TimestampType()), + types.StructField(self.test_prediction_column_name, types.DoubleType()), + types.StructField(self.test_label_column_name, types.DoubleType()), + types.StructField(self.test_id_column_name, types.StringType()), + types.StructField("feature1", types.StringType()), + ] + + self.m_session.add_mock_sql( + query=f"""SHOW TABLES LIKE '{self.test_source_table_name}' IN SNOWML_OBSERVABILITY.DATA""", + result=mock_data_frame.MockDataFrame([Row(name=self.test_source_table_name)]), + ) + + expected_msg = ( + r"Model function input types do not match the source table input columns types\. Model function expected: " + r"\[FeatureSpec\(dtype=DataType\.STRING, name='input_feature_0'\), FeatureSpec\(dtype=DataType\.STRING, " + r"name='unexpected_feature'\)\] but got \{'FEATURE1': StringType\(\)\}" + ) + with self.assertRaisesRegex(ValueError, expected_msg): + self.monitor_sql_client.validate_source_table( + source_table_name=self.test_source_table_name, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + id_columns=[sql_identifier.SqlIdentifier("ID")], + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + model_function=model_manifest_schema.ModelFunctionInfo( + name="PREDICT", + target_method="predict", + target_method_function_type="FUNCTION", + signature=model_signature.ModelSignature( + inputs=[ + model_signature.FeatureSpec("input_feature_0", model_signature.DataType.STRING), + model_signature.FeatureSpec("unexpected_feature", model_signature.DataType.STRING), + ], + outputs=[], + ), + is_partitioned=False, + ), + ) + self.m_session.finalize() + + def test_validate_monitor_warehouse(self) -> None: + self.m_session.add_mock_sql( + query=f"""SHOW WAREHOUSES LIKE '{self.test_wh_name}'""", + result=mock_data_frame.MockDataFrame([]), + ) + with self.assertRaisesRegex(ValueError, f"Warehouse '{self.test_wh_name}' not found"): + self.monitor_sql_client.validate_monitor_warehouse(self.test_wh_name) + + def test_validate_source_table_not_exists(self) -> None: + self.m_session.add_mock_sql( + query=f"""SHOW TABLES LIKE '{self.test_source_table_name}' IN SNOWML_OBSERVABILITY.DATA""", + result=mock_data_frame.MockDataFrame([]), + ) + expected_msg = ( + f"Table {self.test_source_table_name} does not exist in schema {self.test_db_name}.{self.test_schema_name}." + ) + with self.assertRaisesRegex(ValueError, expected_msg): + self.monitor_sql_client.validate_source_table( + source_table_name=self.test_source_table_name, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + id_columns=[sql_identifier.SqlIdentifier("ID")], + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + model_function=model_manifest_schema.ModelFunctionInfo( + name="PREDICT", + target_method="predict", + target_method_function_type="FUNCTION", + signature=model_signature.ModelSignature(inputs=[], outputs=[]), + is_partitioned=False, + ), + ) + self.m_session.finalize() + + def test_validate_columns_exist_in_source_table(self) -> None: + source_table_name = self.test_source_table_name + + table_schema = { + "feature1": types.StringType(), + "feature2": types.StringType(), + "feature3": types.StringType(), + "TIMESTAMP": types.TimestampType(), + "PREDICTION": types.DoubleType(), + "LABEL": types.DoubleType(), + "ID": types.StringType(), + } + self.monitor_sql_client._validate_columns_exist_in_source_table( + table_schema=table_schema, + source_table_name=source_table_name, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + id_columns=[sql_identifier.SqlIdentifier("ID")], + ) + + table_schema = { + "feature1": types.StringType(), + "feature2": types.StringType(), + "feature3": types.StringType(), + "PREDICTION": types.DoubleType(), + "LABEL": types.DoubleType(), + "ID": types.StringType(), + } + with self.assertRaisesRegex(ValueError, "Timestamp column TIMESTAMP does not exist in table MODEL_OUTPUTS"): + self.monitor_sql_client._validate_columns_exist_in_source_table( + table_schema=table_schema, + source_table_name=source_table_name, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + id_columns=[sql_identifier.SqlIdentifier("ID")], + ) + + table_schema = { + "feature1": types.StringType(), + "feature2": types.StringType(), + "feature3": types.StringType(), + "TIMESTAMP": types.TimestampType(), + "LABEL": types.DoubleType(), + "ID": types.StringType(), + } + + with self.assertRaisesRegex( + ValueError, r"Prediction column\(s\): \['PREDICTION'\] do not exist in table MODEL_OUTPUTS." + ): + self.monitor_sql_client._validate_columns_exist_in_source_table( + table_schema=table_schema, + source_table_name=source_table_name, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + id_columns=[sql_identifier.SqlIdentifier("ID")], + ) + + table_schema = { + "feature1": types.StringType(), + "feature2": types.StringType(), + "feature3": types.StringType(), + "TIMESTAMP": types.TimestampType(), + "PREDICTION": types.DoubleType(), + "ID": types.StringType(), + } + with self.assertRaisesRegex(ValueError, r"Label column\(s\): \['LABEL'\] do not exist in table MODEL_OUTPUTS."): + self.monitor_sql_client._validate_columns_exist_in_source_table( + table_schema=table_schema, + source_table_name=source_table_name, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + id_columns=[sql_identifier.SqlIdentifier("ID")], + ) + + table_schema = { + "feature1": types.StringType(), + "feature2": types.StringType(), + "feature3": types.StringType(), + "TIMESTAMP": types.TimestampType(), + "PREDICTION": types.DoubleType(), + "LABEL": types.DoubleType(), + } + with self.assertRaisesRegex(ValueError, r"ID column\(s\): \['ID'\] do not exist in table MODEL_OUTPUTS"): + self.monitor_sql_client._validate_columns_exist_in_source_table( + table_schema=table_schema, + source_table_name=source_table_name, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + id_columns=[sql_identifier.SqlIdentifier("ID")], + ) + + def test_validate_column_types(self) -> None: + self.monitor_sql_client._validate_column_types( + table_schema={ + "PREDICTION1": types.DoubleType(), + "PREDICTION2": types.DoubleType(), + "LABEL1": types.DoubleType(), + "LABEL2": types.DoubleType(), + "ID": types.StringType(), + "TIMESTAMP": types.TimestampType(types.TimestampTimeZone("ltz")), + }, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + prediction_columns=[ + sql_identifier.SqlIdentifier("PREDICTION1"), + sql_identifier.SqlIdentifier("PREDICTION2"), + ], + id_columns=[sql_identifier.SqlIdentifier("ID")], + label_columns=[sql_identifier.SqlIdentifier("LABEL1"), sql_identifier.SqlIdentifier("LABEL2")], + ) + + def test_validate_prediction_column_types(self) -> None: + with self.assertRaisesRegex(ValueError, "Prediction column types must be the same. Found: .*"): + self.monitor_sql_client._validate_prediction_columns_types( + table_schema={ + "PREDICTION1": types.DoubleType(), + "PREDICTION2": types.StringType(), + }, + prediction_columns=[ + sql_identifier.SqlIdentifier("PREDICTION1"), + sql_identifier.SqlIdentifier("PREDICTION2"), + ], + ) + + def test_validate_label_column_types(self) -> None: + with self.assertRaisesRegex(ValueError, "Label column types must be the same. Found:"): + self.monitor_sql_client._validate_label_columns_types( + table_schema={ + "LABEL1": types.DoubleType(), + "LABEL2": types.StringType(), + }, + label_columns=[sql_identifier.SqlIdentifier("LABEL1"), sql_identifier.SqlIdentifier("LABEL2")], + ) + + def test_validate_timestamp_column_type(self) -> None: + with self.assertRaisesRegex(ValueError, "Timestamp column: TIMESTAMP must be TimestampType"): + self.monitor_sql_client._validate_timestamp_column_type( + table_schema={ + "TIMESTAMP": types.StringType(), + }, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + ) + + def test_validate_id_columns_types(self) -> None: + with self.assertRaisesRegex(ValueError, "Id columns must all be StringType"): + self.monitor_sql_client._validate_id_columns_types( + table_schema={ + "ID": types.DoubleType(), + }, + id_columns=[ + sql_identifier.SqlIdentifier("ID"), + ], + ) + + def test_validate_multiple_id_columns_types(self) -> None: + with self.assertRaisesRegex(ValueError, "Id columns must all be StringType. Found"): + self.monitor_sql_client._validate_id_columns_types( + table_schema={ + "ID1": types.StringType(), + "ID2": types.DecimalType(), + }, + id_columns=[ + sql_identifier.SqlIdentifier("ID1"), + sql_identifier.SqlIdentifier("ID2"), + ], + ) + + def test_validate_id_columns_types_all_string(self) -> None: + self.monitor_sql_client._validate_id_columns_types( + table_schema={ + "ID1": types.StringType(36), + "ID2": types.StringType(64), + "ID3": types.StringType(), + }, + id_columns=[ + sql_identifier.SqlIdentifier("ID1"), + sql_identifier.SqlIdentifier("ID2"), + sql_identifier.SqlIdentifier("ID3"), + ], + ) + + def test_monitoring_dynamic_table_query_single_numeric_single_categoric(self) -> None: + query = self.monitor_sql_client._monitoring_dynamic_table_query( + model_name=self.test_model_name, + model_version_name=self.test_model_version_name, + source_table_name=self.test_source_table_name, + refresh_interval=ModelMonitorRefreshInterval("15 minutes"), + aggregate_window=ModelMonitorAggregationWindow.WINDOW_1_HOUR, + warehouse_name=self.test_wh_name, + timestamp_column=self.test_timestamp_column, + numeric_features=[sql_identifier.SqlIdentifier("NUM_0")], + categoric_features=[sql_identifier.SqlIdentifier("STR_COL_0")], + prediction_columns=[sql_identifier.SqlIdentifier("OUTPUT")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + ) + + expected = f""" + CREATE DYNAMIC TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA.{self.mon_table_name} + TARGET_LAG = '15 minutes' + WAREHOUSE = ML_OBS_WAREHOUSE + REFRESH_MODE = AUTO + INITIALIZE = ON_CREATE + AS + SELECT + TIME_SLICE(TIMESTAMP, 60, 'MINUTE') timestamp, + OBJECT_CONSTRUCT( + 'sketch', APPROX_PERCENTILE_ACCUMULATE(NUM_0), + 'count', count_if(NUM_0 is not null), + 'count_null', count_if(NUM_0 is null), + 'min', min(NUM_0), + 'max', max(NUM_0), + 'sum', sum(NUM_0) + ) AS NUM_0, + OBJECT_CONSTRUCT( + 'sketch', APPROX_PERCENTILE_ACCUMULATE(OUTPUT), + 'count', count_if(OUTPUT is not null), + 'count_null', count_if(OUTPUT is null), + 'min', min(OUTPUT), + 'max', max(OUTPUT), + 'sum', sum(OUTPUT) + ) AS OUTPUT, + OBJECT_CONSTRUCT( + 'sketch', APPROX_PERCENTILE_ACCUMULATE(LABEL), + 'count', count_if(LABEL is not null), + 'count_null', count_if(LABEL is null), + 'min', min(LABEL), + 'max', max(LABEL), + 'sum', sum(LABEL) + ) AS LABEL, + SNOWML_OBSERVABILITY.DATA.OBJECT_SUM(to_varchar(STR_COL_0)) AS STR_COL_0 + FROM + MODEL_OUTPUTS + GROUP BY + 1 + """ + self.assertEqual(query, expected) + + def test_monitoring_dynamic_table_query_multi_feature(self) -> None: + query = self.monitor_sql_client._monitoring_dynamic_table_query( + model_name=self.test_model_name, + model_version_name=self.test_model_version_name, + source_table_name=self.test_source_table_name, + refresh_interval=ModelMonitorRefreshInterval("15 minutes"), + aggregate_window=ModelMonitorAggregationWindow.WINDOW_1_HOUR, + warehouse_name=self.test_wh_name, + timestamp_column=self.test_timestamp_column, + numeric_features=[ + sql_identifier.SqlIdentifier("NUM_0"), + sql_identifier.SqlIdentifier("NUM_1"), + sql_identifier.SqlIdentifier("NUM_2"), + ], + categoric_features=[sql_identifier.SqlIdentifier("STR_COL_0"), sql_identifier.SqlIdentifier("STR_COL_1")], + prediction_columns=[sql_identifier.SqlIdentifier("OUTPUT")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + ) + self.assertEqual( + query, + f""" + CREATE DYNAMIC TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA.{self.mon_table_name} + TARGET_LAG = '15 minutes' + WAREHOUSE = ML_OBS_WAREHOUSE + REFRESH_MODE = AUTO + INITIALIZE = ON_CREATE + AS + SELECT + TIME_SLICE(TIMESTAMP, 60, 'MINUTE') timestamp, + OBJECT_CONSTRUCT( + 'sketch', APPROX_PERCENTILE_ACCUMULATE(NUM_0), + 'count', count_if(NUM_0 is not null), + 'count_null', count_if(NUM_0 is null), + 'min', min(NUM_0), + 'max', max(NUM_0), + 'sum', sum(NUM_0) + ) AS NUM_0, + OBJECT_CONSTRUCT( + 'sketch', APPROX_PERCENTILE_ACCUMULATE(NUM_1), + 'count', count_if(NUM_1 is not null), + 'count_null', count_if(NUM_1 is null), + 'min', min(NUM_1), + 'max', max(NUM_1), + 'sum', sum(NUM_1) + ) AS NUM_1, + OBJECT_CONSTRUCT( + 'sketch', APPROX_PERCENTILE_ACCUMULATE(NUM_2), + 'count', count_if(NUM_2 is not null), + 'count_null', count_if(NUM_2 is null), + 'min', min(NUM_2), + 'max', max(NUM_2), + 'sum', sum(NUM_2) + ) AS NUM_2, + OBJECT_CONSTRUCT( + 'sketch', APPROX_PERCENTILE_ACCUMULATE(OUTPUT), + 'count', count_if(OUTPUT is not null), + 'count_null', count_if(OUTPUT is null), + 'min', min(OUTPUT), + 'max', max(OUTPUT), + 'sum', sum(OUTPUT) + ) AS OUTPUT, + OBJECT_CONSTRUCT( + 'sketch', APPROX_PERCENTILE_ACCUMULATE(LABEL), + 'count', count_if(LABEL is not null), + 'count_null', count_if(LABEL is null), + 'min', min(LABEL), + 'max', max(LABEL), + 'sum', sum(LABEL) + ) AS LABEL, + SNOWML_OBSERVABILITY.DATA.OBJECT_SUM(to_varchar(STR_COL_0)) AS STR_COL_0, + SNOWML_OBSERVABILITY.DATA.OBJECT_SUM(to_varchar(STR_COL_1)) AS STR_COL_1 + FROM + MODEL_OUTPUTS + GROUP BY + 1 + """, + ) + + def test_monitoring_accuracy_regression_dynamic_table_query_single_prediction(self) -> None: + query = self.monitor_sql_client._monitoring_regression_accuracy_table_query( + model_name=self.test_model_name, + model_version_name=self.test_model_version_name, + source_table_name=self.test_source_table_name, + refresh_interval=ModelMonitorRefreshInterval("15 minutes"), + aggregate_window=ModelMonitorAggregationWindow.WINDOW_1_HOUR, + warehouse_name=self.test_wh_name, + timestamp_column=self.test_timestamp_column, + prediction_columns=[sql_identifier.SqlIdentifier("OUTPUT")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + ) + self.assertEqual( + query, + f""" + CREATE DYNAMIC TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA.{self.acc_table_name} + TARGET_LAG = '15 minutes' + WAREHOUSE = ML_OBS_WAREHOUSE + REFRESH_MODE = AUTO + INITIALIZE = ON_CREATE + AS + SELECT + TIME_SLICE(TIMESTAMP, 60, 'MINUTE') timestamp, + 'class_regression' label_class, + OBJECT_CONSTRUCT( + 'sum_difference_label_pred', sum(OUTPUT - LABEL), + 'sum_log_difference_square_label_pred', + sum( + case + when OUTPUT > -1 and LABEL > -1 + then pow(ln(OUTPUT + 1) - ln(LABEL + 1),2) + else null + END + ), + 'sum_difference_squares_label_pred', + sum( + pow( + OUTPUT - LABEL, + 2 + ) + ), + 'sum_absolute_regression_labels', sum(abs(LABEL)), + 'sum_absolute_percentage_error', + sum( + abs( + div0null( + (OUTPUT - LABEL), + LABEL + ) + ) + ), + 'sum_absolute_difference_label_pred', + sum( + abs(OUTPUT - LABEL) + ), + 'sum_prediction', sum(OUTPUT), + 'sum_label', sum(LABEL), + 'count', count(*) + ) AS AGGREGATE_METRICS, + APPROX_PERCENTILE_ACCUMULATE(OUTPUT) prediction_sketch, + APPROX_PERCENTILE_ACCUMULATE(LABEL) label_sketch + FROM + MODEL_OUTPUTS + GROUP BY + 1 + """, + ) + + def test_monitoring_accuracy_classification_probit_dynamic_table_query_single_prediction(self) -> None: + query = self.monitor_sql_client._monitoring_classification_accuracy_table_query( + model_name=self.test_model_name, + model_version_name=self.test_model_version_name, + source_table_name=self.test_source_table_name, + refresh_interval=ModelMonitorRefreshInterval("15 minutes"), + aggregate_window=ModelMonitorAggregationWindow.WINDOW_1_HOUR, + warehouse_name=self.test_wh_name, + timestamp_column=self.test_timestamp_column, + prediction_columns=[sql_identifier.SqlIdentifier("OUTPUT")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + score_type=output_score_type.OutputScoreType.PROBITS, + ) + self.assertEqual( + query, + f""" + CREATE DYNAMIC TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA.{self.acc_table_name} + TARGET_LAG = '15 minutes' + WAREHOUSE = ML_OBS_WAREHOUSE + REFRESH_MODE = AUTO + INITIALIZE = ON_CREATE + AS + WITH filtered_data AS ( + SELECT + TIMESTAMP AS timestamp, + OUTPUT, + LABEL, + CASE + WHEN LABEL = 1 THEN 'class_positive' + ELSE 'class_negative' + END AS label_class + FROM + MODEL_OUTPUTS + ) + select + time_slice(timestamp, 60, 'MINUTE') timestamp, + label_class, + OBJECT_CONSTRUCT( + 'sum_prediction', sum(OUTPUT), + 'sum_label', sum(LABEL), + 'sum_log_loss', + CASE + WHEN label_class = 'class_positive' THEN sum(-ln(OUTPUT)) + ELSE sum(-ln(1 - OUTPUT)) + END, + 'count', count(*) + ) AS AGGREGATE_METRICS, + APPROX_PERCENTILE_ACCUMULATE(OUTPUT) prediction_sketch, + APPROX_PERCENTILE_ACCUMULATE(LABEL) label_sketch + FROM + filtered_data + group by + 1, + 2 + """, + ) + + def test_monitoring_accuracy_classification_class_dynamic_table_query_single_prediction(self) -> None: + query = self.monitor_sql_client._monitoring_classification_accuracy_table_query( + model_name=self.test_model_name, + model_version_name=self.test_model_version_name, + source_table_name=self.test_source_table_name, + refresh_interval=ModelMonitorRefreshInterval("15 minutes"), + aggregate_window=ModelMonitorAggregationWindow.WINDOW_1_HOUR, + warehouse_name=self.test_wh_name, + timestamp_column=self.test_timestamp_column, + prediction_columns=[sql_identifier.SqlIdentifier("OUTPUT")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + score_type=output_score_type.OutputScoreType.CLASSIFICATION, + ) + self.assertEqual( + query, + f""" + CREATE DYNAMIC TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA.{self.acc_table_name} + TARGET_LAG = '15 minutes' + WAREHOUSE = ML_OBS_WAREHOUSE + REFRESH_MODE = AUTO + INITIALIZE = ON_CREATE + AS + WITH filtered_data AS ( + SELECT + TIMESTAMP AS timestamp, + OUTPUT, + LABEL, + CASE + WHEN LABEL = 1 THEN 'class_positive' + ELSE 'class_negative' + END AS label_class + FROM + MODEL_OUTPUTS + ) + select + time_slice(timestamp, 60, 'MINUTE') timestamp, + label_class, + OBJECT_CONSTRUCT( + 'sum_prediction', sum(OUTPUT), + 'sum_label', sum(LABEL), + 'tp', count_if(LABEL = 1 AND OUTPUT = 1), + 'tn', count_if(LABEL = 0 AND OUTPUT = 0), + 'fp', count_if(LABEL = 0 AND OUTPUT = 1), + 'fn', count_if(LABEL = 1 AND OUTPUT = 0), + 'count', count(*) + ) AS AGGREGATE_METRICS, + APPROX_PERCENTILE_ACCUMULATE(OUTPUT) prediction_sketch, + APPROX_PERCENTILE_ACCUMULATE(LABEL) label_sketch + FROM + filtered_data + group by + 1, + 2 + """, + ) + + def test_monitoring_accuracy_dynamic_table_query_multi_prediction(self) -> None: + with self.assertRaises(ValueError): + _ = self.monitor_sql_client._monitoring_accuracy_table_query( + model_name=self.test_model_name, + model_version_name=self.test_model_version_name, + task=type_hints.Task.TABULAR_BINARY_CLASSIFICATION, + source_table_name=self.test_source_table_name, + refresh_interval=ModelMonitorRefreshInterval("15 minutes"), + aggregate_window=ModelMonitorAggregationWindow.WINDOW_1_HOUR, + warehouse_name=self.test_wh_name, + timestamp_column=self.test_timestamp_column, + prediction_columns=[sql_identifier.SqlIdentifier("LABEL"), sql_identifier.SqlIdentifier("output_1")], + label_columns=[sql_identifier.SqlIdentifier("LABEL"), sql_identifier.SqlIdentifier("label_1")], + score_type=output_score_type.OutputScoreType.REGRESSION, + ) + + def test_validate_existence_by_name(self) -> None: + self.m_session.add_mock_sql( + query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME + FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA + WHERE MONITOR_NAME = '{self.test_monitor_name}' + """, + result=mock_data_frame.MockDataFrame([]), + ) + res = self.monitor_sql_client.validate_existence_by_name(self.test_monitor_name) + self.assertFalse(res) + + self.m_session.add_mock_sql( + query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME + FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA + WHERE MONITOR_NAME = '{self.test_monitor_name}' + """, + result=mock_data_frame.MockDataFrame( + [ + Row( + FULLY_QUALIFIED_MODEL_NAME=self.test_fq_model_name, + MODEL_VERSION_NAME=self.test_model_version_name, + ) + ] + ), + ) + res = self.monitor_sql_client.validate_existence_by_name(self.test_monitor_name) + self.assertTrue(res) + self.m_session.finalize() + + def test_validate_existence(self) -> None: + self.m_session.add_mock_sql( + query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME + FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA + WHERE FULLY_QUALIFIED_MODEL_NAME = '{self.test_fq_model_name}' + AND MODEL_VERSION_NAME = '{self.test_model_version_name}' + """, + result=mock_data_frame.MockDataFrame([]), + ) + res = self.monitor_sql_client.validate_existence(self.test_fq_model_name, self.test_model_version_name) + self.assertFalse(res) + + self.m_session.add_mock_sql( + query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME + FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA + WHERE FULLY_QUALIFIED_MODEL_NAME = '{self.test_fq_model_name}' + AND MODEL_VERSION_NAME = '{self.test_model_version_name}' + """, + result=mock_data_frame.MockDataFrame( + [ + Row( + FULLY_QUALIFIED_MODEL_NAME=self.test_fq_model_name, + MODEL_VERSION_NAME=self.test_model_version_name, + ) + ] + ), + ) + res = self.monitor_sql_client.validate_existence(self.test_fq_model_name, self.test_model_version_name) + self.assertTrue(res) + + self.m_session.finalize() + + def test_create_monitor_on_model_version(self) -> None: + self.m_session.add_mock_sql( + query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME + FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA + WHERE FULLY_QUALIFIED_MODEL_NAME = '{self.test_fq_model_name}' + AND MODEL_VERSION_NAME = '{self.test_model_version_name}' + """, + result=mock_data_frame.MockDataFrame([]), + ) + self.m_session.add_mock_sql( + query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME + FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA + WHERE MONITOR_NAME = '{self.test_monitor_name}' + """, + result=mock_data_frame.MockDataFrame([]), + ) + + self.m_session.add_mock_sql( + query=f"""INSERT INTO SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA + (MONITOR_NAME, SOURCE_TABLE_NAME, FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME, + FUNCTION_NAME, TASK, IS_ENABLED, + TIMESTAMP_COLUMN_NAME, PREDICTION_COLUMN_NAMES, LABEL_COLUMN_NAMES, ID_COLUMN_NAMES) + SELECT '{self.test_monitor_name}', '{self.test_source_table_name}', + '{self.test_fq_model_name}', '{self.test_model_version_name}', '{self.test_function_name}', + 'TABULAR_BINARY_CLASSIFICATION', TRUE, + '{self.test_timestamp_column}', ARRAY_CONSTRUCT('{self.test_prediction_column_name}'), + ARRAY_CONSTRUCT('{self.test_label_column_name}'), ARRAY_CONSTRUCT('{self.test_id_column_name}')""", + result=mock_data_frame.MockDataFrame([Row(**{"number of rows inserted": 1})]), + ) + + self.monitor_sql_client.create_monitor_on_model_version( + monitor_name=self.test_monitor_name, + source_table_name=self.test_source_table_name, + fully_qualified_model_name=self.test_fq_model_name, + version_name=self.test_model_version_name, + function_name=self.test_function_name, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + id_columns=[sql_identifier.SqlIdentifier("ID")], + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + task=type_hints.Task.TABULAR_BINARY_CLASSIFICATION, + statement_params=None, + ) + self.m_session.finalize() + + def test_create_monitor_on_model_version_fails_if_model_exists(self) -> None: + self.m_session.add_mock_sql( + query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME + FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA + WHERE FULLY_QUALIFIED_MODEL_NAME = '{self.test_fq_model_name}' + AND MODEL_VERSION_NAME = '{self.test_model_version_name}' + """, + result=mock_data_frame.MockDataFrame( + [ + Row( + FULLY_QUALIFIED_MODEL_NAME=self.test_fq_model_name, + MODEL_VERSION_NAME=self.test_model_version_name, + ) + ] + ), + ) + expected_msg = f"Model {self.test_fq_model_name} Version {self.test_model_version_name} is already monitored!" + with self.assertRaisesRegex(ValueError, expected_msg): + self.monitor_sql_client.create_monitor_on_model_version( + monitor_name=self.test_monitor_name, + source_table_name=self.test_source_table_name, + fully_qualified_model_name=self.test_fq_model_name, + version_name=self.test_model_version_name, + function_name=self.test_function_name, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + id_columns=[sql_identifier.SqlIdentifier("ID")], + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + task=type_hints.Task.TABULAR_BINARY_CLASSIFICATION, + statement_params=None, + ) + + self.m_session.finalize() + + def test_create_monitor_on_model_version_fails_if_monitor_name_exists(self) -> None: + self.m_session.add_mock_sql( + query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME + FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA + WHERE FULLY_QUALIFIED_MODEL_NAME = '{self.test_fq_model_name}' + AND MODEL_VERSION_NAME = '{self.test_model_version_name}' + """, + result=mock_data_frame.MockDataFrame([]), + ) + self.m_session.add_mock_sql( + query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME + FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA + WHERE MONITOR_NAME = '{self.test_monitor_name}' + """, + result=mock_data_frame.MockDataFrame( + [ + Row( + FULLY_QUALIFIED_MODEL_NAME=self.test_fq_model_name, + MODEL_VERSION_NAME=self.test_model_version_name, + ) + ] + ), + ) + + expected_msg = f"Model Monitor with name '{self.test_monitor_name}' already exists!" + with self.assertRaisesRegex(ValueError, expected_msg): + self.monitor_sql_client.create_monitor_on_model_version( + monitor_name=self.test_monitor_name, + source_table_name=self.test_source_table_name, + fully_qualified_model_name=self.test_fq_model_name, + version_name=self.test_model_version_name, + function_name=self.test_function_name, + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + id_columns=[sql_identifier.SqlIdentifier("ID")], + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + task=type_hints.Task.TABULAR_BINARY_CLASSIFICATION, + statement_params=None, + ) + + self.m_session.finalize() + + def test_validate_unique_columns(self) -> None: + self.monitor_sql_client._validate_unique_columns( + id_columns=[sql_identifier.SqlIdentifier("ID")], + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + ) + + def test_validate_unique_columns_column_used_twice(self) -> None: + with self.assertRaisesRegex( + ValueError, "Column names must be unique across id, timestamp, prediction, and label columns." + ): + self.monitor_sql_client._validate_unique_columns( + id_columns=[sql_identifier.SqlIdentifier("ID")], + timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), + prediction_columns=[ + sql_identifier.SqlIdentifier("PREDICTION"), + # This is a duplicate with the id column + sql_identifier.SqlIdentifier("ID"), + ], + label_columns=[sql_identifier.SqlIdentifier("LABEL")], + ) + + def test_infer_numeric_categoric_column_names(self) -> None: + from snowflake.snowpark import types + + timestamp_col = sql_identifier.SqlIdentifier("TS_COL") + id_col = sql_identifier.SqlIdentifier("ID_COL") + output_column = sql_identifier.SqlIdentifier("OUTPUT") + label_column = sql_identifier.SqlIdentifier("LABEL") + test_schema = { + timestamp_col: types.TimeType(), + id_col: types.FloatType(), + output_column: types.FloatType(), + label_column: types.FloatType(), + "STR_COL": types.StringType(16777216), + "LONG_COL": types.LongType(), + "FLOAT_COL": types.FloatType(), + "DOUBLE_COL": types.DoubleType(), + "BINARY_COL": types.BinaryType(), + "ARRAY_COL": types.ArrayType(), + "NULL_COL": types.NullType(), + } + + expected_numeric = [ + sql_identifier.SqlIdentifier("LONG_COL"), + sql_identifier.SqlIdentifier("FLOAT_COL"), + sql_identifier.SqlIdentifier("DOUBLE_COL"), + ] + expected_categoric = [ + sql_identifier.SqlIdentifier("STR_COL"), + ] + + numeric, categoric = monitor_sql_client._infer_numeric_categoric_feature_column_names( + source_table_schema=test_schema, + timestamp_column=timestamp_col, + id_columns=[id_col], + prediction_columns=[output_column], + label_columns=[label_column], + ) + self.assertListEqual(expected_numeric, numeric) + self.assertListEqual(expected_categoric, categoric) + + def test_initialize_baseline_table(self) -> None: + mocked_table_out = mock.MagicMock(name="schema") + self.m_session.table = mock.MagicMock(name="table", return_value=mocked_table_out) + mocked_table_out.schema = mock.MagicMock(name="schema") + mocked_table_out.schema.fields = [ + types.StructField(self.test_timestamp_column, types.TimestampType()), + types.StructField(self.test_prediction_column_name, types.DoubleType()), + types.StructField(self.test_label_column_name, types.DoubleType()), + types.StructField(self.test_id_column_name, types.StringType()), + ] + + self.m_session.add_mock_sql( + query=f"""CREATE TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA._SNOWML_OBS_BASELINE_""" + f"""{self.test_model_name}_{self.test_model_version_name}""" + f"""(PREDICTION DOUBLE, LABEL DOUBLE)""", + result=mock_data_frame.MockDataFrame( + [ + Row( + name="PREDICTION", + type="DOUBLE", + ), + Row( + name="LABEL", + type="DOUBLE", + ), + ] + ), + ) + + self.monitor_sql_client.initialize_baseline_table( + model_name=self.test_model_name, + version_name=self.test_model_version_name, + source_table_name=self.test_source_table_name, + columns_to_drop=[self.test_id_column_name, self.test_timestamp_column], + ) + + def test_materialize_baseline_dataframe(self) -> None: + mocked_dataframe = mock_data_frame.MockDataFrame( + [ + Row(TIMESTAMP="2022-01-01 00:00:00", PREDICTION=0.8, LABEL=1.0, ID="12345"), + Row(TIMESTAMP="2022-01-02 00:00:00", PREDICTION=0.6, LABEL=0.0, ID="67890"), + ] + ) + self.m_session.add_mock_sql( + f"SHOW TABLES LIKE '{self.test_baseline_table_name_sql}' IN SNOWML_OBSERVABILITY.DATA", + mock_data_frame.MockDataFrame([Row(name=self.test_baseline_table_name_sql)]), + ) + + mocked_dataframe.write = mock.MagicMock(name="write") + save_as_table = mock.MagicMock(name="save_as_table") + mocked_dataframe.write.mode = mock.MagicMock(name="mode", return_value=save_as_table) + + self.monitor_sql_client.materialize_baseline_dataframe( + baseline_df=cast(DataFrame, mocked_dataframe), + fully_qualified_model_name=self.test_model_name, + model_version_name=self.test_model_version_name, + ) + + mocked_dataframe.write.mode.assert_called_once_with("truncate") + save_as_table.save_as_table.assert_called_once_with( + [self.test_db_name, self.test_schema_name, self.test_baseline_table_name_sql], + statement_params=mock.ANY, + ) + + def test_materialize_baseline_dataframe_table_not_exists(self) -> None: + mocked_dataframe = mock_data_frame.MockDataFrame( + [ + Row(TIMESTAMP="2022-01-01 00:00:00", PREDICTION=0.8, LABEL=1.0, ID="12345"), + Row(TIMESTAMP="2022-01-02 00:00:00", PREDICTION=0.6, LABEL=0.0, ID="67890"), + ] + ) + self.m_session.add_mock_sql( + f"SHOW TABLES LIKE '{self.test_baseline_table_name_sql}' IN SNOWML_OBSERVABILITY.DATA", + mock_data_frame.MockDataFrame([]), + ) + + expected_msg = ( + f"Baseline table '{self.test_baseline_table_name_sql}' does not exist for model: " + "'TEST_MODEL' and model_version: 'TEST_MODEL_VERSION'" + ) + with self.assertRaisesRegex(ValueError, expected_msg): + self.monitor_sql_client.materialize_baseline_dataframe( + baseline_df=cast(DataFrame, mocked_dataframe), + fully_qualified_model_name=self.test_model_name, + model_version_name=self.test_model_version_name, + ) + + def test_initialize_baseline_table_different_data_kinds(self) -> None: + mocked_table_out = mock.MagicMock(name="schema") + self.m_session.table = mock.MagicMock(name="table", return_value=mocked_table_out) + mocked_table_out.schema = mock.MagicMock(name="schema") + mocked_table_out.schema.fields = [ + types.StructField(self.test_timestamp_column, types.TimestampType()), + types.StructField(self.test_prediction_column_name, types.DoubleType()), + types.StructField(self.test_label_column_name, types.DoubleType()), + types.StructField(self.test_id_column_name, types.StringType()), + types.StructField(sql_identifier.SqlIdentifier("FEATURE1"), types.StringType()), + types.StructField(sql_identifier.SqlIdentifier("FEATURE2"), types.DoubleType()), + types.StructField(sql_identifier.SqlIdentifier("FEATURE3"), types.FloatType()), + types.StructField(sql_identifier.SqlIdentifier("FEATURE4"), types.DecimalType(38, 9)), + types.StructField(sql_identifier.SqlIdentifier("FEATURE5"), types.IntegerType()), + types.StructField(sql_identifier.SqlIdentifier("FEATURE6"), types.LongType()), + types.StructField(sql_identifier.SqlIdentifier("FEATURE7"), types.ShortType()), + types.StructField(sql_identifier.SqlIdentifier("FEATURE8"), types.BinaryType()), + types.StructField(sql_identifier.SqlIdentifier("FEATURE9"), types.BooleanType()), + types.StructField(sql_identifier.SqlIdentifier("FEATURE10"), types.TimestampType()), + types.StructField( + sql_identifier.SqlIdentifier("FEATURE11"), types.TimestampType(types.TimestampTimeZone("ltz")) + ), + types.StructField( + sql_identifier.SqlIdentifier("FEATURE12"), types.TimestampType(types.TimestampTimeZone("ntz")) + ), + types.StructField( + sql_identifier.SqlIdentifier("FEATURE13"), types.TimestampType(types.TimestampTimeZone("tz")) + ), + ] + + self.m_session.add_mock_sql( + query=f"""CREATE TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA._SNOWML_OBS_BASELINE_""" + f"""{self.test_model_name}_{self.test_model_version_name}""" + f"""(PREDICTION DOUBLE, LABEL DOUBLE, + FEATURE1 STRING, FEATURE2 DOUBLE, FEATURE3 FLOAT, FEATURE4 NUMBER(38, 9), FEATURE5 INT, + FEATURE6 BIGINT, FEATURE7 SMALLINT, FEATURE8 BINARY, FEATURE9 BOOLEAN, FEATURE10 TIMESTAMP, + FEATURE11 TIMESTAMP_LTZ, FEATURE12 TIMESTAMP_NTZ, FEATURE13 TIMESTAMP_TZ)""", + result=mock_data_frame.MockDataFrame( + [ + Row( + name="PREDICTION", + type="DOUBLE", + ), + Row( + name="LABEL", + type="DOUBLE", + ), + Row( + name="FEATURE1", + type="STRING", + ), + Row( + name="FEATURE2", + type="DOUBLE", + ), + Row( + name="FEATURE3", + type="FLOAT", + ), + Row( + name="FEATURE4", + type="NUMBER", + ), + Row( + name="FEATURE5", + type="INTEGER", + ), + Row( + name="FEATURE6", + type="INTEGER", + ), + Row( + name="FEATURE7", + type="INTEGER", + ), + Row( + name="FEATURE8", + type="BINARY", + ), + Row( + name="FEATURE9", + type="BOOLEAN", + ), + Row( + name="FEATURE10", + type="TIMESTAMP", + ), + Row( + name="FEATURE11", + type="TIMESTAMP_LTZ", + ), + Row( + name="FEATURE12", + type="TIMESTAMP_NTZ", + ), + Row( + name="FEATURE13", + type="TIMESTAMP_TZ", + ), + ] + ), + ) + + self.monitor_sql_client.initialize_baseline_table( + model_name=self.test_model_name, + version_name=self.test_model_version_name, + source_table_name=self.test_source_table_name, + columns_to_drop=[self.test_timestamp_column, self.test_id_column_name], + ) + + def test_get_model_monitor_by_model_version(self) -> None: + model_db = sql_identifier.SqlIdentifier("MODEL_DB") + model_schema = sql_identifier.SqlIdentifier("MODEL_SCHEMA") + self.m_session.add_mock_sql( + f"""SELECT {monitor_sql_client.MONITOR_NAME_COL_NAME}, {monitor_sql_client.FQ_MODEL_NAME_COL_NAME}, + {monitor_sql_client.VERSION_NAME_COL_NAME}, {monitor_sql_client.FUNCTION_NAME_COL_NAME} + FROM {self.test_db_name}.{self.test_schema_name}.{monitor_sql_client.SNOWML_MONITORING_METADATA_TABLE_NAME} + WHERE {monitor_sql_client.FQ_MODEL_NAME_COL_NAME} = '{model_db}.{model_schema}.{self.test_model_name}' + AND {monitor_sql_client.VERSION_NAME_COL_NAME} = '{self.test_model_version_name}'""", + result=mock_data_frame.MockDataFrame( + [ + Row( + MONITOR_NAME=self.test_monitor_name, + FULLY_QUALIFIED_MODEL_NAME=f"{model_db}.{model_schema}.{self.test_model_name}", + MODEL_VERSION_NAME=self.test_model_version_name, + FUNCTION_NAME=self.test_function_name, + PREDICTION_COLUMN_NAMES="[]", + LABEL_COLUMN_NAMES="[]", + ) + ] + ), + ) + # name, fq_model_name, version_name, function_name + monitor_params = self.monitor_sql_client.get_model_monitor_by_model_version( + model_db=model_db, + model_schema=model_schema, + model_name=self.test_model_name, + version_name=self.test_model_version_name, + ) + self.assertEqual(monitor_params["monitor_name"], str(self.test_monitor_name)) + self.assertEqual( + monitor_params["fully_qualified_model_name"], f"{model_db}.{model_schema}.{self.test_model_name}" + ) + self.assertEqual(monitor_params["version_name"], str(self.test_model_version_name)) + self.assertEqual(monitor_params["function_name"], str(self.test_function_name)) + + self.m_session.finalize() # TODO: Move to tearDown() for all tests. + + def test_get_model_monitor_by_model_version_fails_if_multiple(self) -> None: + model_db = sql_identifier.SqlIdentifier("MODEL_DB") + model_schema = sql_identifier.SqlIdentifier("MODEL_SCHEMA") + self.m_session.add_mock_sql( + f"""SELECT {monitor_sql_client.MONITOR_NAME_COL_NAME}, {monitor_sql_client.FQ_MODEL_NAME_COL_NAME}, + {monitor_sql_client.VERSION_NAME_COL_NAME}, {monitor_sql_client.FUNCTION_NAME_COL_NAME} + FROM {self.test_db_name}.{self.test_schema_name}.{monitor_sql_client.SNOWML_MONITORING_METADATA_TABLE_NAME} + WHERE {monitor_sql_client.FQ_MODEL_NAME_COL_NAME} = '{model_db}.{model_schema}.{self.test_model_name}' + AND {monitor_sql_client.VERSION_NAME_COL_NAME} = '{self.test_model_version_name}'""", + result=mock_data_frame.MockDataFrame( + [ + Row( + MONITOR_NAME=self.test_monitor_name, + FULLY_QUALIFIED_MODEL_NAME=f"{model_db}.{model_schema}.{self.test_model_name}", + MODEL_VERSION_NAME=self.test_model_version_name, + FUNCTION_NAME=self.test_function_name, + ), + Row( + MONITOR_NAME=self.test_monitor_name, + FULLY_QUALIFIED_MODEL_NAME=f"{model_db}.{model_schema}.{self.test_model_name}", + MODEL_VERSION_NAME=self.test_model_version_name, + FUNCTION_NAME=self.test_function_name, + ), + ] + ), + ) + with self.assertRaisesRegex(ValueError, "Invalid state. Multiple Monitors exist for model:"): + self.monitor_sql_client.get_model_monitor_by_model_version( + model_db=model_db, + model_schema=model_schema, + model_name=self.test_model_name, + version_name=self.test_model_version_name, + ) + + self.m_session.finalize() # TODO: Move to tearDown() for all tests. + + def test_dashboard_udtf_queries(self) -> None: + queries_map = self.monitor_sql_client._create_dashboard_udtf_queries( + self.test_monitor_name, + self.test_model_version_name, + self.test_model_name, + type_hints.Task.TABULAR_REGRESSION, + output_score_type.OutputScoreType.REGRESSION, + output_columns=[self.test_prediction_column_name], + ground_truth_columns=[self.test_label_column_name], + ) + self.assertIn("rmse", queries_map) + EXPECTED_RMSE = """CREATE OR REPLACE FUNCTION TEST_RMSE() + RETURNS TABLE(event_timestamp TIMESTAMP_NTZ, value FLOAT) + AS +$$ +WITH metric_of_interest as ( + select + time_slice(timestamp, 1, 'hour') as event_timestamp, + AGGREGATE_METRICS:"sum_difference_squares_label_pred" as aggregate_field, + AGGREGATE_METRICS:"count" as "count" + from + SNOWML_OBSERVABILITY.DATA._SNOWML_OBS_ACCURACY__TEST_MODEL_VERSION_TEST_MODEL +), metric_combine as ( + select + event_timestamp, + CAST(SUM(NVL(aggregate_field, 0)) as DOUBLE) as metric_sum, + SUM("count") as metric_count + from + metric_of_interest + where + cast(aggregate_field as varchar) not in ('inf','-inf','NaN') + group by + 1 +) select + event_timestamp, + SQRT(DIV0(metric_sum,metric_count)) as VALUE +from metric_combine +order by 1 desc +$$; +""" + self.assertEqual(queries_map["rmse"], EXPECTED_RMSE) + + self.assertIn("record_count", queries_map) + EXPECTED_RECORD_COUNT = """CREATE OR REPLACE FUNCTION TEST_PREDICTION_COUNT() + RETURNS TABLE(event_timestamp TIMESTAMP_NTZ, count FLOAT) + AS + $$ +SELECT + time_slice(timestamp, 1, 'hour') as "event_timestamp", + sum(get(PREDICTION,'count')) as count +from + SNOWML_OBSERVABILITY.DATA._SNOWML_OBS_MONITORING__TEST_MODEL_VERSION_TEST_MODEL +group by + 1 +order by + 1 desc + $$; +""" + self.assertEqual(queries_map["record_count"], EXPECTED_RECORD_COUNT) + + def test_get_all_model_monitor_metadata(self) -> None: + expected_result = [Row(MONITOR_NAME="monitor")] + self.m_session.add_mock_sql( + query="SELECT * FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA", + result=mock_data_frame.MockDataFrame(expected_result), + ) + res = self.monitor_sql_client.get_all_model_monitor_metadata() + self.assertEqual(res, expected_result) + + def test_suspend_monitor_dynamic_tables(self) -> None: + self.m_session.add_mock_sql( + f"""ALTER DYNAMIC TABLE {self.test_db_name}.{self.test_schema_name}.{self.mon_table_name} SUSPEND""", + result=mock_data_frame.MockDataFrame([Row(status="Success")]), + ) + self.m_session.add_mock_sql( + f"""ALTER DYNAMIC TABLE {self.test_db_name}.{self.test_schema_name}.{self.acc_table_name} SUSPEND""", + result=mock_data_frame.MockDataFrame([Row(status="Success")]), + ) + self.monitor_sql_client.suspend_monitor_dynamic_tables(self.test_model_name, self.test_model_version_name) + self.m_session.finalize() + + def test_resume_monitor_dynamic_tables(self) -> None: + self.m_session.add_mock_sql( + f"""ALTER DYNAMIC TABLE {self.test_db_name}.{self.test_schema_name}.{self.mon_table_name} RESUME""", + result=mock_data_frame.MockDataFrame([Row(status="Success")]), + ) + self.m_session.add_mock_sql( + f"""ALTER DYNAMIC TABLE {self.test_db_name}.{self.test_schema_name}.{self.acc_table_name} RESUME""", + result=mock_data_frame.MockDataFrame([Row(status="Success")]), + ) + self.monitor_sql_client.resume_monitor_dynamic_tables(self.test_model_name, self.test_model_version_name) + self.m_session.finalize() + + def test_delete_monitor_metadata(self) -> None: + monitor = "TEST_MONITOR" + self.m_session.add_mock_sql( + query=f"DELETE FROM {self.test_db_name}.{self.test_schema_name}." + f"{monitor_sql_client.SNOWML_MONITORING_METADATA_TABLE_NAME} WHERE " + f"{monitor_sql_client.MONITOR_NAME_COL_NAME} = '{monitor}'", + result=mock_data_frame.MockDataFrame([]), + ) + self.monitor_sql_client.delete_monitor_metadata(monitor) + + def test_delete_baseline_table(self) -> None: + model = "TEST_MODEL" + version = "TEST_VERSION" + table = monitor_sql_client._create_baseline_table_name(model, version) + self.m_session.add_mock_sql( + query=f"DROP TABLE IF EXISTS {self.test_db_name}.{self.test_schema_name}.{table}", + result=mock_data_frame.MockDataFrame([]), + ) + self.monitor_sql_client.delete_baseline_table(model, version) + + def test_delete_dynamic_tables(self) -> None: + model = "TEST_MODEL" + model_id = sql_identifier.SqlIdentifier(model) + fully_qualified_model = f"{self.test_db_name}.{self.test_schema_name}.{model}" + version = "TEST_VERSION" + version_id = sql_identifier.SqlIdentifier(version) + monitoring_table = self.monitor_sql_client.get_monitoring_table_fully_qualified_name(model_id, version_id) + accuracy_table = self.monitor_sql_client.get_accuracy_monitoring_table_fully_qualified_name( + model_id, version_id + ) + self.m_session.add_mock_sql( + query=f"DROP DYNAMIC TABLE IF EXISTS {monitoring_table}", + result=mock_data_frame.MockDataFrame([]), + ) + self.m_session.add_mock_sql( + query=f"DROP DYNAMIC TABLE IF EXISTS {accuracy_table}", + result=mock_data_frame.MockDataFrame([]), + ) + self.monitor_sql_client.delete_dynamic_tables(fully_qualified_model, version) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/monitoring/_client/queries/record_count.ssql b/snowflake/ml/monitoring/_client/queries/record_count.ssql new file mode 100644 index 00000000..4337cc6c --- /dev/null +++ b/snowflake/ml/monitoring/_client/queries/record_count.ssql @@ -0,0 +1,14 @@ +CREATE OR REPLACE FUNCTION ${MODEL_MONITOR_NAME}_PREDICTION_COUNT() + RETURNS TABLE(event_timestamp TIMESTAMP_NTZ, count FLOAT) + AS + $$$$ +SELECT + time_slice(timestamp, 1, 'hour') as "event_timestamp", + sum(get($OUTPUT_COLUMN_NAME,'count')) as count +from + $MONITORING_TABLE +group by + 1 +order by + 1 desc + $$$$; diff --git a/snowflake/ml/monitoring/_client/queries/rmse.ssql b/snowflake/ml/monitoring/_client/queries/rmse.ssql new file mode 100644 index 00000000..f922f4ac --- /dev/null +++ b/snowflake/ml/monitoring/_client/queries/rmse.ssql @@ -0,0 +1,28 @@ +CREATE OR REPLACE FUNCTION ${MODEL_MONITOR_NAME}_RMSE() + RETURNS TABLE(event_timestamp TIMESTAMP_NTZ, value FLOAT) + AS +$$$$ +WITH metric_of_interest as ( + select + time_slice(timestamp, 1, 'hour') as event_timestamp, + AGGREGATE_METRICS:"sum_difference_squares_label_pred" as aggregate_field, + AGGREGATE_METRICS:"count" as "count" + from + $MONITORING_PRED_LABEL_JOINED_TABLE +), metric_combine as ( + select + event_timestamp, + CAST(SUM(NVL(aggregate_field, 0)) as DOUBLE) as metric_sum, + SUM("count") as metric_count + from + metric_of_interest + where + cast(aggregate_field as varchar) not in ('inf','-inf','NaN') + group by + 1 +) select + event_timestamp, + SQRT(DIV0(metric_sum,metric_count)) as VALUE +from metric_combine +order by 1 desc +$$$$; diff --git a/snowflake/ml/monitoring/entities/BUILD.bazel b/snowflake/ml/monitoring/entities/BUILD.bazel new file mode 100644 index 00000000..77faa62c --- /dev/null +++ b/snowflake/ml/monitoring/entities/BUILD.bazel @@ -0,0 +1,37 @@ +load("//bazel:py_rules.bzl", "py_library", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "entities_lib", + srcs = [ + "model_monitor_config.py", + "model_monitor_interval.py", + "output_score_type.py", + ], + deps = [ + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model:type_hints", + ], +) + +py_test( + name = "output_score_type_test", + srcs = [ + "output_score_type_test.py", + ], + deps = [ + ":entities_lib", + ], +) + +py_test( + name = "model_monitor_interval_test", + srcs = [ + "model_monitor_interval_test.py", + ], + deps = [ + ":entities_lib", + "//snowflake/ml/test_utils:mock_session", + ], +) diff --git a/snowflake/ml/monitoring/entities/model_monitor_config.py b/snowflake/ml/monitoring/entities/model_monitor_config.py new file mode 100644 index 00000000..f4083d14 --- /dev/null +++ b/snowflake/ml/monitoring/entities/model_monitor_config.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +from typing import List + +from snowflake.ml.model._client.model import model_version_impl +from snowflake.ml.monitoring.entities import model_monitor_interval + + +@dataclass +class ModelMonitorTableConfig: + source_table: str + timestamp_column: str + prediction_columns: List[str] + label_columns: List[str] + id_columns: List[str] + + +@dataclass +class ModelMonitorConfig: + model_version: model_version_impl.ModelVersion + + # Python model function name + model_function_name: str + background_compute_warehouse_name: str + # TODO: Add support for pythonic notion of time. + refresh_interval: str = model_monitor_interval.ModelMonitorRefreshInterval.DAILY + aggregation_window: model_monitor_interval.ModelMonitorAggregationWindow = ( + model_monitor_interval.ModelMonitorAggregationWindow.WINDOW_1_DAY + ) diff --git a/snowflake/ml/monitoring/entities/model_monitor_interval.py b/snowflake/ml/monitoring/entities/model_monitor_interval.py new file mode 100644 index 00000000..f9ec1ddd --- /dev/null +++ b/snowflake/ml/monitoring/entities/model_monitor_interval.py @@ -0,0 +1,46 @@ +from enum import Enum + + +class ModelMonitorAggregationWindow(Enum): + WINDOW_1_HOUR = 60 + WINDOW_1_DAY = 24 * 60 + + def __init__(self, minutes: int) -> None: + super().__init__() + self.minutes = minutes + + +class ModelMonitorRefreshInterval: + EVERY_30_MINUTES = "30 minutes" + HOURLY = "1 hours" + EVERY_6_HOURS = "6 hours" + EVERY_12_HOURS = "12 hours" + DAILY = "1 days" + WEEKLY = "7 days" + BIWEEKLY = "14 days" + MONTHLY = "30 days" + + _ALLOWED_TIME_UNITS = {"minutes": 1, "hours": 60, "days": 24 * 60} + + def __init__(self, raw_time_str: str) -> None: + try: + num_units_raw, time_units = raw_time_str.strip().split(" ") + num_units = int(num_units_raw) # try to cast + except Exception as e: + raise ValueError( + f"""Failed to parse refresh interval with exception {e}. + Provide ' '. +See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info.""" + ) + if time_units.lower() not in self._ALLOWED_TIME_UNITS: + raise ValueError( + """Invalid time unit in refresh interval. Provide ' '. +See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info.""" + ) + minutes_multiplier = self._ALLOWED_TIME_UNITS[time_units.lower()] + self.minutes = num_units * minutes_multiplier + + def __eq__(self, value: object) -> bool: + if not isinstance(value, ModelMonitorRefreshInterval): + return False + return self.minutes == value.minutes diff --git a/snowflake/ml/monitoring/entities/model_monitor_interval_test.py b/snowflake/ml/monitoring/entities/model_monitor_interval_test.py new file mode 100644 index 00000000..e8a2f913 --- /dev/null +++ b/snowflake/ml/monitoring/entities/model_monitor_interval_test.py @@ -0,0 +1,41 @@ +from absl.testing import absltest + +from snowflake.ml.monitoring.entities import model_monitor_interval + + +class ModelMonitorIntervalTest(absltest.TestCase): + def setUp(self) -> None: + super().setUp() + + def test_validate_monitor_config(self) -> None: + with self.assertRaisesRegex(ValueError, "Failed to parse refresh interval with exception"): + model_monitor_interval.ModelMonitorRefreshInterval("UNINITIALIZED") + + with self.assertRaisesRegex(ValueError, "Invalid time unit in refresh interval."): + model_monitor_interval.ModelMonitorRefreshInterval("4 years") + + with self.assertRaisesRegex(ValueError, "Failed to parse refresh interval with exception."): + model_monitor_interval.ModelMonitorRefreshInterval("2.5 hours") + ri = model_monitor_interval.ModelMonitorRefreshInterval("1 hours") + self.assertEqual(ri.minutes, 60) + + def test_predefined_refresh_intervals(self) -> None: + min_30 = model_monitor_interval.ModelMonitorRefreshInterval.EVERY_30_MINUTES + hr_1 = model_monitor_interval.ModelMonitorRefreshInterval.HOURLY + hr_6 = model_monitor_interval.ModelMonitorRefreshInterval.EVERY_6_HOURS + day_1 = model_monitor_interval.ModelMonitorRefreshInterval.DAILY + day_7 = model_monitor_interval.ModelMonitorRefreshInterval.WEEKLY + day_14 = model_monitor_interval.ModelMonitorRefreshInterval.BIWEEKLY + day_30 = model_monitor_interval.ModelMonitorRefreshInterval.MONTHLY + + self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(min_30).minutes, 30) + self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(hr_1).minutes, 60) + self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(hr_6).minutes, 6 * 60) + self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(day_1).minutes, 24 * 60) + self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(day_7).minutes, 7 * 24 * 60) + self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(day_14).minutes, 14 * 24 * 60) + self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(day_30).minutes, 30 * 24 * 60) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/monitoring/entities/output_score_type.py b/snowflake/ml/monitoring/entities/output_score_type.py new file mode 100644 index 00000000..a34eca24 --- /dev/null +++ b/snowflake/ml/monitoring/entities/output_score_type.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from enum import Enum +from typing import List, Mapping + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import type_hints +from snowflake.snowpark import types + +# Accepted data types for each OutputScoreType. +REGRESSION_DATA_TYPES = ( + types.ByteType, + types.ShortType, + types.IntegerType, + types.LongType, + types.FloatType, + types.DoubleType, + types.DecimalType, +) +CLASSIFICATION_DATA_TYPES = ( + types.ByteType, + types.ShortType, + types.IntegerType, + types.BooleanType, + types.BinaryType, +) +PROBITS_DATA_TYPES = ( + types.ByteType, + types.ShortType, + types.IntegerType, + types.LongType, + types.FloatType, + types.DoubleType, + types.DecimalType, +) + + +# OutputScoreType enum +class OutputScoreType(Enum): + UNKNOWN = "UNKNOWN" + REGRESSION = "REGRESSION" + CLASSIFICATION = "CLASSIFICATION" + PROBITS = "PROBITS" + + @classmethod + def deduce_score_type( + cls, + table_schema: Mapping[str, types.DataType], + prediction_columns: List[sql_identifier.SqlIdentifier], + task: type_hints.Task, + ) -> OutputScoreType: + """Find the score type for monitoring given a table schema and the task. + + Args: + table_schema: Dictionary of column names and types in the source table. + prediction_columns: List of prediction columns. + task: Enum value for the task of the model. + + Returns: + Enum value for the score type, informing monitoring table set up. + + Raises: + ValueError: If prediction type fails to align with task. + """ + # Already validated we have just one prediction column type + prediction_column_type = {table_schema[column_name] for column_name in prediction_columns}.pop() + + if task == type_hints.Task.TABULAR_REGRESSION: + if isinstance(prediction_column_type, REGRESSION_DATA_TYPES): + return OutputScoreType.REGRESSION + else: + raise ValueError( + f"Expected prediction column type to be one of {REGRESSION_DATA_TYPES} " + f"for REGRESSION task. Found: {prediction_column_type}." + ) + + elif task == type_hints.Task.TABULAR_BINARY_CLASSIFICATION: + if isinstance(prediction_column_type, CLASSIFICATION_DATA_TYPES): + return OutputScoreType.CLASSIFICATION + elif isinstance(prediction_column_type, PROBITS_DATA_TYPES): + return OutputScoreType.PROBITS + else: + raise ValueError( + f"Expected prediction column type to be one of {CLASSIFICATION_DATA_TYPES} " + f"or one of {PROBITS_DATA_TYPES} for CLASSIFICATION task. " + f"Found: {prediction_column_type}." + ) + + else: + raise ValueError(f"Received unsupported task for model monitoring: {task}.") diff --git a/snowflake/ml/monitoring/entities/output_score_type_test.py b/snowflake/ml/monitoring/entities/output_score_type_test.py new file mode 100644 index 00000000..13f1c54e --- /dev/null +++ b/snowflake/ml/monitoring/entities/output_score_type_test.py @@ -0,0 +1,93 @@ +import re +from typing import List, Mapping, Tuple + +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import type_hints +from snowflake.ml.monitoring.entities import output_score_type +from snowflake.snowpark import types + +DEDUCE_SCORE_TYPE_ACCEPTED_COMBINATIONS: List[ + Tuple[ + Mapping[str, types.DataType], + List[sql_identifier.SqlIdentifier], + type_hints.Task, + output_score_type.OutputScoreType, + ] +] = [ + ( + {"PREDICTION1": types.FloatType()}, + [sql_identifier.SqlIdentifier("PREDICTION1")], + type_hints.Task.TABULAR_REGRESSION, + output_score_type.OutputScoreType.REGRESSION, + ), + ( + {"PREDICTION1": types.DecimalType(38, 1)}, + [sql_identifier.SqlIdentifier("PREDICTION1")], + type_hints.Task.TABULAR_BINARY_CLASSIFICATION, + output_score_type.OutputScoreType.PROBITS, + ), + ( + {"PREDICTION1": types.BinaryType()}, + [sql_identifier.SqlIdentifier("PREDICTION1")], + type_hints.Task.TABULAR_BINARY_CLASSIFICATION, + output_score_type.OutputScoreType.CLASSIFICATION, + ), +] + + +DEDUCE_SCORE_TYPE_FAILURE_COMBINATIONS: List[ + Tuple[Mapping[str, types.DataType], List[sql_identifier.SqlIdentifier], type_hints.Task, str] +] = [ + ( + {"PREDICTION1": types.BinaryType()}, + [sql_identifier.SqlIdentifier("PREDICTION1")], + type_hints.Task.TABULAR_REGRESSION, + f"Expected prediction column type to be one of {output_score_type.REGRESSION_DATA_TYPES} " + f"for REGRESSION task. Found: {types.BinaryType()}.", + ), + ( + {"PREDICTION1": types.StringType()}, + [sql_identifier.SqlIdentifier("PREDICTION1")], + type_hints.Task.TABULAR_BINARY_CLASSIFICATION, + f"Expected prediction column type to be one of {output_score_type.CLASSIFICATION_DATA_TYPES} " + f"or one of {output_score_type.PROBITS_DATA_TYPES} for CLASSIFICATION task. " + f"Found: {types.StringType()}.", + ), + ( + {"PREDICTION1": types.BinaryType()}, + [sql_identifier.SqlIdentifier("PREDICTION1")], + type_hints.Task.UNKNOWN, + f"Received unsupported task for model monitoring: {type_hints.Task.UNKNOWN}.", + ), +] + + +class OutputScoreTypeTest(absltest.TestCase): + def test_deduce_score_type(self) -> None: + # Success cases + for ( + table_schema, + prediction_column_names, + task, + expected_score_type, + ) in DEDUCE_SCORE_TYPE_ACCEPTED_COMBINATIONS: + actual_score_type = output_score_type.OutputScoreType.deduce_score_type( + table_schema, prediction_column_names, task + ) + self.assertEqual(actual_score_type, expected_score_type) + + # Failure cases + for ( + table_schema, + prediction_column_names, + task, + expected_error, + ) in DEDUCE_SCORE_TYPE_FAILURE_COMBINATIONS: + with self.assertRaisesRegex(ValueError, re.escape(expected_error)): + output_score_type.OutputScoreType.deduce_score_type(table_schema, prediction_column_names, task) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/monitoring/monitor.py b/snowflake/ml/monitoring/monitor.py deleted file mode 100644 index be0569f4..00000000 --- a/snowflake/ml/monitoring/monitor.py +++ /dev/null @@ -1,203 +0,0 @@ -from typing import Dict, Optional, Tuple - -from typing_extensions import TypedDict - -from snowflake import snowpark -from snowflake.ml._internal import telemetry -from snowflake.snowpark import functions - -_PROJECT = "MLOps" -_SUBPROJECT = "Monitor" - - -class BucketConfig(TypedDict): - """ "Options for bucketizing the data.""" - - min: int - max: int - size: int - - -@telemetry.send_api_usage_telemetry( - project=_PROJECT, - subproject=_SUBPROJECT, -) -@snowpark._internal.utils.private_preview(version="1.0.10") # TODO: update versions when release -def compare_udfs_outputs( - base_udf_name: str, - test_udf_name: str, - input_data_df: snowpark.DataFrame, - bucket_config: Optional[BucketConfig] = None, -) -> snowpark.DataFrame: - """Compare outputs of 2 UDFs. Outputs are bucketized the based on bucketConfig. - This is useful when someone retrain a Model and deploy as UDF to compare against earlier UDF as ground truth. - NOTE: Only supports UDFs with single Column output. - - Args: - base_udf_name: used as control ground truth UDF. - test_udf_name: output of this UDF is compared against that of `base_udf`. - input_data_df: Input data used for computing metric. - bucket_config: must have the kv as {"min":xx, "max":xx, "size"}, keys in lowercase; it's width_bucket - Sqloperator's config, using https://docs.snowflake.com/en/sql-reference/functions/width_bucket. - - Returns: - snowpark.DataFrame. - "BASEUDF" is base_udf's bucketized output, "TESTUDF" is test_udf's bucketized output, - """ - if bucket_config: - assert len(bucket_config) == 3 - assert "min" in bucket_config and "max" in bucket_config and "size" in bucket_config - - argStr = ",".join(input_data_df.columns) - query1Str = _get_udf_query_str("BASEUDF", f"{base_udf_name}({argStr})", input_data_df, bucket_config) - query2Str = _get_udf_query_str("TESTUDF", f"{test_udf_name}({argStr})", input_data_df, bucket_config) - - if bucket_config: - finalStr = ( - "select A.bucket, BASEUDF, TESTUDF \n from ({}) as A \n join ({}) as B \n on A.bucket=B.bucket".format( - query1Str, query2Str - ) - ) - else: # don't bucket at all - finalStr = "select {},{} \n from ({})".format(query1Str, query2Str, input_data_df.queries["queries"][0]) - - assert input_data_df._session is not None - return input_data_df._session.sql(finalStr) - - -@telemetry.send_api_usage_telemetry( - project=_PROJECT, - subproject=_SUBPROJECT, -) -@snowpark._internal.utils.private_preview(version="1.0.10") # TODO: update versions when release -def get_basic_stats(df: snowpark.DataFrame) -> Tuple[Dict[str, int], Dict[str, int]]: - """Get basic stats of 2 Columns - Note this isn't public API. Only support min, max, stddev, HLL--cardinality estimate - - Args: - df: input Snowpark Dataframe, must have 2 and only 2 columns - - Returns: - 2 Dict for 2 columns' stats - """ - projStr = "" - stats = ["MIN", "MAX", "STDDEV", "HLL"] - assert len(df.columns) == 2 - for colName in df.columns: - for stat in stats: - projStr += f"{stat}({colName}) as {colName}_{stat}," - finalStr = "select {} \n from ({})".format(projStr[:-1], df.queries["queries"][0]) - assert df._session is not None - resDf = df._session.sql(finalStr).to_pandas() - d1 = {} - col1 = df.columns[0] - d2 = {} - col2 = df.columns[1] - for stat in stats: - d1[stat] = resDf.iloc[0][f"{col1}_{stat}"] - d2[stat] = resDf.iloc[0][f"{col2}_{stat}"] - return d1, d2 - - -@telemetry.send_api_usage_telemetry( - project=_PROJECT, - subproject=_SUBPROJECT, -) -@snowpark._internal.utils.private_preview(version="1.0.10") # TODO: update versions when release -def jensenshannon(df1: snowpark.DataFrame, colname1: str, df2: snowpark.DataFrame, colname2: str) -> float: - """ - Similar to scipy implementation: - https://github.com/scipy/scipy/blob/e4dec2c5993faa381bb4f76dce551d0d79734f8f/scipy/spatial/distance.py#L1174 - It's server solution, all computing being in Snowflake warehouse, so will be significantly faster than client. - - Args: - df1: 1st Snowpark Dataframe; - colname1: the col to be selected in df1 - df2: 2nd Snowpark Dataframe; - colname2: the col to be selected in df2 - Supported data Tyte: any data type that Snowflake supports, including VARIANT, OBJECT...etc. - - Returns: - a jensenshannon value - """ - df1 = df1.select(colname1) - df1 = ( - df1.group_by(colname1) - .agg(functions.count(colname1).alias("c1")) - .select(functions.col(colname1).alias("d1"), "c1") - ) - df2 = df2.select(colname2) - df2 = ( - df2.group_by(colname2) - .agg(functions.count(colname2).alias("c2")) - .select(functions.col(colname2).alias("d2"), "c2") - ) - - dfsum = df1.select("c1").agg(functions.sum("c1").alias("SUM1")) - sum1 = dfsum.collect()[0].as_dict()["SUM1"] - dfsum = df2.select("c2").agg(functions.sum("c2").alias("SUM2")) - sum2 = dfsum.collect()[0].as_dict()["SUM2"] - - df1 = df1.select("d1", functions.sql_expr("c1 / " + str(sum1)).alias("p")) - minp = df1.select(functions.min("P").alias("MINP")).collect()[0].as_dict()["MINP"] - df2 = df2.select("d2", functions.sql_expr("c2 / " + str(sum2)).alias("q")) - minq = df2.select(functions.min("Q").alias("MINQ")).collect()[0].as_dict()["MINQ"] - - DECAY_FACTOR = 0.5 - df = df1.join(df2, df1.d1 == df2.d2, "fullouter").select( - "d1", - "d2", - functions.sql_expr( - """ - CASE - WHEN p is NULL THEN {}*{} - ELSE p - END - """.format( - minp, DECAY_FACTOR - ) - ).alias("p"), - functions.sql_expr( - """ - CASE - WHEN q is NULL THEN {}*{} - ELSE q - END - """.format( - minq, DECAY_FACTOR - ) - ).alias("q"), - ) - - df = df.select("p", "q", functions.sql_expr("(p+q)/2.0").alias("m")) - df = df.select( - functions.sql_expr( - """ - CASE - WHEN p > 0 AND m > 0 THEN p * LOG(2, p/m) - ELSE 0 - END - """ - ).alias("left"), - functions.sql_expr( - """ - CASE - WHEN q > 0 AND m > 0 THEN q * LOG(2, q/m) - ELSE 0 - END - """ - ).alias("right"), - ) - resdf = df.select(functions.sql_expr("sqrt((sum(left) + sum(right)) / 2.0)").alias("JS")) - return float(resdf.collect()[0].as_dict()["JS"]) - - -def _get_udf_query_str( - name: str, col: str, df: snowpark.DataFrame, bucket_config: Optional[BucketConfig] = None -) -> str: - if bucket_config: - return "select count(1) as {}, width_bucket({}, {}, {}, {}) bucket from ({}) group by bucket".format( - name, col, bucket_config["min"], bucket_config["max"], bucket_config["size"], df.queries["queries"][0] - ) - else: # don't bucket at all - return f"{col} as {name}" diff --git a/snowflake/ml/monitoring/pyproject.toml b/snowflake/ml/monitoring/pyproject.toml deleted file mode 100644 index e90e7152..00000000 --- a/snowflake/ml/monitoring/pyproject.toml +++ /dev/null @@ -1,50 +0,0 @@ -[build-system] -requires = ["setuptools >= 61.0"] -build-backend = "setuptools.build_meta" - -[project] -name = "snowflake-ml-python" -version = "0.1.0" -authors = [ - {name = "Snowflake, Inc", email = "support@snowflake.com"} -] -description = "The machine learning client library that is used for interacting with Snowflake to build machine learning solutions." -license = {file = "LICENSE.txt"} -classifiers = [ - "Development Status :: 3 - Alpha", - "Environment :: Console", - "Environment :: Other Environment", - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Information Technology", - "Intended Audience :: System Administrators", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Topic :: Database", - "Topic :: Software Development", - "Topic :: Software Development :: Libraries", - "Topic :: Software Development :: Libraries :: Application Frameworks", - "Topic :: Software Development :: Libraries :: Python Modules", - "Topic :: Scientific/Engineering :: Information Analysis" -] -requires-python = ">=3.8, <4" -dependencies = [ - "numpy", - "shap", - "snowflake-connector-python[pandas]", - "snowflake-snowpark-python>=1.4.0,<2" -] - -[project.urls] -Homepage = "https://github.com/snowflakedb/snowflake-ml-python" -Documentation = "https://docs.snowflake.com/developer-guide/snowpark-ml" -Repository = "https://github.com/snowflakedb/snowflake-ml-python" -Issues = "https://github.com/snowflakedb/snowflake-ml-python/issues" -Changelog = "https://github.com/snowflakedb/snowflake-ml-python/blob/master/CHANGELOG.md" - -[tool.setuptools.packages.find] -where = ["."] -include = ["snowflake.ml.monitoring*"] diff --git a/snowflake/ml/monitoring/tests/BUILD.bazel b/snowflake/ml/monitoring/tests/BUILD.bazel deleted file mode 100644 index 594aae84..00000000 --- a/snowflake/ml/monitoring/tests/BUILD.bazel +++ /dev/null @@ -1,20 +0,0 @@ -load("//bazel:py_rules.bzl", "py_test") - -package(default_visibility = [ - "//bazel:snowml_public_common", - "//snowflake/ml/monitoring", -]) - -SHARD_COUNT = 3 - -TIMEOUT = "long" # 900s - -py_test( - name = "monitor_test", - timeout = "long", - srcs = ["monitor_test.py"], - deps = [ - "//snowflake/ml/monitoring:monitoring_lib", - "//snowflake/ml/utils:connection_params", - ], -) diff --git a/snowflake/ml/monitoring/tests/monitor_test.py b/snowflake/ml/monitoring/tests/monitor_test.py deleted file mode 100644 index 35da9075..00000000 --- a/snowflake/ml/monitoring/tests/monitor_test.py +++ /dev/null @@ -1,181 +0,0 @@ -#!/usr/bin/env python3 -import math -from typing import Any, List - -import numpy as np -import shap -from absl.testing import absltest -from sklearn.ensemble import RandomForestClassifier - -from snowflake import snowpark -from snowflake.ml.monitoring import monitor -from snowflake.ml.monitoring.shap import ShapExplainer -from snowflake.ml.utils import connection_params -from snowflake.snowpark import functions, types - - -def rel_entropy(x: float, y: float) -> float: - if np.isnan(x) or np.isnan(y): - return np.NAN - elif x > 0 and y > 0: - return x * math.log2(x / y) - elif x == 0 and y >= 0: - return 0 - else: - return np.inf - - -# This is the official JS algorithm -def JS_helper(p1: List[float], q1: List[float]) -> Any: - p = np.asarray(p1) - q = np.asarray(q1) - m = (p + q) / 2.0 - tmp = np.column_stack((p, m)) - left = np.array([rel_entropy(x, y) for x, y in tmp]) - tmp = np.column_stack((q, m)) - right = np.array([rel_entropy(x, y) for x, y in tmp]) - left_sum = np.sum(left) - right_sum = np.sum(right) - js = left_sum + right_sum - return np.sqrt(js / 2.0) - - -class MonitorTest(absltest.TestCase): - """Test Covariance matrix.""" - - def setUp(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self._session = snowpark.Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - - def tearDown(self) -> None: - self._session.close() - - def test_compare_udfs(self) -> None: - inputDf = self._session.create_dataframe( - [ - snowpark.Row(-2, -5), - snowpark.Row(8, 7), - ], - schema=["COL1", "COL2"], - ) - self._session.udf.register( - lambda x, y: x + y, - return_type=snowpark.types.IntegerType(), - input_types=[snowpark.types.IntegerType(), snowpark.types.IntegerType()], - name="add1", - replace=True, - ) - self._session.udf.register( - lambda x, y: x + y + 1, - return_type=snowpark.types.IntegerType(), - input_types=[snowpark.types.IntegerType(), snowpark.types.IntegerType()], - name="add2", - replace=True, - ) - res = monitor.compare_udfs_outputs("add1", "add2", inputDf) - pdf = res.to_pandas() - assert pdf.iloc[0][0] == -7 and pdf.iloc[0][1] == -6 - - resBucketize = monitor.compare_udfs_outputs("add1", "add2", inputDf, {"min": 0, "max": 20, "size": 2}) - pdfBucketize = resBucketize.to_pandas() - assert pdfBucketize.iloc[0][1] == 1 and pdfBucketize.iloc[0][2] == 1 - - def test_get_basic_stats(self) -> None: - inputDf = self._session.create_dataframe( - [ - snowpark.Row(-2, -5), - snowpark.Row(8, 7), - snowpark.Row(100, 98), - ], - schema=["MODEL1", "MODEL2"], - ) - d1, d2 = monitor.get_basic_stats(inputDf) - assert d1["HLL"] == d2["HLL"] == 3 - assert d1["MIN"] == -2 and d2["MIN"] == -5 - assert d1["MAX"] == 100 and d2["MAX"] == 98 - - def test_jensenshannon(self) -> None: - df1 = self._session.create_dataframe( - [ - snowpark.Row(-3), - snowpark.Row(-2), - snowpark.Row(8), - snowpark.Row(100), - ], - schema=["col1"], - ) - - df2 = self._session.create_dataframe( - [ - snowpark.Row(-2), - snowpark.Row(8), - snowpark.Row(100), - snowpark.Row(140), - ], - schema=["col2"], - ) - - df3 = self._session.create_dataframe( - [ - snowpark.Row(-3), - snowpark.Row(-2), - snowpark.Row(8), - snowpark.Row(8), - snowpark.Row(8), - snowpark.Row(100), - ], - schema=["col1"], - ) - - js = monitor.jensenshannon(df1, "col1", df2, "col2") - assert abs(js - JS_helper([0.125, 0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25, 0.125])) <= 1e-5 - js = monitor.jensenshannon(df1, "col1", df3, "col1") - assert abs(js - JS_helper([0.25, 0.25, 0.25, 0.25], [1.0 / 6, 1.0 / 6, 0.5, 1.0 / 6])) <= 1e-5 - - def test_shap(self) -> None: - X_train = np.random.randint(1, 90, (4, 5)) - y_train = np.random.randint(0, 3, (4, 1)) - - clf = RandomForestClassifier(max_depth=3, random_state=0) - clf.fit(X_train, y_train) - - test_sample = np.array([[3, 2, 1, 4, 5]]) - - inputDf = self._session.create_dataframe( - [snowpark.Row(3, 2, 1, 4, 5)], - schema=["COL1", "COL2", "COL3", "COL4", "COL5"], - ) - - sf_explainer = ShapExplainer(self._session, clf.predict, X_train) - shapdf2 = sf_explainer.get_shap(inputDf) - shapdf2_1 = sf_explainer(inputDf) - assert shapdf2_1 is not None - v2 = shapdf2.collect()[0].as_dict(True)["SHAP"] - v2 = v2.replace("\n", "").strip("[] ").split(",") - - shap_explainer1 = shap.Explainer(clf.predict, X_train) - shap_values1 = shap_explainer1(test_sample) - - self._session.add_packages("numpy", "shap") - - def get_shap(input: list) -> list: # type: ignore[type-arg] - shap_explainer = shap.Explainer(clf.predict, X_train) - shap_values = shap_explainer(np.array([input])) - return shap_values.values.tolist() # type: ignore[no-any-return] - - shapudf = self._session.udf.register(get_shap, input_types=[types.ArrayType()], return_type=types.ArrayType()) - - shapdf1 = inputDf.select( - functions.array_construct("COL1", "COL2", "COL3", "COL4", "COL5").alias("INPUT") - ).select(functions.get(shapudf("INPUT"), 0).alias("SHAP")) - v1 = shapdf1.collect()[0].as_dict(True)["SHAP"] - v1 = v1.replace("\n", "").strip("[] ").split(",") - - assert abs(float(v1[0]) - shap_values1.values[0][0]) <= 1e-5 - assert abs(float(v1[1]) - shap_values1.values[0][1]) <= 1e-5 - assert abs(float(v2[0]) - shap_values1.values[0][0]) <= 1e-5 - assert abs(float(v2[1]) - shap_values1.values[0][1]) <= 1e-5 - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/registry/BUILD.bazel b/snowflake/ml/registry/BUILD.bazel index 08b6eb7e..32d573b6 100644 --- a/snowflake/ml/registry/BUILD.bazel +++ b/snowflake/ml/registry/BUILD.bazel @@ -2,55 +2,6 @@ load("//bazel:py_rules.bzl", "py_library", "py_package", "py_test") package(default_visibility = ["//visibility:public"]) -py_library( - name = "model_registry", - srcs = [ - "model_registry.py", - ], - deps = [ - ":schema", - "//snowflake/ml/_internal:telemetry", - "//snowflake/ml/_internal/utils:formatting", - "//snowflake/ml/_internal/utils:identifier", - "//snowflake/ml/_internal/utils:query_result_checker", - "//snowflake/ml/_internal/utils:spcs_attribution_utils", - "//snowflake/ml/_internal/utils:table_manager", - "//snowflake/ml/_internal/utils:uri", - "//snowflake/ml/dataset", - "//snowflake/ml/model:_api", - "//snowflake/ml/model:deploy_platforms", - "//snowflake/ml/model:model_signature", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/modeling/framework", - ], -) - -py_test( - name = "model_registry_test", - srcs = ["model_registry_test.py"], - deps = [ - ":model_registry", - ":schema", - "//snowflake/ml/test_utils:mock_data_frame", - "//snowflake/ml/test_utils:mock_session", - ], -) - -py_library( - name = "schema", - srcs = [ - "_initial_schema.py", - "_schema.py", - "_schema_upgrade_plans.py", - "_schema_version_manager.py", - ], - visibility = ["//bazel:snowml_public_common"], - deps = [ - "//snowflake/ml/_internal/utils:query_result_checker", - "//snowflake/ml/_internal/utils:table_manager", - ], -) - py_library( name = "registry_impl", srcs = [ @@ -63,6 +14,7 @@ py_library( "//snowflake/ml/model", "//snowflake/ml/model:model_signature", "//snowflake/ml/model:type_hints", + "//snowflake/ml/monitoring:model_monitor_impl", "//snowflake/ml/registry/_manager:model_manager", ], ) @@ -73,9 +25,7 @@ py_library( "__init__.py", ], deps = [ - ":model_registry", ":registry_impl", - ":schema", ], ) @@ -96,7 +46,6 @@ py_test( name = "package_visibility_test", srcs = ["package_visibility_test.py"], deps = [ - ":model_registry", ":registry", ], ) @@ -105,7 +54,6 @@ py_package( name = "model_registry_pkg", packages = ["snowflake.ml"], deps = [ - ":model_registry", ":registry", ], ) diff --git a/snowflake/ml/registry/_initial_schema.py b/snowflake/ml/registry/_initial_schema.py deleted file mode 100644 index 10ba648e..00000000 --- a/snowflake/ml/registry/_initial_schema.py +++ /dev/null @@ -1,142 +0,0 @@ -from typing import Any, Dict, List, Tuple - -from snowflake import snowpark -from snowflake.ml._internal.utils import identifier, query_result_checker, table_manager - -# THIS FILE CONTAINS INITIAL REGISTRY SCHEMA. -# !!!!!!! WARNING !!!!!!! -# Please do not modify initial schema and use schema evolution mechanism in SchemaVersionManager to change the schema. -# If you are touching this file, make sure you understand what you are doing. - -_INITIAL_VERSION: int = 0 - -_MODELS_TABLE_NAME: str = "_SYSTEM_REGISTRY_MODELS" -_METADATA_TABLE_NAME: str = "_SYSTEM_REGISTRY_METADATA" -_DEPLOYMENT_TABLE_NAME: str = "_SYSTEM_REGISTRY_DEPLOYMENTS" -_ARTIFACT_TABLE_NAME: str = "_SYSTEM_REGISTRY_ARTIFACTS" - -_INITIAL_REGISTRY_TABLE_SCHEMA: List[Tuple[str, str]] = [ - ("CREATION_CONTEXT", "VARCHAR"), - ("CREATION_ENVIRONMENT_SPEC", "OBJECT"), - ("CREATION_ROLE", "VARCHAR"), - ("CREATION_TIME", "TIMESTAMP_TZ"), - ("ID", "VARCHAR PRIMARY KEY RELY"), - ("INPUT_SPEC", "OBJECT"), - ("NAME", "VARCHAR"), - ("OUTPUT_SPEC", "OBJECT"), - ("RUNTIME_ENVIRONMENT_SPEC", "OBJECT"), - ("TRAINING_DATASET_ID", "VARCHAR"), - ("TYPE", "VARCHAR"), - ("URI", "VARCHAR"), - ("VERSION", "VARCHAR"), -] - -_INITIAL_METADATA_TABLE_SCHEMA: List[Tuple[str, str]] = [ - ("ATTRIBUTE_NAME", "VARCHAR"), - ("EVENT_ID", "VARCHAR UNIQUE NOT NULL"), - ("EVENT_TIMESTAMP", "TIMESTAMP_TZ"), - ("MODEL_ID", "VARCHAR FOREIGN KEY REFERENCES {registry_table_name}(ID) RELY"), - ("OPERATION", "VARCHAR"), - ("ROLE", "VARCHAR"), - ("SEQUENCE_ID", "BIGINT AUTOINCREMENT START 0 INCREMENT 1 PRIMARY KEY"), - ("VALUE", "OBJECT"), -] - -_INITIAL_DEPLOYMENTS_TABLE_SCHEMA: List[Tuple[str, str]] = [ - ("CREATION_TIME", "TIMESTAMP_TZ"), - ("MODEL_ID", "VARCHAR FOREIGN KEY REFERENCES {registry_table_name}(ID) RELY"), - ("DEPLOYMENT_NAME", "VARCHAR"), - ("OPTIONS", "VARIANT"), - ("TARGET_PLATFORM", "VARCHAR"), - ("ROLE", "VARCHAR"), - ("STAGE_PATH", "VARCHAR"), - ("SIGNATURE", "VARIANT"), - ("TARGET_METHOD", "VARCHAR"), -] - -_INITIAL_ARTIFACT_TABLE_SCHEMA: List[Tuple[str, str]] = [ - ("ID", "VARCHAR"), - ("TYPE", "VARCHAR"), - ("NAME", "VARCHAR"), - ("VERSION", "VARCHAR"), - ("CREATION_ROLE", "VARCHAR"), - ("CREATION_TIME", "TIMESTAMP_TZ"), - ("ARTIFACT_SPEC", "OBJECT"), - # Below is out-of-line constraints of Snowflake table. - # See https://docs.snowflake.com/en/sql-reference/sql/create-table - ("PRIMARY KEY", "(ID, TYPE) RELY"), -] - -_INITIAL_TABLE_SCHEMAS = { - _MODELS_TABLE_NAME: _INITIAL_REGISTRY_TABLE_SCHEMA, - _METADATA_TABLE_NAME: _INITIAL_METADATA_TABLE_SCHEMA, - _DEPLOYMENT_TABLE_NAME: _INITIAL_DEPLOYMENTS_TABLE_SCHEMA, - _ARTIFACT_TABLE_NAME: _INITIAL_ARTIFACT_TABLE_SCHEMA, -} - - -def create_initial_registry_tables( - session: snowpark.Session, - database_name: str, - schema_name: str, - statement_params: Dict[str, Any], -) -> None: - """Creates initial set of tables for registry. This is the legacy schema from which schema evolution is supported. - - Args: - session: Active session to create tables. - database_name: Name of database in which tables will be created. - schema_name: Name of schema in which tables will be created. - statement_params: Statement parameters for telemetry tracking. - """ - model_table_full_path = table_manager.get_fully_qualified_table_name(database_name, schema_name, _MODELS_TABLE_NAME) - - for table_name, schema_template in _INITIAL_TABLE_SCHEMAS.items(): - table_schema = [(k, v.format(registry_table_name=model_table_full_path)) for k, v in schema_template] - table_manager.create_single_table( - session=session, - database_name=database_name, - schema_name=schema_name, - table_name=table_name, - table_schema=table_schema, - statement_params=statement_params, - ) - - -def check_access(session: snowpark.Session, database_name: str, schema_name: str) -> None: - """Check that the required tables exist and are accessible by the current role. - - Args: - session: Active session to execution SQL queries. - database_name: Name of database where schema tables live. - schema_name: Name of schema where schema tables live. - """ - query_result_checker.SqlResultValidator( - session, - query=f"SHOW DATABASES LIKE '{identifier.get_unescaped_names(database_name)}'", - ).has_dimensions(expected_rows=1).validate() - - query_result_checker.SqlResultValidator( - session, - query=f"SHOW SCHEMAS LIKE '{identifier.get_unescaped_names(schema_name)}' IN DATABASE {database_name}", - ).has_dimensions(expected_rows=1).validate() - - full_qualified_schema_name = table_manager.get_fully_qualified_schema_name(database_name, schema_name) - - table_manager.validate_table_exist( - session, - identifier.get_unescaped_names(_MODELS_TABLE_NAME), - full_qualified_schema_name, - ) - table_manager.validate_table_exist( - session, - identifier.get_unescaped_names(_METADATA_TABLE_NAME), - full_qualified_schema_name, - ) - table_manager.validate_table_exist( - session, - identifier.get_unescaped_names(_DEPLOYMENT_TABLE_NAME), - full_qualified_schema_name, - ) - - # TODO(zzhu): Also check validity of views. diff --git a/snowflake/ml/registry/_manager/model_manager.py b/snowflake/ml/registry/_manager/model_manager.py index ae34c95c..b7956d43 100644 --- a/snowflake/ml/registry/_manager/model_manager.py +++ b/snowflake/ml/registry/_manager/model_manager.py @@ -50,7 +50,7 @@ def log_model( sample_input_data: Optional[model_types.SupportedDataType] = None, code_paths: Optional[List[str]] = None, ext_modules: Optional[List[ModuleType]] = None, - model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN, + task: model_types.Task = model_types.Task.UNKNOWN, options: Optional[model_types.ModelSaveOption] = None, statement_params: Optional[Dict[str, Any]] = None, ) -> model_version_impl.ModelVersion: @@ -90,7 +90,7 @@ def log_model( sample_input_data=sample_input_data, code_paths=code_paths, ext_modules=ext_modules, - model_objective=model_objective, + task=task, options=options, statement_params=statement_params, ) @@ -110,7 +110,7 @@ def _log_model( sample_input_data: Optional[model_types.SupportedDataType] = None, code_paths: Optional[List[str]] = None, ext_modules: Optional[List[ModuleType]] = None, - model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN, + task: model_types.Task = model_types.Task.UNKNOWN, options: Optional[model_types.ModelSaveOption] = None, statement_params: Optional[Dict[str, Any]] = None, ) -> model_version_impl.ModelVersion: @@ -159,7 +159,7 @@ def _log_model( code_paths=code_paths, ext_modules=ext_modules, options=options, - model_objective=model_objective, + task=task, ) statement_params = telemetry.add_statement_params_custom_tags( statement_params, model_metadata.telemetry_metadata() diff --git a/snowflake/ml/registry/_manager/model_manager_test.py b/snowflake/ml/registry/_manager/model_manager_test.py index 2a29985e..393bb06e 100644 --- a/snowflake/ml/registry/_manager/model_manager_test.py +++ b/snowflake/ml/registry/_manager/model_manager_test.py @@ -212,7 +212,7 @@ def test_log_model_minimal(self) -> None: code_paths=None, ext_modules=None, options=None, - model_objective=type_hints.ModelObjective.UNKNOWN, + task=type_hints.Task.UNKNOWN, ) mock_create_from_stage.assert_called_once_with( composed_model=mock.ANY, @@ -281,7 +281,7 @@ def test_log_model_1(self) -> None: code_paths=None, ext_modules=None, options=None, - model_objective=type_hints.ModelObjective.UNKNOWN, + task=type_hints.Task.UNKNOWN, ) mock_create_from_stage.assert_called_once_with( composed_model=mock.ANY, @@ -335,7 +335,7 @@ def test_log_model_2(self) -> None: code_paths=None, ext_modules=None, options=m_options, - model_objective=type_hints.ModelObjective.UNKNOWN, + task=type_hints.Task.UNKNOWN, ) mock_create_from_stage.assert_called_once_with( composed_model=mock.ANY, @@ -392,7 +392,7 @@ def test_log_model_3(self) -> None: code_paths=m_code_paths, ext_modules=m_ext_modules, options=None, - model_objective=type_hints.ModelObjective.UNKNOWN, + task=type_hints.Task.UNKNOWN, ) mock_create_from_stage.assert_called_once_with( composed_model=mock.ANY, @@ -449,7 +449,7 @@ def test_log_model_4(self) -> None: code_paths=None, ext_modules=None, options=None, - model_objective=type_hints.ModelObjective.UNKNOWN, + task=type_hints.Task.UNKNOWN, ) mock_create_from_stage.assert_called_once_with( composed_model=mock.ANY, @@ -550,7 +550,7 @@ def test_log_model_fully_qualified(self) -> None: code_paths=None, ext_modules=None, options=None, - model_objective=type_hints.ModelObjective.UNKNOWN, + task=type_hints.Task.UNKNOWN, ) mock_create_from_stage.assert_called_once_with( composed_model=mock.ANY, diff --git a/snowflake/ml/registry/_schema.py b/snowflake/ml/registry/_schema.py deleted file mode 100644 index 9b5e1609..00000000 --- a/snowflake/ml/registry/_schema.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Dict, List, Tuple, Type - -from snowflake.ml.registry import _initial_schema, _schema_upgrade_plans - -# BUMP THIS VERSION WHENEVER YOU CHANGE ANY SCHEMA TABLES. -# ALSO UPDATE SCHEMA UPGRADE PLANS. -_CURRENT_SCHEMA_VERSION = 3 - -_REGISTRY_TABLE_SCHEMA: List[Tuple[str, str]] = [ - ("CREATION_CONTEXT", "VARCHAR"), - ("CREATION_ENVIRONMENT_SPEC", "OBJECT"), - ("CREATION_ROLE", "VARCHAR"), - ("CREATION_TIME", "TIMESTAMP_TZ"), - ("ID", "VARCHAR PRIMARY KEY RELY"), - ("INPUT_SPEC", "OBJECT"), - ("NAME", "VARCHAR"), - ("OUTPUT_SPEC", "OBJECT"), - ("RUNTIME_ENVIRONMENT_SPEC", "OBJECT"), - ("ARTIFACT_IDS", "ARRAY"), - ("TYPE", "VARCHAR"), - ("URI", "VARCHAR"), - ("VERSION", "VARCHAR"), -] - -_METADATA_TABLE_SCHEMA: List[Tuple[str, str]] = [ - ("ATTRIBUTE_NAME", "VARCHAR"), - ("EVENT_ID", "VARCHAR UNIQUE NOT NULL"), - ("EVENT_TIMESTAMP", "TIMESTAMP_TZ"), - ("MODEL_ID", "VARCHAR FOREIGN KEY REFERENCES {registry_table_name}(ID) RELY"), - ("OPERATION", "VARCHAR"), - ("ROLE", "VARCHAR"), - ("SEQUENCE_ID", "BIGINT AUTOINCREMENT START 0 INCREMENT 1 PRIMARY KEY"), - ("VALUE", "OBJECT"), -] - -_DEPLOYMENTS_TABLE_SCHEMA: List[Tuple[str, str]] = [ - ("CREATION_TIME", "TIMESTAMP_TZ"), - ("MODEL_ID", "VARCHAR FOREIGN KEY REFERENCES {registry_table_name}(ID) RELY"), - ("DEPLOYMENT_NAME", "VARCHAR"), - ("OPTIONS", "VARIANT"), - ("TARGET_PLATFORM", "VARCHAR"), - ("ROLE", "VARCHAR"), - ("STAGE_PATH", "VARCHAR"), - ("SIGNATURE", "VARIANT"), - ("TARGET_METHOD", "VARCHAR"), -] - -_ARTIFACT_TABLE_SCHEMA: List[Tuple[str, str]] = [ - ("ID", "VARCHAR"), - ("TYPE", "VARCHAR"), - ("NAME", "VARCHAR"), - ("VERSION", "VARCHAR"), - ("CREATION_ROLE", "VARCHAR"), - ("CREATION_TIME", "TIMESTAMP_TZ"), - ("ARTIFACT_SPEC", "VARCHAR"), - # Below is out-of-line constraints of Snowflake table. - # See https://docs.snowflake.com/en/sql-reference/sql/create-table - ("PRIMARY KEY", "(ID, TYPE) RELY"), -] - -# Note, one can add/remove tables from this tuple as well. As long as correct schema update process is followed. -# In case of a new table, they should not be defined in _initial_schema. -_CURRENT_TABLE_SCHEMAS = { - _initial_schema._MODELS_TABLE_NAME: _REGISTRY_TABLE_SCHEMA, - _initial_schema._METADATA_TABLE_NAME: _METADATA_TABLE_SCHEMA, - _initial_schema._DEPLOYMENT_TABLE_NAME: _DEPLOYMENTS_TABLE_SCHEMA, - _initial_schema._ARTIFACT_TABLE_NAME: _ARTIFACT_TABLE_SCHEMA, -} - - -_SCHEMA_UPGRADE_PLANS: Dict[int, Type[_schema_upgrade_plans.BaseSchemaUpgradePlans]] = { - # Currently _CURRENT_SCHEMA_VERSION == _initial_schema._INITIAL_VERSION, so no entry. - # But if schema evolves it must contain: - # Key = a version number - # Value = a subclass of _schema_upgrades.BaseSchemaUpgrade - # NOTE, all version from _INITIAL_VERSION + 1 till _CURRENT_SCHEMA_VERSION must exists. - 1: _schema_upgrade_plans.AddTrainingDatasetIdIfNotExists, - 2: _schema_upgrade_plans.ReplaceTrainingDatasetIdWithArtifactIds, - 3: _schema_upgrade_plans.ChangeArtifactSpecFromObjectToVarchar, -} - -assert len(_SCHEMA_UPGRADE_PLANS) == _CURRENT_SCHEMA_VERSION - _initial_schema._INITIAL_VERSION diff --git a/snowflake/ml/registry/_schema_upgrade_plans.py b/snowflake/ml/registry/_schema_upgrade_plans.py deleted file mode 100644 index fa79e539..00000000 --- a/snowflake/ml/registry/_schema_upgrade_plans.py +++ /dev/null @@ -1,116 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, Optional - -from snowflake import snowpark -from snowflake.ml._internal.utils import table_manager -from snowflake.ml.registry import _initial_schema - - -class BaseSchemaUpgradePlans(ABC): - """Abstract Class for specifying schema upgrades for registry.""" - - def __init__( - self, - session: snowpark.Session, - database_name: str, - schema_name: str, - statement_params: Optional[Dict[str, Any]] = None, - ) -> None: - self._session = session - self._database = database_name - self._schema = schema_name - self._statement_params = statement_params - - @abstractmethod - def upgrade(self) -> None: - """Convert schema from previous version to `_current_version`.""" - pass - - -class AddTrainingDatasetIdIfNotExists(BaseSchemaUpgradePlans): - """Add Column TRAINING_DATASET_ID in registry schema table.""" - - def __init__( - self, - session: snowpark.Session, - database_name: str, - schema_name: str, - statement_params: Optional[Dict[str, Any]] = None, - ) -> None: - super().__init__(session, database_name, schema_name, statement_params) - - def upgrade(self) -> None: - full_schema_path = f"{self._database}.{self._schema}" - table_schema_dict = table_manager.get_table_schema( - self._session, _initial_schema._MODELS_TABLE_NAME, full_schema_path - ) - new_column = "TRAINING_DATASET_ID" - if new_column not in table_schema_dict: - self._session.sql( - f"""ALTER TABLE {self._database}.{self._schema}.{_initial_schema._MODELS_TABLE_NAME} - ADD COLUMN {new_column} VARCHAR - """ - ).collect(statement_params=self._statement_params) - - -class ReplaceTrainingDatasetIdWithArtifactIds(BaseSchemaUpgradePlans): - """Drop column `TRAINING_DATASET_ID`, add `ARTIFACT_IDS`.""" - - def __init__( - self, - session: snowpark.Session, - database_name: str, - schema_name: str, - statement_params: Optional[Dict[str, Any]] = None, - ) -> None: - super().__init__(session, database_name, schema_name, statement_params) - - def upgrade(self) -> None: - full_schema_path = f"{self._database}.{self._schema}" - old_column = "TRAINING_DATASET_ID" - self._session.sql( - f"""ALTER TABLE {full_schema_path}.{_initial_schema._MODELS_TABLE_NAME} - DROP COLUMN {old_column} - """ - ).collect(statement_params=self._statement_params) - - new_column = "ARTIFACT_IDS" - self._session.sql( - f"""ALTER TABLE {full_schema_path}.{_initial_schema._MODELS_TABLE_NAME} - ADD COLUMN {new_column} ARRAY - """ - ).collect(statement_params=self._statement_params) - - -class ChangeArtifactSpecFromObjectToVarchar(BaseSchemaUpgradePlans): - """Change artifact spec type from object to varchar. It's fine to drop the column as it's empty.""" - - def __init__( - self, - session: snowpark.Session, - database_name: str, - schema_name: str, - statement_params: Optional[Dict[str, Any]] = None, - ) -> None: - super().__init__(session, database_name, schema_name, statement_params) - - def upgrade(self) -> None: - full_schema_path = f"{self._database}.{self._schema}" - update_col = "ARTIFACT_SPEC" - self._session.sql( - f"""ALTER TABLE {full_schema_path}.{_initial_schema._ARTIFACT_TABLE_NAME} - DROP COLUMN {update_col} - """ - ).collect(statement_params=self._statement_params) - - self._session.sql( - f"""ALTER TABLE {full_schema_path}.{_initial_schema._ARTIFACT_TABLE_NAME} - ADD COLUMN {update_col} VARCHAR - """ - ).collect(statement_params=self._statement_params) - - self._session.sql( - f"""COMMENT ON COLUMN {full_schema_path}.{_initial_schema._ARTIFACT_TABLE_NAME}.{update_col} IS - 'This column is VARCHAR but supposed to store a valid JSON object' - """ - ).collect(statement_params=self._statement_params) diff --git a/snowflake/ml/registry/_schema_version_manager.py b/snowflake/ml/registry/_schema_version_manager.py deleted file mode 100644 index 7580b3d6..00000000 --- a/snowflake/ml/registry/_schema_version_manager.py +++ /dev/null @@ -1,163 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -from snowflake import snowpark -from snowflake.ml._internal.utils import identifier, query_result_checker, table_manager -from snowflake.ml.registry import _initial_schema, _schema - -_SCHEMA_VERSION_TABLE_NAME: str = "_SYSTEM_REGISTRY_SCHEMA_VERSION" - -_SCHEMA_VERSION_TABLE_SCHEMA: List[Tuple[str, str]] = [ - ("VERSION", "NUMBER"), - ("CREATION_TIME", "TIMESTAMP_TZ"), - ("INFO", "OBJECT"), -] - - -class SchemaVersionManager: - """Registry Schema Version Manager to deal with schema evolution.""" - - def __init__(self, session: snowpark.Session, database: str, schema: str) -> None: - """SchemaVersionManager constructor. - - Args: - session: Snowpark session - database: Database in which registry is being managed. - schema: Schema in which registry is being managed. - """ - self._session = session - self._database = database - self._schema = schema - - def get_deployed_version(self, statement_params: Optional[Dict[str, Any]] = None) -> int: - """Get current version of deployed schema. - - Args: - statement_params: Statement parameters for telemetry tracking. - - Returns: - Version of deployed schema. - """ - if not table_manager.validate_table_exist( - self._session, _SCHEMA_VERSION_TABLE_NAME, self._get_qualified_schema() - ): - return _initial_schema._INITIAL_VERSION - - result = ( - query_result_checker.SqlResultValidator( - session=self._session, - query=f"""SELECT MAX(VERSION) AS MAX_VERSION - FROM {self._get_qualified_schema()}.{_SCHEMA_VERSION_TABLE_NAME} - """, - statement_params=statement_params, - ) - .has_dimensions(expected_rows=1, expected_cols=1) - .has_column("MAX_VERSION") - .validate() - ) - cur_version = result[0]["MAX_VERSION"] - return int(cur_version) - - def validate_schema_version(self, statement_params: Optional[Dict[str, Any]] = None) -> None: - """Checks if currently deployed schema is up to date. - - Args: - statement_params: Statement parameters for telemetry tracking. - - Raises: - RuntimeError: if deployed schema different from registry schema. - """ - deployed_version = self.get_deployed_version(statement_params) - if deployed_version > _schema._CURRENT_SCHEMA_VERSION: - raise RuntimeError( - f"Deployed registry schema version ({deployed_version}) is ahead of current " - f"package ({_schema._CURRENT_SCHEMA_VERSION}). Please update the package." - ) - elif deployed_version < _schema._CURRENT_SCHEMA_VERSION: - raise RuntimeError( - f"Registry schema version ({_schema._CURRENT_SCHEMA_VERSION}) is ahead of deployed " - f"schema ({deployed_version}). Please call create_model_registry() to upgrade." - ) - - def try_upgrade(self, statement_params: Optional[Dict[str, Any]] = None) -> None: - """Upgrade deployed schema to current. - - Args: - statement_params: Statement parameters for telemetry tracking. - - Raises: - RuntimeError: Deployed schema is newer than package. - """ - deployed_version = self.get_deployed_version(statement_params) - if deployed_version > _schema._CURRENT_SCHEMA_VERSION: - raise RuntimeError( - f"Deployed registry schema version ({deployed_version}) is ahead of current " - f"package ({_schema._CURRENT_SCHEMA_VERSION}). Please update the package." - ) - - any_upgraded = False - for cur_version in range(deployed_version, _schema._CURRENT_SCHEMA_VERSION): - assert cur_version + 1 in _schema._SCHEMA_UPGRADE_PLANS, "version number not exist." - plan = _schema._SCHEMA_UPGRADE_PLANS[cur_version + 1]( - self._session, self._database, self._schema, statement_params - ) - plan.upgrade() - any_upgraded = True - - self._validate_schema() - - if any_upgraded: - self._create_or_update_version_table(statement_params) - - def _create_or_update_version_table(self, statement_params: Optional[Dict[str, Any]] = None) -> None: - if not table_manager.validate_table_exist( - self._session, _SCHEMA_VERSION_TABLE_NAME, self._get_qualified_schema() - ): - table_manager.create_single_table( - session=self._session, - database_name=self._database, - schema_name=self._schema, - table_name=_SCHEMA_VERSION_TABLE_NAME, - table_schema=_SCHEMA_VERSION_TABLE_SCHEMA, - statement_params=statement_params, - ) - query_result_checker.SqlResultValidator( - session=self._session, - query=f"""INSERT INTO {self._get_qualified_schema_version_table()} (VERSION, CREATION_TIME) - VALUES ({_schema._CURRENT_SCHEMA_VERSION}, CURRENT_TIMESTAMP()) - """, - statement_params=statement_params, - ).insertion_success(expected_num_rows=1).validate() - - def _validate_schema(self) -> None: - for table_name in _initial_schema._INITIAL_TABLE_SCHEMAS: - if table_name not in _schema._CURRENT_TABLE_SCHEMAS: - # This table must be deleted by transformations. - if table_manager.validate_table_exist( - self._session, - table_name, - table_manager.get_fully_qualified_schema_name(self._database, self._schema), - ): - raise RuntimeError( - f"Schema transformation error. A table '{table_name}' found, which should not exist." - ) - - exclude_cols = ["PRIMARY KEY"] - for table_name, expected_schema in _schema._CURRENT_TABLE_SCHEMAS.items(): - deployed_schema_dict = table_manager.get_table_schema( - self._session, table_name, self._get_qualified_schema() - ) - - # TODO check type as well. - for col_name, _ in expected_schema: - if col_name not in deployed_schema_dict and col_name not in exclude_cols: - raise RuntimeError(f"Schema table: {table_name} doesn't have required column:'{col_name}'.") - - def _get_qualified_schema(self) -> str: - return table_manager.get_fully_qualified_schema_name(self._database, self._schema) - - def _get_qualified_schema_version_table(self) -> str: - return table_manager.get_fully_qualified_table_name( - self._database, - self._schema, - identifier.get_inferred_name(_SCHEMA_VERSION_TABLE_NAME), - ) diff --git a/snowflake/ml/registry/model_registry.py b/snowflake/ml/registry/model_registry.py deleted file mode 100644 index e5b04df8..00000000 --- a/snowflake/ml/registry/model_registry.py +++ /dev/null @@ -1,2048 +0,0 @@ -import inspect -import json -import sys -import textwrap -import types -import warnings -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Union, - cast, -) -from uuid import uuid1 - -from absl import logging - -from snowflake import connector, snowpark -from snowflake.ml._internal import telemetry -from snowflake.ml._internal.utils import ( - formatting, - identifier, - query_result_checker, - spcs_attribution_utils, - table_manager, - uri, -) -from snowflake.ml.model import ( - _api as model_api, - deploy_platforms, - model_signature, - type_hints as model_types, -) -from snowflake.ml.registry import _initial_schema, _schema_version_manager -from snowflake.snowpark._internal import utils as snowpark_utils - -if TYPE_CHECKING: - import pandas as pd - -_DEFAULT_REGISTRY_NAME: str = "_SYSTEM_MODEL_REGISTRY" -_DEFAULT_SCHEMA_NAME: str = "_SYSTEM_MODEL_REGISTRY_SCHEMA" -_MODELS_TABLE_NAME: str = "_SYSTEM_REGISTRY_MODELS" -_METADATA_TABLE_NAME: str = "_SYSTEM_REGISTRY_METADATA" -_DEPLOYMENT_TABLE_NAME: str = "_SYSTEM_REGISTRY_DEPLOYMENTS" - -# Metadata operation types. -_SET_METADATA_OPERATION: str = "SET" -_ADD_METADATA_OPERATION: str = "ADD" -_DROP_METADATA_OPERATION: str = "DROP" - -# Metadata types. -_METADATA_ATTRIBUTE_DESCRIPTION: str = "DESCRIPTION" -_METADATA_ATTRIBUTE_METRICS: str = "METRICS" -_METADATA_ATTRIBUTE_REGISTRATION: str = "REGISTRATION" -_METADATA_ATTRIBUTE_TAGS: str = "TAGS" -_METADATA_ATTRIBUTE_DEPLOYMENT: str = "DEPLOYMENTS" -_METADATA_ATTRIBUTE_DELETION: str = "DELETION" - -# Leaving out REGISTRATION/DEPLOYMENT events as they will be handled differently from all mutable attributes. -_LIST_METADATA_ATTRIBUTE: List[str] = [ - _METADATA_ATTRIBUTE_DESCRIPTION, - _METADATA_ATTRIBUTE_METRICS, - _METADATA_ATTRIBUTE_TAGS, -] -_TELEMETRY_PROJECT = "MLOps" -_TELEMETRY_SUBPROJECT = "ModelRegistry" - -_STAGE_PREFIX = "@" - - -def _create_registry_database( - session: snowpark.Session, - database_name: str, - statement_params: Dict[str, Any], -) -> None: - """Private helper to create the model registry database. - - The creation will be skipped if the target database already exists. - - Args: - session: Session object to communicate with Snowflake. - database_name: Desired name of the model registry database. - statement_params: Function usage statement parameters used in sql query executions. - """ - registry_databases = session.sql(f"SHOW DATABASES LIKE '{identifier.get_unescaped_names(database_name)}'").collect( - statement_params=statement_params - ) - if len(registry_databases) > 0: - logging.warning(f"The database {database_name} already exists. Skipping creation.") - return - - session.sql(f"CREATE DATABASE {database_name}").collect(statement_params=statement_params) - - -def _create_registry_schema( - session: snowpark.Session, - database_name: str, - schema_name: str, - statement_params: Dict[str, Any], -) -> None: - """Private helper to create the model registry schema. - - The creation will be skipped if the target schema already exists. - - Args: - session: Session object to communicate with Snowflake. - database_name: Desired name of the model registry database. - schema_name: Desired name of the schema used by this model registry inside the database. - statement_params: Function usage statement parameters used in sql query executions. - """ - # The default PUBLIC schema is created by default so it might already exist even in a new database. - registry_schemas = session.sql( - f"SHOW SCHEMAS LIKE '{identifier.get_unescaped_names(schema_name)}' IN DATABASE {database_name}" - ).collect(statement_params=statement_params) - - if len(registry_schemas) > 0: - logging.warning( - f"The schema {table_manager.get_fully_qualified_schema_name(database_name, schema_name)} already exists. " - + "Skipping creation." - ) - return - - session.sql(f"CREATE SCHEMA {table_manager.get_fully_qualified_schema_name(database_name, schema_name)}").collect( - statement_params=statement_params - ) - - -def _create_registry_views( - session: snowpark.Session, - database_name: str, - schema_name: str, - registry_table_name: str, - metadata_table_name: str, - deployment_table_name: str, - statement_params: Dict[str, Any], -) -> None: - """Create views on underlying ModelRegistry tables. - - Args: - session: Session object to communicate with Snowflake. - database_name: Desired name of the model registry database. - schema_name: Desired name of the schema used by this model registry inside the database. - registry_table_name: Name for the main model registry table. - metadata_table_name: Name for the metadata table used by the model registry. - deployment_table_name: Name for the deployment event table. - statement_params: Function usage statement parameters used in sql query executions. - """ - fully_qualified_schema_name = table_manager.get_fully_qualified_schema_name(database_name, schema_name) - - # From the documentation: Each DDL statement executes as a separate transaction. Races should not be an issue. - # https://docs.snowflake.com/en/sql-reference/transactions.html#ddl - - # Create a view on active permanent deployments. - _create_active_permanent_deployment_view( - session, - fully_qualified_schema_name, - registry_table_name, - deployment_table_name, - statement_params, - ) - - # Create views on most recent metadata items. - metadata_view_name_prefix = identifier.concat_names([metadata_table_name, "_LAST_"]) - metadata_view_template = formatting.unwrap( - """CREATE OR REPLACE TEMPORARY VIEW {database}.{schema}.{attribute_view} COPY GRANTS AS - SELECT DISTINCT MODEL_ID, {select_expression} AS {final_attribute_name} FROM {metadata_table} - WHERE ATTRIBUTE_NAME = '{attribute_name}'""" - ) - - # Create a separate view for the most recent item in each metadata column. - metadata_view_names = [] - metadata_select_fields = [] - for attribute_name in _LIST_METADATA_ATTRIBUTE: - view_name = identifier.concat_names([metadata_view_name_prefix, attribute_name]) - select_expression = ( - f"(LAST_VALUE(VALUE) OVER (PARTITION BY MODEL_ID ORDER BY EVENT_TIMESTAMP))['{attribute_name}']" - ) - sql = metadata_view_template.format( - database=database_name, - schema=schema_name, - select_expression=select_expression, - attribute_view=view_name, - attribute_name=attribute_name, - final_attribute_name=attribute_name, - metadata_table=metadata_table_name, - ) - session.sql(sql).collect(statement_params=statement_params) - metadata_view_names.append(view_name) - metadata_select_fields.append(f"{view_name}.{attribute_name} AS {attribute_name}") - - # Create a special view for the registration timestamp. - attribute_name = _METADATA_ATTRIBUTE_REGISTRATION - final_attribute_name = identifier.concat_names([attribute_name, "_TIMESTAMP"]) - view_name = identifier.concat_names([metadata_view_name_prefix, attribute_name]) - create_registration_view_sql = metadata_view_template.format( - database=database_name, - schema=schema_name, - select_expression="EVENT_TIMESTAMP", - attribute_view=view_name, - attribute_name=attribute_name, - final_attribute_name=final_attribute_name, - metadata_table=metadata_table_name, - ) - session.sql(create_registration_view_sql).collect(statement_params=statement_params) - metadata_view_names.append(view_name) - metadata_select_fields.append(f"{view_name}.{final_attribute_name} AS {final_attribute_name}") - - metadata_views_join = " ".join( - [ - "LEFT JOIN {view} ON ({view}.MODEL_ID = {registry_table}.ID)".format( - view=view, registry_table=registry_table_name - ) - for view in metadata_view_names - ] - ) - - # Create view to combine all attributes. - registry_view_name = identifier.concat_names([registry_table_name, "_VIEW"]) - metadata_select_fields_formatted = ",".join(metadata_select_fields) - session.sql( - f"""CREATE OR REPLACE TEMPORARY VIEW {fully_qualified_schema_name}.{registry_view_name} COPY GRANTS AS - SELECT {registry_table_name}.*, {metadata_select_fields_formatted} - FROM {registry_table_name} {metadata_views_join}""" - ).collect(statement_params=statement_params) - - -def _create_active_permanent_deployment_view( - session: snowpark.Session, - fully_qualified_schema_name: str, - registry_table_name: str, - deployment_table_name: str, - statement_params: Dict[str, Any], -) -> None: - """Create a view which lists all available permanent deployments. - - Args: - session: Session object to communicate with Snowflake. - fully_qualified_schema_name: Schema name to the target table. - registry_table_name: Name for the main model registry table. - deployment_table_name: Name of the deployment table. - statement_params: Function usage statement parameters used in sql query executions. - """ - - # Create a view on active permanent deployments - # Active deployments are those whose last operation is not DROP. - active_deployments_view_name = identifier.concat_names([deployment_table_name, "_VIEW"]) - active_deployments_view_expr = f""" - CREATE OR REPLACE TEMPORARY VIEW {fully_qualified_schema_name}.{active_deployments_view_name} - COPY GRANTS AS - SELECT - DEPLOYMENT_NAME, - MODEL_ID, - {registry_table_name}.NAME as MODEL_NAME, - {registry_table_name}.VERSION as MODEL_VERSION, - {deployment_table_name}.CREATION_TIME as CREATION_TIME, - TARGET_METHOD, - TARGET_PLATFORM, - SIGNATURE, - OPTIONS, - STAGE_PATH, - ROLE - FROM {deployment_table_name} - LEFT JOIN {registry_table_name} - ON {deployment_table_name}.MODEL_ID = {registry_table_name}.ID - """ - session.sql(active_deployments_view_expr).collect(statement_params=statement_params) - - -class ModelRegistry: - """Model Management API.""" - - def __init__( - self, - *, - session: snowpark.Session, - database_name: str = _DEFAULT_REGISTRY_NAME, - schema_name: str = _DEFAULT_SCHEMA_NAME, - create_if_not_exists: bool = False, - ) -> None: - """ - Opens an already-created registry. - - Args: - session: Session object to communicate with Snowflake. - database_name: Desired name of the model registry database. - schema_name: Desired name of the schema used by this model registry inside the database. - create_if_not_exists: create model registry if it's not exists already. - """ - - warnings.warn( - """ -The `snowflake.ml.registry.model_registry.ModelRegistry` has been deprecated starting from version 1.2.0. -It will stay in the Private Preview phase. For future implementations, kindly utilize `snowflake.ml.registry.Registry`, -except when specifically required. The old model registry will be removed once all its primary functionalities are -fully integrated into the new registry. - """, - DeprecationWarning, - stacklevel=2, - ) - if create_if_not_exists: - create_model_registry(session=session, database_name=database_name, schema_name=schema_name) - - self._name = identifier.get_inferred_name(database_name) - self._schema = identifier.get_inferred_name(schema_name) - self._registry_table = identifier.get_inferred_name(_MODELS_TABLE_NAME) - self._registry_table_view = identifier.concat_names([self._registry_table, "_VIEW"]) - self._metadata_table = identifier.get_inferred_name(_METADATA_TABLE_NAME) - self._deployment_table = identifier.get_inferred_name(_DEPLOYMENT_TABLE_NAME) - self._permanent_deployment_view = identifier.concat_names([self._deployment_table, "_VIEW"]) - self._permanent_deployment_stage = identifier.concat_names([self._deployment_table, "_STAGE"]) - self._session = session - self._svm = _schema_version_manager.SchemaVersionManager(self._session, self._name, self._schema) - - # A in-memory deployment info cache to store information of temporary deployments - # TODO(zhe): Use a temporary table to replace the in-memory cache. - self._temporary_deployments: Dict[str, model_types.Deployment] = {} - - _initial_schema.check_access(self._session, self._name, self._schema) - - statement_params = self._get_statement_params(inspect.currentframe()) - self._svm.validate_schema_version(statement_params) - - _create_registry_views( - session, - self._name, - self._schema, - self._registry_table, - self._metadata_table, - self._deployment_table, - statement_params, - ) - - # Private methods - - def _get_statement_params(self, frame: Optional[types.FrameType]) -> Dict[str, Any]: - return telemetry.get_function_usage_statement_params( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - function_name=telemetry.get_statement_params_full_func_name(frame, "ModelRegistry"), - ) - - def _get_new_unique_identifier(self) -> str: - """Create new unique identifier. - - Returns: - String identifier.""" - return uuid1().hex - - def _fully_qualified_registry_table_name(self) -> str: - """Get the fully qualified name to the current registry table.""" - return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._registry_table) - - def _fully_qualified_registry_view_name(self) -> str: - """Get the fully qualified name to the current registry view.""" - return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._registry_table_view) - - def _fully_qualified_metadata_table_name(self) -> str: - """Get the fully qualified name to the current metadata table.""" - return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._metadata_table) - - def _fully_qualified_deployment_table_name(self) -> str: - """Get the fully qualified name to the current deployment table.""" - return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._deployment_table) - - def _fully_qualified_permanent_deployment_view_name(self) -> str: - """Get the fully qualified name to the permanent deployment view.""" - return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._permanent_deployment_view) - - def _fully_qualified_schema_name(self) -> str: - """Get the fully qualified name to the current registry schema.""" - return table_manager.get_fully_qualified_schema_name(self._name, self._schema) - - def _fully_qualified_deployment_name(self, deployment_name: str) -> str: - """Get the fully qualified name to the given deployment.""" - return table_manager.get_fully_qualified_table_name(self._name, self._schema, deployment_name) - - def _insert_registry_entry( - self, *, id: str, name: str, version: str, properties: Dict[str, Any] - ) -> List[snowpark.Row]: - """Insert a new row into the model registry table. - - Args: - id: Model id to register. - name: Model Name string. - version: Model Version string. - properties: Dictionary of properties corresponding to table columns. - - Returns: - snowpark.Dataframe with the result of the operation. - - Raises: - DataError: Mismatch between different id fields. - """ - if not id: - raise connector.DataError("Model ID is required but none given.") - mandatory_args = {"ID": id, "NAME": name, "VERSION": version} - for k, v in mandatory_args.items(): - if k not in properties: - properties[k] = v - else: - if v and v != properties[k]: - raise connector.DataError( - formatting.unwrap( - f"""Parameter '{k.lower()}' is given and parameter 'properties' has the field '{k}' set but - the values do not match: {k.lower()}=="{v}" properties['{k}']=="{properties[k]}".""" - ) - ) - # Could do a multi-table insert here with some pros and cons: - # [PRO] Atomic insert across multiple tables. - # [CON] Code logic becomes messy depending on which fields are set. - # [CON] Harder to reuse existing methods like set_model_name. - # Context: https://docs.snowflake.com/en/sql-reference/sql/insert-multi-table.html - return table_manager.insert_table_entry( - self._session, - table=self._fully_qualified_registry_table_name(), - columns=properties, - ) - - def _insert_metadata_entry(self, *, id: str, attribute: str, value: Any, operation: str) -> List[snowpark.Row]: - """Insert a new row into the model metadata table. - - Args: - id: Model id to register. - attribute: name of the metadata attribute - value: new value of the metadata attribute - operation: the operation type of the metadata entry. - - Returns: - snowpark.DataFrame with the result of the operation. - - Raises: - DataError: Missing ID field. - """ - if not id: - raise connector.DataError("Model ID is required but none given.") - - columns: Dict[str, Any] = {} - columns["EVENT_TIMESTAMP"] = formatting.SqlStr("CURRENT_TIMESTAMP()") - columns["EVENT_ID"] = self._get_new_unique_identifier() - columns["MODEL_ID"] = id - columns["ROLE"] = self._session.get_current_role() - columns["OPERATION"] = operation - columns["ATTRIBUTE_NAME"] = attribute - columns["VALUE"] = value - - return table_manager.insert_table_entry( - self._session, - table=self._fully_qualified_metadata_table_name(), - columns=columns, - ) - - def _insert_deployment_entry( - self, - *, - id: str, - name: str, - platform: str, - stage_path: str, - signature: Dict[str, Any], - target_method: str, - options: Optional[ - Union[ - model_types.WarehouseDeployOptions, - model_types.SnowparkContainerServiceDeployOptions, - ] - ] = None, - ) -> List[snowpark.Row]: - """Insert a new row into the model deployment table. - - Each row in the deployment table is a deployment event. - - Args: - id: Model id of the deployed model. - name: Name of the deployment. - platform: The deployment target destination. - stage_path: The stage location where the deployment UDF is stored. - signature: The model signature. - target_method: The method name which is used for the deployment. - options: The deployment options. - - Returns: - A list of snowpark rows which is the insertion result. - - Raises: - DataError: Missing ID field. - """ - if not id: - raise connector.DataError("Model ID is required but none given.") - - columns: Dict[str, Any] = {} - columns["CREATION_TIME"] = formatting.SqlStr("CURRENT_TIMESTAMP()") - columns["MODEL_ID"] = id - columns["DEPLOYMENT_NAME"] = name - columns["TARGET_PLATFORM"] = platform - columns["STAGE_PATH"] = stage_path - columns["ROLE"] = self._session.get_current_role() - columns["SIGNATURE"] = signature - columns["TARGET_METHOD"] = target_method - columns["OPTIONS"] = options - - return table_manager.insert_table_entry( - self._session, - table=self._fully_qualified_deployment_table_name(), - columns=columns, - ) - - def _prepare_deployment_stage(self) -> str: - """Create a stage in the model registry for storing all permanent deployments. - - Returns: - Path to the stage that was created. - """ - schema = self._fully_qualified_schema_name() - fully_qualified_deployment_stage_name = f"{schema}.{self._permanent_deployment_stage}" - statement_params = self._get_statement_params(inspect.currentframe()) - self._session.sql( - f"CREATE STAGE IF NOT EXISTS {fully_qualified_deployment_stage_name} " - f"ENCRYPTION = (TYPE= 'SNOWFLAKE_SSE')" - ).collect(statement_params=statement_params) - return f"@{fully_qualified_deployment_stage_name}" - - def _prepare_model_stage(self, model_id: str) -> str: - """Create a stage in the model registry for storing the model with the given id. - - Creating a permanent stage here since we do not have a way to switch a stage from temporary to permanent. - This can result in orphaned stages in case the process fails. It might be better to try to create a - temporary stage, attempt to perform all operations and convert the temp stage into permanent once the - operation is complete. - - Args: - model_id: Internal model ID string. - - Returns: - Name of the stage that was created. - - Raises: - DatabaseError: Indicates that something went wrong when creating the stage. - """ - schema = self._fully_qualified_schema_name() - - # Uppercase the model_stage_name to avoid having to quote the the stage name. - stage_name = model_id.upper() - - model_stage_name = f"SNOWML_MODEL_{stage_name}" - fully_qualified_model_stage_name = f"{schema}.{model_stage_name}" - statement_params = self._get_statement_params(inspect.currentframe()) - - create_stage_result = self._session.sql( - f"CREATE OR REPLACE STAGE {fully_qualified_model_stage_name} ENCRYPTION = (TYPE= 'SNOWFLAKE_SSE')" - ).collect(statement_params=statement_params) - if not create_stage_result: - raise connector.DatabaseError("Unable to create stage for model. Operation returned not result.") - if len(create_stage_result) != 1: - raise connector.DatabaseError( - "Unable to create stage for model. Creating the model stage returned unexpected result: {}.".format( - str(create_stage_result) - ) - ) - - return fully_qualified_model_stage_name - - def _get_fully_qualified_stage_name_from_uri(self, model_uri: str) -> Optional[str]: - """Get fully qualified stage path pointed by the URI. - - Args: - model_uri: URI for which stage file is needed. - - Returns: - The fully qualified Snowflake stage location encoded by the given URI. Returns None if the URI is not - pointing to a Snowflake stage. - """ - raw_stage_path = uri.get_snowflake_stage_path_from_uri(model_uri) - if not raw_stage_path: - return None - (db, schema, stage, _) = identifier.parse_snowflake_stage_path(raw_stage_path) - return identifier.get_schema_level_object_identifier(db, schema, stage) - - def _list_selected_models( - self, - *, - id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, - ) -> snowpark.DataFrame: - """Retrieve the Snowpark dataframe of models matching the specified ID or (name and version). - - Args: - id: Model ID string. Required if either name or version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if id is None. - - Returns: - A Snowpark dataframe representing the models that match the given constraints. - """ - models = self.list_models() - - if id: - filtered_models = models.filter(snowpark.Column("ID") == id) - else: - self._model_identifier_is_nonempty_or_raise(model_name, model_version) - - # The following two asserts is to satisfy mypy. - assert model_name - assert model_version - - filtered_models = models.filter(snowpark.Column("NAME") == model_name).filter( - snowpark.Column("VERSION") == model_version - ) - - return cast(snowpark.DataFrame, filtered_models) - - def _validate_exact_one_result( - self, selected_model: snowpark.DataFrame, model_identifier: str - ) -> List[snowpark.Row]: - """Validate the filtered model has exactly one result. - - Args: - selected_model: A snowpark dataframe representing the models that are filtered out. - model_identifier: A string which is used to filter the model. - - Returns: - A snowpark row which contains the metadata of the filtered model - - Raises: - KeyError: The target model doesn't exist. - DataError: The target model is not unique. - """ - statement_params = self._get_statement_params(inspect.currentframe()) - model_info = None - try: - model_info = ( - query_result_checker.ResultValidator(result=selected_model.collect(statement_params=statement_params)) - .has_dimensions(expected_rows=1) - .validate() - ) - except connector.DataError: - if model_info is None or len(model_info) == 0: - raise KeyError(f"The model {model_identifier} does not exist in the current registry.") - else: - raise connector.DataError( - formatting.unwrap( - f"""There are {len(model_info)} models {model_identifier}. This might indicate a problem with - the integrity of the model registry data.""" - ) - ) - return model_info - - def _get_metadata_attribute( - self, - attribute: str, - id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, - ) -> Any: - """Get the value of the given metadata attribute for target model with given (model name + model version) or id. - - Args: - attribute: Name of the attribute to get. - id: Model ID string. Required if either name or version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. - - Returns: - The value of the attribute that was requested. Can be None if the attribute is not set. - """ - selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version) - identifier = f"id {id}" if id else f"{model_name}/{model_version}" - model_info = self._validate_exact_one_result(selected_models, identifier) - return model_info[0][attribute] - - def _set_metadata_attribute( - self, - attribute: str, - value: Any, - id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, - operation: str = _SET_METADATA_OPERATION, - enable_model_presence_check: bool = True, - ) -> None: - """Set the value of the given metadata attribute for target model with given (model name + model version) or id. - - Args: - attribute: Name of the attribute to set. - value: Value to set. - id: Model ID string. Required if either name or version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. - operation: the operation type of the metadata entry. - enable_model_presence_check: If True, we will check if the model with the given ID is currently registered - before setting the metadata attribute. False by default meaning that by default we will check. - - Raises: - DataError: Failed to set the metadata attribute. - KeyError: The target model doesn't exist - """ - selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version) - identifier = f"id {id}" if id else f"{model_name}/{model_version}" - try: - model_info = self._validate_exact_one_result(selected_models, identifier) - except KeyError as e: - # If the target model doesn't exist, raise the error only if enable_model_presence_check is True. - if enable_model_presence_check: - raise e - - if not id: - id = model_info[0]["ID"] - assert id is not None - - try: - self._insert_metadata_entry( - id=id, - attribute=attribute, - value={attribute: value}, - operation=operation, - ) - except connector.DataError: - raise connector.DataError(f"Setting {attribute} for mode id {id} failed.") - - def _model_identifier_is_nonempty_or_raise(self, model_name: Optional[str], model_version: Optional[str]) -> None: - """Validate model_name and model_version are non-empty strings. - - Args: - model_name: Model Name string. - model_version: Model Version string. - - Raises: - ValueError: Raised when either model_name and model_version is empty. - """ - if not model_name or not model_version: - raise ValueError("model_name and model_version have to be non-empty strings.") - - def _get_model_id(self, model_name: str, model_version: str) -> str: - """Get ID of the model with the given (model name + model version). - - Args: - model_name: Model Name string. - model_version: Model Version string. - - Returns: - Id of the model. - - Raises: - DataError: The requested model could not be found. - """ - result = self._get_metadata_attribute("ID", model_name=model_name, model_version=model_version) - if not result: - raise connector.DataError(f"Model {model_name}/{model_version} doesn't exist.") - return str(result) - - def _get_model_path( - self, - id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, - ) -> str: - """Get the stage path for the model with the given (model name + model version) or `id` from the registry. - - Args: - id: Id of the model to deploy. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if id is None. - - Returns: - str: Stage path for the model. - - Raises: - DataError: When the model cannot be found or not be restored. - """ - statement_params = self._get_statement_params(inspect.currentframe()) - selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version) - identifier = f"id {id}" if id else f"{model_name}/{model_version}" - model_info = self._validate_exact_one_result(selected_models, identifier) - if not id: - id = model_info[0]["ID"] - model_uri = model_info[0]["URI"] - - if not uri.is_snowflake_stage_uri(model_uri): - raise connector.DataError( - f"Artifacts with URI scheme {uri.get_uri_scheme(model_uri)} are currently not supported." - ) - - model_stage_path = self._get_fully_qualified_stage_name_from_uri(model_uri=model_uri) - - # Currently we assume only the model is on the stage. - model_file_list = self._session.sql(f"LIST @{model_stage_path}").collect(statement_params=statement_params) - if len(model_file_list) == 0: - raise connector.DataError(f"No files in model artifact for id {id} located at {model_uri}.") - return f"{_STAGE_PREFIX}{model_stage_path}" - - def _log_model_path( - self, - model_name: str, - model_version: str, - ) -> Tuple[str, str]: - """Generate a path in the Model Registry to store a model. - - Args: - model_name: The given name for the model. - model_version: Version string to be set for the model. - - Returns: - String of the auto-generate unique model identifier and path to store it. - """ - model_id = self._get_new_unique_identifier() - - # Copy model from local disk to remote stage. - # TODO(zhe): Check if we could use the same stage for multiple models. - fully_qualified_model_stage_name = self._prepare_model_stage(model_id=model_id) - - return model_id, fully_qualified_model_stage_name - - def _register_model_with_id( - self, - model_name: str, - model_version: str, - model_id: str, - *, - type: str, - uri: str, - input_spec: Optional[Dict[str, str]] = None, - output_spec: Optional[Dict[str, str]] = None, - description: Optional[str] = None, - tags: Optional[Dict[str, str]] = None, - ) -> None: - """Helper function to register model metadata. - - Args: - model_name: Name to be set for the model. The model name can NOT be changed after registration. The - combination of name and version is expected to be unique inside the registry. - model_version: Version string to be set for the model. The model version string can NOT be changed after - model registration. The combination of name and version is expected to be unique inside the registry. - model_id: The internal id for the model. - type: Type of the model. Only a subset of types are supported natively. - uri: Resource identifier pointing to the model artifact. There are no restrictions on the URI format, - however only a limited set of URI schemes is supported natively. - input_spec: The expected input schema of the model. Dictionary where the keys are - expected column names and the values are the value types. - output_spec: The expected output schema of the model. Dictionary where the keys - are expected column names and the values are the value types. - description: A description for the model. The description can be changed later. - tags: Key-value pairs of tags to be set for this model. Tags can be modified - after model registration. - - Raises: - DataError: The given model already exists. - DatabaseError: Unable to register the model properties into table. - """ - new_model: Dict[Any, Any] = {} - new_model["ID"] = model_id - new_model["NAME"] = model_name - new_model["VERSION"] = model_version - new_model["TYPE"] = type - new_model["URI"] = uri - new_model["INPUT_SPEC"] = input_spec - new_model["OUTPUT_SPEC"] = output_spec - new_model["CREATION_TIME"] = formatting.SqlStr("CURRENT_TIMESTAMP()") - new_model["CREATION_ROLE"] = self._session.get_current_role() - new_model["CREATION_ENVIRONMENT_SPEC"] = {"python": ".".join(map(str, sys.version_info[:3]))} - - existing_model_nums = self._list_selected_models(model_name=model_name, model_version=model_version).count() - if existing_model_nums: - raise connector.DataError( - f"Model {model_name}/{model_version} already exists. Unable to register the model." - ) - - if self._insert_registry_entry(id=model_id, name=model_name, version=model_version, properties=new_model): - self._set_metadata_attribute( - model_name=model_name, - model_version=model_version, - attribute=_METADATA_ATTRIBUTE_REGISTRATION, - value=new_model, - ) - if description: - self.set_model_description( - model_name=model_name, - model_version=model_version, - description=description, - ) - if tags: - self._set_metadata_attribute( - _METADATA_ATTRIBUTE_TAGS, - value=tags, - model_name=model_name, - model_version=model_version, - ) - else: - raise connector.DatabaseError("Failed to insert the model properties to the registry table.") - - def _get_deployment(self, *, model_name: str, model_version: str, deployment_name: str) -> snowpark.Row: - statement_params = self._get_statement_params(inspect.currentframe()) - deployment_lst = ( - self._session.sql(f"SELECT * FROM {self._fully_qualified_permanent_deployment_view_name()}") - .filter(snowpark.Column("DEPLOYMENT_NAME") == deployment_name) - .filter(snowpark.Column("MODEL_NAME") == model_name) - .filter(snowpark.Column("MODEL_VERSION") == model_version) - ).collect(statement_params=statement_params) - if len(deployment_lst) == 0: - raise KeyError( - f"Unable to find deployment named {deployment_name} in the model {model_name}/{model_version}." - ) - assert len(deployment_lst) == 1, "_get_deployment should return exactly 1 deployment" - return cast(snowpark.Row, deployment_lst[0]) - - # Registry operations - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def list_models(self) -> snowpark.DataFrame: - """Lists models contained in the registry. - - Returns: - snowpark.DataFrame with the list of models. Access is read-only through the snowpark.DataFrame. - The resulting snowpark.dataframe will have an "id" column that uniquely identifies each model and can be - used to reference the model when performing operations. - """ - # Explicitly not calling collect. - return self._session.sql( - "SELECT * FROM {database}.{schema}.{view}".format( - database=self._name, schema=self._schema, view=self._registry_table_view - ) - ) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def set_tag( - self, - model_name: str, - model_version: str, - tag_name: str, - tag_value: Optional[str] = None, - ) -> None: - """Set model tag to the model with value. - - If the model tag already exists, the tag value will be overwritten. - - Args: - model_name: Model Name string. - model_version: Model Version string. - tag_name: Desired tag name string. - tag_value: (optional) New tag value string. If no value is given the value of the tag will be set to None. - """ - # This method uses a read-modify-write pattern for setting tags. - # TODO(amauser): Investigate the use of transactions to avoid race conditions. - model_tags = self.get_tags(model_name=model_name, model_version=model_version) - model_tags[tag_name] = tag_value - self._set_metadata_attribute( - _METADATA_ATTRIBUTE_TAGS, - model_tags, - model_name=model_name, - model_version=model_version, - ) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def remove_tag(self, model_name: str, model_version: str, tag_name: str) -> None: - """Remove target model tag. - - Args: - model_name: Model Name string. - model_version: Model Version string. - tag_name: Desired tag name string. - - Raises: - DataError: If the model does not have the requested tag. - """ - # This method uses a read-modify-write pattern for updating tags. - - model_tags = self.get_tags(model_name=model_name, model_version=model_version) - try: - del model_tags[tag_name] - except KeyError: - raise connector.DataError( - f"Model {model_name}/{model_version} has no tag named {tag_name}. Full list of tags: {model_tags}" - ) - - self._set_metadata_attribute( - _METADATA_ATTRIBUTE_TAGS, - model_tags, - model_name=model_name, - model_version=model_version, - ) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def has_tag( - self, - model_name: str, - model_version: str, - tag_name: str, - tag_value: Optional[str] = None, - ) -> bool: - """Check if a model has a tag with the given name and value. - - If no value is given, any value for the tag will return true. - - Args: - model_name: Model Name string. - model_version: Model Version string. - tag_name: Desired tag name string. - tag_value: (optional) Tag value to check. If not value is given, only the presence of the tag will be - checked. - - Returns: - True if the tag or tag and value combination is present for the model with the given id, False otherwise. - """ - tags = self.get_tags(model_name=model_name, model_version=model_version) - has_tag = tag_name in tags - if tag_value is None: - return has_tag - return has_tag and tags[tag_name] == str(tag_value) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def get_tag_value(self, model_name: str, model_version: str, tag_name: str) -> Any: - """Return the value of the tag for the model. - - The returned value can be None. If the tag does not exist, KeyError will be raised. - - Args: - model_name: Model Name string. - model_version: Model Version string. - tag_name: Desired tag name string. - - Returns: - Value string of the tag or None, if no value is set for the tag. - """ - return self.get_tags(model_name=model_name, model_version=model_version)[tag_name] - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def get_tags(self, model_name: Optional[str] = None, model_version: Optional[str] = None) -> Dict[str, Any]: - """Get all tags and values stored for the target model. - - Args: - model_name: Model Name string. - model_version: Model Version string. - - Returns: - String-to-string dictionary containing all tags and values. The resulting dictionary can be empty. - """ - # Snowpark snowpark.dataframe returns dictionary objects as strings. We need to convert it back to a dictionary - # here. - result = self._get_metadata_attribute( - _METADATA_ATTRIBUTE_TAGS, model_name=model_name, model_version=model_version - ) - - if result: - ret: Dict[str, Optional[str]] = json.loads(result) - return ret - else: - return dict() - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def get_model_description(self, model_name: str, model_version: str) -> Optional[str]: - """Get the description of the model. - - Args: - model_name: Model Name string. - model_version: Model Version string. - - Returns: - Description of the model or None. - """ - result = self._get_metadata_attribute( - _METADATA_ATTRIBUTE_DESCRIPTION, - model_name=model_name, - model_version=model_version, - ) - return None if result is None else json.loads(result) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def set_model_description( - self, - model_name: str, - model_version: str, - description: str, - ) -> None: - """Set the description of the model. - - Args: - model_name: Model Name string. - model_version: Model Version string. - description: Desired new model description. - """ - self._set_metadata_attribute( - _METADATA_ATTRIBUTE_DESCRIPTION, - description, - model_name=model_name, - model_version=model_version, - ) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def get_history(self) -> snowpark.DataFrame: - """Return a dataframe with the history of operations performed on the model registry. - - The returned dataframe is order by time and can be filtered further. - - Returns: - snowpark.DataFrame with the history of the model. - """ - res = ( - self._session.table(self._fully_qualified_metadata_table_name()) - .order_by("EVENT_TIMESTAMP") - .select_expr( - "EVENT_TIMESTAMP", - "EVENT_ID", - "MODEL_ID", - "ROLE", - "OPERATION", - "ATTRIBUTE_NAME", - "VALUE[ATTRIBUTE_NAME]", - ) - ) - return cast(snowpark.DataFrame, res) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def get_model_history( - self, - model_name: str, - model_version: str, - ) -> snowpark.DataFrame: - """Return a dataframe with the history of operations performed on the desired model. - - The returned dataframe is order by time and can be filtered further. - - Args: - model_name: Model Name string. - model_version: Model Version string. - - Returns: - snowpark.DataFrame with the history of the model. - """ - id = self._get_model_id(model_name=model_name, model_version=model_version) - return cast( - snowpark.DataFrame, - self.get_history().filter(snowpark.Column("MODEL_ID") == id), - ) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def set_metric( - self, - model_name: str, - model_version: str, - metric_name: str, - metric_value: object, - ) -> None: - """Set scalar model metric to value. - - If a metric with that name already exists for the model, the metric value will be overwritten. - - Args: - model_name: Model Name string. - model_version: Model Version string. - metric_name: Desired metric name. - metric_value: New metric value. - """ - # This method uses a read-modify-write pattern for setting tags. - # TODO(amauser): Investigate the use of transactions to avoid race conditions. - model_metrics = self.get_metrics(model_name=model_name, model_version=model_version) - model_metrics[metric_name] = metric_value - self._set_metadata_attribute( - _METADATA_ATTRIBUTE_METRICS, - model_metrics, - model_name=model_name, - model_version=model_version, - ) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def remove_metric( - self, - model_name: str, - model_version: str, - metric_name: str, - ) -> None: - """Remove a specific metric entry from the model. - - Args: - model_name: Model Name string. - model_version: Model Version string. - metric_name: Desired metric name. - - Raises: - DataError: If the model does not have the requested metric. - """ - # This method uses a read-modify-write pattern for updating tags. - - model_metrics = self.get_metrics(model_name=model_name, model_version=model_version) - try: - del model_metrics[metric_name] - except KeyError: - raise connector.DataError( - f"Model {model_name}/{model_version} has no metric named {metric_name}. " - f"Full list of metrics: {model_metrics}" - ) - - self._set_metadata_attribute( - _METADATA_ATTRIBUTE_METRICS, - model_metrics, - model_name=model_name, - model_version=model_version, - ) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def has_metric(self, model_name: str, model_version: str, metric_name: str) -> bool: - """Check if a model has a metric with the given name. - - Args: - model_name: Model Name string. - model_version: Model Version string. - metric_name: Desired metric name. - - Returns: - True if the metric is present for the model with the given id, False otherwise. - """ - metrics = self.get_metrics(model_name=model_name, model_version=model_version) - return metric_name in metrics - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def get_metric_value(self, model_name: str, model_version: str, metric_name: str) -> object: - """Return the value of the given metric for the model. - - The returned value can be None. If the metric does not exist, KeyError will be raised. - - Args: - model_name: Model Name string. - model_version: Model Version string. - metric_name: Desired metric name. - - Returns: - Value of the metric. Can be None if the metric was set to None. - """ - return self.get_metrics(model_name=model_name, model_version=model_version)[metric_name] - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def get_metrics(self, model_name: str, model_version: str) -> Dict[str, object]: - """Get all metrics and values stored for the given model. - - Args: - model_name: Model Name string. - model_version: Model Version string. - - Returns: - String-to-float dictionary containing all metrics and values. The resulting dictionary can be empty. - """ - # Snowpark snowpark.dataframe returns dictionary objects as strings. We need to convert it back to a dictionary - # here. - result = self._get_metadata_attribute( - _METADATA_ATTRIBUTE_METRICS, - model_name=model_name, - model_version=model_version, - ) - - if result: - ret: Dict[str, object] = json.loads(result) - return ret - else: - return dict() - - # Combined Registry and Repository operations. - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def log_model( - self, - model_name: str, - model_version: str, - *, - model: Any, - description: Optional[str] = None, - tags: Optional[Dict[str, str]] = None, - conda_dependencies: Optional[List[str]] = None, - pip_requirements: Optional[List[str]] = None, - signatures: Optional[Dict[str, model_signature.ModelSignature]] = None, - sample_input_data: Optional[Any] = None, - code_paths: Optional[List[str]] = None, - options: Optional[model_types.BaseModelSaveOption] = None, - ) -> Optional["ModelReference"]: - """Uploads and register a model to the Model Registry. - - Args: - model_name: The given name for the model. The combination (name + version) must be unique for each model. - model_version: Version string to be set for the model. The combination (name + version) must be unique for - each model. - model: Local model object in a supported format. - description: A description for the model. The description can be changed later. - tags: string-to-string dictionary of tag names and values to be set for the model. - conda_dependencies: List of Conda package specs. Use "[channel::]package [operator version]" syntax to - specify a dependency. It is a recommended way to specify your dependencies using conda. When channel is - not specified, defaults channel will be used. When deploying to Snowflake Warehouse, defaults channel - would be replaced with the Snowflake Anaconda channel. - pip_requirements: List of PIP package specs. Model will not be able to deploy to the warehouse if there is - pip requirements. - signatures: Signatures of the model, which is a mapping from target method name to signatures of input and - output, which could be inferred by calling `infer_signature` method with sample input data. - sample_input_data: Sample of the input data for the model. - code_paths: Directory of code to import when loading and deploying the model. - options: Additional options when saving the model. - - Raises: - DataError: Raised when: - 1) the given model already exists; - ValueError: Raised when: # noqa: DAR402 - 1) Signatures and sample_input_data are both not provided and model is not a - snowflake estimator. - Exception: Raised when there is any error raised when saving the model. - - Returns: - Model Reference . None if failed. - """ - # Ideally, the whole operation should be a single transaction. Currently, transactions do not support stage - # operations. - - statement_params = self._get_statement_params(inspect.currentframe()) - self._svm.validate_schema_version(statement_params) - - self._model_identifier_is_nonempty_or_raise(model_name, model_version) - - existing_model_nums = self._list_selected_models(model_name=model_name, model_version=model_version).count() - if existing_model_nums: - raise connector.DataError(f"Model {model_name}/{model_version} already exists. Unable to log the model.") - model_id, fully_qualified_model_stage_name = self._log_model_path( - model_name=model_name, - model_version=model_version, - ) - stage_path = f"{_STAGE_PREFIX}{fully_qualified_model_stage_name}" - model = cast(model_types.SupportedModelType, model) - try: - model_composer = model_api.save_model( # type: ignore[call-overload, misc] - name=model_name, - session=self._session, - stage_path=stage_path, - model=model, - signatures=signatures, - metadata=tags, - conda_dependencies=conda_dependencies, - pip_requirements=pip_requirements, - sample_input_data=sample_input_data, - code_paths=code_paths, - options=options, - ) - except Exception: - # When model saving fails, clean up the model stage. - query_result_checker.SqlResultValidator( - self._session, f"DROP STAGE {fully_qualified_model_stage_name}" - ).has_dimensions(expected_rows=1, expected_cols=1).validate() - raise - - self._register_model_with_id( - model_name=model_name, - model_version=model_version, - model_id=model_id, - type=model_composer.packager.meta.model_type, - uri=uri.get_uri_from_snowflake_stage_path(stage_path), - description=description, - tags=tags, - ) - - return ModelReference(registry=self, model_name=model_name, model_version=model_version) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def load_model(self, model_name: str, model_version: str) -> Any: - """Loads the model with the given (model_name + model_version) from the registry into memory. - - Args: - model_name: Model Name string. - model_version: Model Version string. - - Returns: - Restored model object. - """ - warnings.warn( - ( - "Please use with caution: " - "Using `load_model` method requires you to have the EXACT same Python environments " - "as the one when you logged the model. Any differences will potentially lead to errors.\n" - "Also, if your model contains custom code imported using `code_paths` argument when logging, " - "they will be added to your `sys.path`. It might lead to unexpected module importing issues. " - "If you run into such kind of problems, you need to restart your Python or Notebook kernel." - ), - category=UserWarning, - stacklevel=2, - ) - remote_model_path = self._get_model_path(model_name=model_name, model_version=model_version) - restored_model = None - - restored_model = model_api.load_model(session=self._session, stage_path=remote_model_path) - - return restored_model.packager.model - - # Repository Operations - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def deploy( - self, - model_name: str, - model_version: str, - *, - deployment_name: str, - target_method: Optional[str] = None, - permanent: bool = False, - platform: deploy_platforms.TargetPlatform = deploy_platforms.TargetPlatform.WAREHOUSE, - options: Optional[ - Union[ - model_types.WarehouseDeployOptions, - model_types.SnowparkContainerServiceDeployOptions, - ] - ] = None, - ) -> model_types.Deployment: - """Deploy the model with the given deployment name. - - Args: - model_name: Model Name string. - model_version: Model Version string. - deployment_name: name of the generated UDF. - target_method: The method name to use in deployment. Can be omitted if only 1 method in the model. - permanent: Whether the deployment is permanent or not. Permanent deployment will generate a permanent UDF. - (Only applicable for Warehouse deployment) - platform: Target platform to deploy the model to. Currently supported platforms are defined as enum in - `snowflake.ml.model.deploy_platforms.TargetPlatform` - options: Optional options for model deployment. Defaults to None. - - Returns: - Deployment info. - - Raises: - RuntimeError: Raised when parameters are not properly enabled when deploying to Warehouse with temporary UDF - RuntimeError: Raised when deploying to SPCS with db/schema that starts with underscore. - """ - statement_params = self._get_statement_params(inspect.currentframe()) - self._svm.validate_schema_version(statement_params) - - if options is None: - options = {} - - deployment_stage_path = "" - - if platform == deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES: - if self._name.startswith("_") or self._schema.startswith("_"): - error_message = """\ - Model deployment to Snowpark Container Service does not support a database/schema name that starts with - an underscore. Please ensure you pass in a valid db/schema name when initializing the registry with: - - model_registry.create_model_registry( - session=session, - database_name=db, - schema_name=schema - ) - - registry = model_registry.ModelRegistry( - session=session, - database_name=db, - schema_name=schema - ) - """ - raise RuntimeError(textwrap.dedent(error_message)) - permanent = True - options = cast(model_types.SnowparkContainerServiceDeployOptions, options) - deployment_stage_path = f"{self._prepare_deployment_stage()}/{deployment_name}/" - elif platform == deploy_platforms.TargetPlatform.WAREHOUSE: - options = cast(model_types.WarehouseDeployOptions, options) - if permanent: - # Every deployment-generated UDF should reside in its own unique directory. As long as each deployment - # is allocated a distinct directory, multiple deployments can coexist within the same stage. - # Given that each permanent deployment possesses a unique deployment_name, sharing the same stage does - # not present any issues - deployment_stage_path = ( - options.get("permanent_udf_stage_location") - or f"{self._prepare_deployment_stage()}/{deployment_name}/" - ) - options["permanent_udf_stage_location"] = deployment_stage_path - - remote_model_path = self._get_model_path(model_name=model_name, model_version=model_version) - model_id = self._get_model_id(model_name, model_version) - - # https://snowflakecomputing.atlassian.net/browse/SNOW-858376 - # During temporary deployment on the Warehouse, Snowpark creates an unencrypted temporary stage for UDF-related - # artifacts. However, UDF generation fails when importing from a mix of encrypted and unencrypted stages. - # The following workaround copies model between stages (PrPr as of July 7th, 2023) to transfer the SSE - # encrypted model zip from model stage to the temporary unencrypted stage. - if not permanent and platform == deploy_platforms.TargetPlatform.WAREHOUSE: - schema = self._fully_qualified_schema_name() - unencrypted_stage = ( - f"@{schema}.{snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)}" - ) - self._session.sql(f"CREATE TEMPORARY STAGE {unencrypted_stage[1:]}").collect() - try: - self._session.sql(f"COPY FILES INTO {unencrypted_stage} from {remote_model_path}").collect() - except Exception: - raise RuntimeError( - "Temporary deployment to the warehouse is currently not supported. Please use " - "permanent deployment by setting the 'permanent' parameter to True" - ) - remote_model_path = unencrypted_stage - - # Step 1: Deploy to get the UDF - deployment_info = model_api.deploy( - session=self._session, - name=self._fully_qualified_deployment_name(deployment_name), - platform=platform, - target_method=target_method, - stage_path=remote_model_path, - deployment_stage_path=deployment_stage_path, - model_id=model_id, - options=options, - ) - - # Step 2: Record the deployment - - # Assert to convince mypy. - assert deployment_info - if permanent: - self._insert_deployment_entry( - id=model_id, - name=deployment_name, - platform=deployment_info["platform"].value, - stage_path=deployment_stage_path, - signature=deployment_info["signature"].to_dict(), - target_method=deployment_info["target_method"], - options=options, - ) - - self._set_metadata_attribute( - _METADATA_ATTRIBUTE_DEPLOYMENT, - {"name": deployment_name, "permanent": permanent}, - id=model_id, - operation=_ADD_METADATA_OPERATION, - ) - - # Store temporary deployment information in the in-memory cache. This allows for future referencing and - # tracking of its availability status. - if not permanent: - self._temporary_deployments[deployment_name] = deployment_info - - return deployment_info - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="1.0.1") - def list_deployments(self, model_name: str, model_version: str) -> snowpark.DataFrame: - """List all permanent deployments that originated from the given model. - - Temporary deployment info are currently not supported for listing. - - Args: - model_name: Model Name string. - model_version: Model Version string. - - Returns: - A snowpark dataframe that contains all deployments that associated with the given model. - """ - deployments_df = ( - self._session.sql(f"SELECT * FROM {self._fully_qualified_permanent_deployment_view_name()}") - .filter(snowpark.Column("MODEL_NAME") == model_name) - .filter(snowpark.Column("MODEL_VERSION") == model_version) - ) - res = deployments_df.select( - deployments_df["MODEL_NAME"], - deployments_df["MODEL_VERSION"], - deployments_df["DEPLOYMENT_NAME"], - deployments_df["CREATION_TIME"], - deployments_df["TARGET_METHOD"], - deployments_df["TARGET_PLATFORM"], - deployments_df["SIGNATURE"], - deployments_df["OPTIONS"], - deployments_df["STAGE_PATH"], - deployments_df["ROLE"], - ) - return cast(snowpark.DataFrame, res) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="1.0.1") - def get_deployment(self, model_name: str, model_version: str, *, deployment_name: str) -> snowpark.DataFrame: - """Get the permanent deployment with target name of the given model. - - Temporary deployment info are currently not supported. - - Args: - model_name: Model Name string. - model_version: Model Version string. - deployment_name: Deployment name string. - - Returns: - A snowpark dataframe that contains the information of the target deployment. - - Raises: - KeyError: Raised if the target deployment is not found. - """ - deployment = self.list_deployments(model_name, model_version).filter( - snowpark.Column("DEPLOYMENT_NAME") == deployment_name - ) - if deployment.count() == 0: - raise KeyError( - f"Unable to find deployment named {deployment_name} in the model {model_name}/{model_version}." - ) - return cast(snowpark.DataFrame, deployment) - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="1.0.1") - def delete_deployment(self, model_name: str, model_version: str, *, deployment_name: str) -> None: - """Delete the target permanent deployment of the given model. - - Deleting temporary deployment are currently not supported. - Temporary deployment will get cleaned automatically when the current session closed. - - Args: - model_name: Model Name string. - model_version: Model Version string. - deployment_name: Name of the deployment that is getting deleted. - - """ - deployment = self._get_deployment( - model_name=model_name, - model_version=model_version, - deployment_name=deployment_name, - ) - - # TODO(SNOW-759526): The following sequence should be a transaction. - # Step 1: Drop the UDF - self._session.sql( - f"DROP FUNCTION IF EXISTS {self._fully_qualified_deployment_name(deployment_name)}(OBJECT)" - ).collect() - - # Step 2: Remove the staged artifact - self._session.sql(f"REMOVE {deployment['STAGE_PATH']}").collect() - - # Step 3: Delete the deployment from the deployment table - query_result_checker.SqlResultValidator( - self._session, - f"""DELETE FROM {self._fully_qualified_deployment_table_name()} - WHERE MODEL_ID='{deployment['MODEL_ID']}' AND DEPLOYMENT_NAME='{deployment_name}' - """, - ).deletion_success(expected_num_rows=1).validate() - - # Step 4: Record the delete event - self._set_metadata_attribute( - _METADATA_ATTRIBUTE_DEPLOYMENT, - {"name": deployment_name}, - id=deployment["MODEL_ID"], - operation=_DROP_METADATA_OPERATION, - ) - - # Optional Step 5: Delete Snowpark container service. - if deployment["TARGET_PLATFORM"] == deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES.value: - service_name = identifier.get_schema_level_object_identifier( - self._name, self._schema, f"service_{deployment['MODEL_ID']}" - ) - spcs_attribution_utils.record_service_end(self._session, service_name) - query_result_checker.SqlResultValidator( - self._session, - f"DROP SERVICE IF EXISTS {service_name}", - ).validate() - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def delete_model( - self, - model_name: str, - model_version: str, - delete_artifact: bool = True, - ) -> None: - """Delete model with the given ID from the registry. - - The history of the model will still be preserved. - - Args: - model_name: Model Name string. - model_version: Model Version string. - delete_artifact: If True, the underlying model artifact will also be deleted, not just the entry in - the registry table. - """ - - # Check that a model with the given ID exists and there is only one of them. - # TODO(amauser): The following sequence should be a transaction. Transactions currently cannot contain DDL - # statements. - model_info = None - selected_models = self._list_selected_models(model_name=model_name, model_version=model_version) - identifier = f"{model_name}/{model_version}" - model_info = self._validate_exact_one_result(selected_models, identifier) - id = model_info[0]["ID"] - model_uri = model_info[0]["URI"] - - # Step 1/3: Delete the registry entry. - query_result_checker.SqlResultValidator( - self._session, - f"DELETE FROM {self._fully_qualified_registry_table_name()} WHERE ID='{id}'", - ).deletion_success(expected_num_rows=1).validate() - - # Step 2/3: Delete the artifact (if desired). - if delete_artifact: - if uri.is_snowflake_stage_uri(model_uri): - stage_path = self._get_fully_qualified_stage_name_from_uri(model_uri) - query_result_checker.SqlResultValidator(self._session, f"DROP STAGE {stage_path}").has_dimensions( - expected_rows=1, expected_cols=1 - ).validate() - - # Step 3/3: Record the deletion event. - self._set_metadata_attribute( - id=id, - attribute=_METADATA_ATTRIBUTE_DELETION, - value={"delete_artifact": True, "URI": model_uri}, - enable_model_presence_check=False, - ) - - -class ModelReference: - """Wrapper class for ModelReference objects that proxy model metadata operations.""" - - def _remove_arg_from_docstring(self, arg: str, docstring: Optional[str]) -> Optional[str]: - """Remove the given parameter from a function docstring (Google convention).""" - if docstring is None: - return None - docstring_lines = docstring.split("\n") - - args_section_start = None - args_section_end = None - args_section_indent = None - arg_start = None - arg_end = None - arg_indent = None - for i in range(len(docstring_lines)): - line = docstring_lines[i] - lstrip_line = line.lstrip() - indent = len(line) - len(lstrip_line) - - if line.strip() == "Args:": - # Starting the Args section of the docstring (assuming Google-style). - args_section_start = i - # logging.info("TEST: args_section_start=" + str(args_section_start)) - args_section_indent = indent - continue - - # logging.info("TEST: " + lstrip_line) - if args_section_start and lstrip_line.startswith(f"{arg}:"): - # This is the arg we are looking for. - arg_start = i - # logging.info("TEST: arg_start=" + str(arg_start)) - arg_indent = indent - continue - - if arg_start and not arg_end and indent == arg_indent: - # We got the next arg, previous line was the last of the cut out arg docstring - # and we do have other args. Saving arg_end for python slice/range notation. - arg_end = i - continue - - if arg_start and (len(lstrip_line) == 0 or indent == args_section_indent): - # Arg section ends. - args_section_end = i - arg_end = arg_end if arg_end else i - # We have learned everything we need to know, no need to continue. - break - - if arg_start and not arg_end: - arg_end = len(docstring_lines) - - if args_section_start and not args_section_end: - args_section_end = len(docstring_lines) - - # Determine which lines from the "Args:" section of the docstring to skip or if we - # should skip the entire section. - keep_lines = set(range(len(docstring_lines))) - if args_section_start: - if arg_start == args_section_start + 1 and arg_end == args_section_end: - # Removed arg was the only arg, remove the entire section. - assert args_section_end - keep_lines.difference_update(range(args_section_start, args_section_end)) - else: - # Just remove the arg. - assert arg_start - assert arg_end - keep_lines.difference_update(range(arg_start, arg_end)) - - return "\n".join([docstring_lines[i] for i in sorted(keep_lines)]) - - def __init__( - self, - *, - registry: ModelRegistry, - model_name: str, - model_version: str, - ) -> None: - self._registry = registry - self._id = registry._get_model_id(model_name=model_name, model_version=model_version) - self._model_name = model_name - self._model_version = model_version - - # Wrap all functions of the ModelRegistry that have an "id" parameter and bind that parameter - # the the "_id" member of this class. - if hasattr(self.__class__, "init_complete"): - # Already did the generation of wrapped method. - return - - for name, obj in self._registry.__class__.__dict__.items(): - if ( - not inspect.isfunction(obj) - or "model_name" not in inspect.signature(obj).parameters - or "model_version" not in inspect.signature(obj).parameters - ): - continue - - # Ensure that we are not silently overwriting existing functions. - assert not hasattr(self.__class__, name) - - def build_method(m: Callable[..., Any]) -> Callable[..., Any]: - return lambda self, *args, **kwargs: m( - self._registry, - self._model_name, - self._model_version, - *args, - **kwargs, - ) - - method = build_method(m=obj) - setattr(self.__class__, name, method) - - docstring = self._remove_arg_from_docstring("model_name", obj.__doc__) - if docstring and "model_version" in docstring: - docstring = self._remove_arg_from_docstring("model_version", docstring) - setattr(self.__class__.__dict__[name], "__doc__", docstring) # noqa: B010 - - setattr(self.__class__, "init_complete", True) # noqa: B010 - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - def get_name(self) -> str: - return self._model_name - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - def get_version(self) -> str: - return self._model_version - - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="0.2.0") - def predict(self, deployment_name: str, data: Any) -> "pd.DataFrame": - """Predict using the deployed model in Snowflake. - - Args: - deployment_name: name of the generated UDF. - data: Data to run predict. - - Raises: - ValueError: The deployment with given name haven't been deployed. - - Returns: - A dataframe containing the result of prediction. - """ - # We will search temporary deployments from the local in-memory cache. - # If there is no hit, we try to search the remote deployment table. - di = self._registry._temporary_deployments.get(deployment_name) - - statement_params = telemetry.get_function_usage_statement_params( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - function_name=telemetry.get_statement_params_full_func_name( - inspect.currentframe(), self.__class__.__name__ - ), - ) - - self._registry._svm.validate_schema_version(statement_params) - - if di: - return model_api.predict( - session=self._registry._session, - deployment=di, - X=data, - statement_params=statement_params, - ) - - # Mypy enforce to refer to the registry for calling the function - deployment_collect = self._registry.get_deployment( - self._model_name, self._model_version, deployment_name=deployment_name - ).collect(statement_params=statement_params) - if not deployment_collect: - raise ValueError(f"The deployment with name {deployment_name} haven't been deployed") - deployment = deployment_collect[0] - platform = deploy_platforms.TargetPlatform(deployment["TARGET_PLATFORM"]) - target_method = deployment["TARGET_METHOD"] - signature = model_signature.ModelSignature.from_dict(json.loads(deployment["SIGNATURE"])) - options_dict = cast(Dict[str, Any], json.loads(deployment["OPTIONS"])) - platform_options = { - deploy_platforms.TargetPlatform.WAREHOUSE: model_types.WarehouseDeployOptions, - deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES: ( - model_types.SnowparkContainerServiceDeployOptions - ), - } - - if platform not in platform_options: - raise ValueError(f"Unsupported target Platform: {platform}") - options = platform_options[platform](options_dict) - di = model_types.Deployment( - name=self._registry._fully_qualified_deployment_name(deployment_name), - platform=platform, - target_method=target_method, - signature=signature, - options=options, - ) - return model_api.predict( - session=self._registry._session, - deployment=di, - X=data, - statement_params=statement_params, - ) - - -@telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, -) -@snowpark._internal.utils.private_preview(version="0.2.0") -def create_model_registry( - *, - session: snowpark.Session, - database_name: str = _DEFAULT_REGISTRY_NAME, - schema_name: str = _DEFAULT_SCHEMA_NAME, -) -> bool: - """Setup a new model registry. This should be run once per model registry by an administrator role. - - Args: - session: Session object to communicate with Snowflake. - database_name: Desired name of the model registry database. - schema_name: Desired name of the schema used by this model registry inside the database. - - Returns: - True if the creation of the model registry internal data structures was successful, - False otherwise. - """ - # Get the db & schema of the current session - old_db = session.get_current_database() - old_schema = session.get_current_schema() - - # These might be exposed as parameters in the future. - database_name = identifier.get_inferred_name(database_name) - schema_name = identifier.get_inferred_name(schema_name) - - statement_params = telemetry.get_function_usage_statement_params( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), ""), - ) - try: - _create_registry_database(session, database_name, statement_params) - _create_registry_schema(session, database_name, schema_name, statement_params) - - svm = _schema_version_manager.SchemaVersionManager(session, database_name, schema_name) - deployed_schema_version = svm.get_deployed_version(statement_params) - if deployed_schema_version == _initial_schema._INITIAL_VERSION: - # We do not know if registry is being created for the first time. - # So let's start with creating initial schema, which is idempotent anyways. - _initial_schema.create_initial_registry_tables(session, database_name, schema_name, statement_params) - - svm.try_upgrade(statement_params) - - finally: - if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call] - # Restore the db & schema to the original ones - if old_db is not None and old_db != session.get_current_database(): - session.use_database(old_db) - if old_schema is not None and old_schema != session.get_current_schema(): - session.use_schema(old_schema) - return True diff --git a/snowflake/ml/registry/model_registry_test.py b/snowflake/ml/registry/model_registry_test.py deleted file mode 100644 index ba7ce9ab..00000000 --- a/snowflake/ml/registry/model_registry_test.py +++ /dev/null @@ -1,1251 +0,0 @@ -import datetime -import itertools -import json -from typing import Any, Dict, List, Union, cast - -from absl.testing import absltest - -from snowflake import connector, snowpark -from snowflake.ml._internal import telemetry -from snowflake.ml._internal.utils import formatting, identifier, uri -from snowflake.ml.model import _api -from snowflake.ml.registry import _initial_schema, _schema, model_registry -from snowflake.ml.test_utils import mock_data_frame, mock_session - -_DATABASE_NAME = identifier.get_inferred_name("_SYSTEM_MODEL_REGISTRY") -_SCHEMA_NAME = identifier.get_inferred_name("_SYSTEM_MODEL_REGISTRY_SCHEMA") -_REGISTRY_TABLE_NAME = identifier.get_inferred_name("_SYSTEM_REGISTRY_MODELS") -_ARTIFACTS_TABLE_NAME = identifier.get_inferred_name("_SYSTEM_REGISTRY_ARTIFACTS") -_METADATA_TABLE_NAME = identifier.get_inferred_name("_SYSTEM_REGISTRY_METADATA") -_DEPLOYMENTS_TABLE_NAME = identifier.get_inferred_name("_SYSTEM_REGISTRY_DEPLOYMENTS") -_VERSION_TABLE_NAME = identifier.get_inferred_name("_SYSTEM_REGISTRY_SCHEMA_VERSION") -_FULLY_QUALIFIED_REGISTRY_TABLE_NAME = ".".join( - [ - _DATABASE_NAME, - _SCHEMA_NAME, - _REGISTRY_TABLE_NAME, - ] -) -_REGISTRY_SCHEMA_STRING = ", ".join([f"{k} {v}" for k, v in _initial_schema._INITIAL_REGISTRY_TABLE_SCHEMA]) -_METADATA_INSERT_COLUMNS_STRING = ",".join( - filter(lambda x: x != "SEQUENCE_ID", [item[0] for item in _initial_schema._INITIAL_METADATA_TABLE_SCHEMA]) -) -_METADATA_SCHEMA_STRING = ", ".join( - [ - f"{k} {v.format(registry_table_name=_FULLY_QUALIFIED_REGISTRY_TABLE_NAME)}" - for k, v in _initial_schema._INITIAL_METADATA_TABLE_SCHEMA - ] -) -_DEPLOYMENTS_SCHEMA_STRING = ",".join( - [ - f"{k} {v.format(registry_table_name=_FULLY_QUALIFIED_REGISTRY_TABLE_NAME)}" - for k, v in _initial_schema._INITIAL_DEPLOYMENTS_TABLE_SCHEMA - ] -) -_ARTIFACTS_SCHEMA_STRING = ",".join( - [ - f"{k} {v.format(registry_table_name=_FULLY_QUALIFIED_REGISTRY_TABLE_NAME)}" - for k, v in _initial_schema._INITIAL_ARTIFACT_TABLE_SCHEMA - ] -) - - -class ModelRegistryTest(absltest.TestCase): - """Testing ModelRegistry functions.""" - - def setUp(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self.maxDiff = None - self._session = mock_session.MockSession(conn=None, test_case=self) - self.event_id = "fedcba9876543210fedcba9876543210" - self.model_id = "0123456789abcdef0123456789abcdef" - self.model_name = "name" - self.model_version = "abc" - self.datetime = datetime.datetime(2022, 11, 4, 17, 1, 30, 153000) - - self._setup_mock_session() - - def tearDown(self) -> None: - """Complete test case. Ensure all expected operations have been observed.""" - self._session.finalize() - - def _setup_mock_session(self) -> None: - """Equip the mock session with mock variable/methods just for model registry.""" - self._session.get_current_database = absltest.mock.MagicMock(return_value=_DATABASE_NAME) - self._session.get_current_schema = absltest.mock.MagicMock(return_value=_SCHEMA_NAME) - self._session.use_database = absltest.mock.MagicMock() - self._session.use_schema = absltest.mock.MagicMock() - - def _mock_show_database_exists(self) -> None: - self.add_session_mock_sql( - query=f"SHOW DATABASES LIKE '{_DATABASE_NAME}'", - result=mock_data_frame.MockDataFrame(self.get_show_databases_success(name=_DATABASE_NAME)), - ) - - def _mock_show_database_not_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"SHOW DATABASES LIKE '{_DATABASE_NAME}'", - result=mock_data_frame.MockDataFrame([]).add_collect_result([], statement_params=statement_params), - ) - - def _mock_create_database_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"CREATE DATABASE IF NOT EXISTS {_DATABASE_NAME}", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status="MODEL_REGISTRY already exists, statement succeeded.")], - collect_statement_params=statement_params, - ), - ) - - def _mock_create_database_not_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"CREATE DATABASE {_DATABASE_NAME}", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status="Database MODEL_REGISTRY successfully created.")], - collect_statement_params=statement_params, - ), - ) - - def _mock_show_schema_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"SHOW SCHEMAS LIKE '{_SCHEMA_NAME}' IN DATABASE {_DATABASE_NAME}", - result=mock_data_frame.MockDataFrame(self.get_show_schemas_success(name=_SCHEMA_NAME)).add_collect_result( - self.get_show_schemas_success(name=_SCHEMA_NAME), - statement_params=statement_params, - ), - ) - - def _mock_show_schema_not_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"SHOW SCHEMAS LIKE '{_SCHEMA_NAME}' IN DATABASE {_DATABASE_NAME}", - result=mock_data_frame.MockDataFrame([]).add_collect_result([], statement_params=statement_params), - ) - - def _mock_create_schema_not_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"CREATE SCHEMA {_DATABASE_NAME}.{_SCHEMA_NAME}", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"SCHEMA {_SCHEMA_NAME} successfully created.")], - collect_statement_params=statement_params, - ), - ) - - def _mock_create_registry_table_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"""CREATE TABLE IF NOT EXISTS {_DATABASE_NAME}.{_SCHEMA_NAME}.{_REGISTRY_TABLE_NAME} - ({_REGISTRY_SCHEMA_STRING})""", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"{_REGISTRY_TABLE_NAME} already exists, statement succeeded.")], - collect_statement_params=statement_params, - ), - ) - - def _mock_create_artifacts_table_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"""CREATE TABLE IF NOT EXISTS {_DATABASE_NAME}.{_SCHEMA_NAME}.{_ARTIFACTS_TABLE_NAME} - ({_ARTIFACTS_SCHEMA_STRING})""", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"{_ARTIFACTS_TABLE_NAME} already exists, statement succeeded.")], - collect_statement_params=statement_params, - ), - ) - - def _mock_create_registry_table_not_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"""CREATE TABLE IF NOT EXISTS {_DATABASE_NAME}.{_SCHEMA_NAME}.{_REGISTRY_TABLE_NAME} - ({_REGISTRY_SCHEMA_STRING})""", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"Table {_REGISTRY_TABLE_NAME} successfully created.")], - collect_statement_params=statement_params, - ), - ) - - def _mock_create_artifacts_table_not_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"""CREATE TABLE IF NOT EXISTS {_DATABASE_NAME}.{_SCHEMA_NAME}.{_ARTIFACTS_TABLE_NAME} - ({_ARTIFACTS_SCHEMA_STRING})""", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"Table {_ARTIFACTS_TABLE_NAME} successfully created.")], - collect_statement_params=statement_params, - ), - ) - - def _mock_create_metadata_table_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"""CREATE TABLE IF NOT EXISTS {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME} - ({_METADATA_SCHEMA_STRING})""", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"{_METADATA_TABLE_NAME} already exists, statement succeeded.")], - collect_statement_params=statement_params, - ), - ) - - def _mock_create_metadata_table_not_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"""CREATE TABLE IF NOT EXISTS {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME} - ({_METADATA_SCHEMA_STRING})""", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"Table {_METADATA_TABLE_NAME} successfully created.")], - collect_statement_params=statement_params, - ), - ) - - def _mock_create_deployment_table_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"""CREATE TABLE IF NOT EXISTS {_DATABASE_NAME}.{_SCHEMA_NAME}.{_DEPLOYMENTS_TABLE_NAME} - ({_DEPLOYMENTS_SCHEMA_STRING})""", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"{_DEPLOYMENTS_TABLE_NAME} already exists, statement succeeded.")], - collect_statement_params=statement_params, - ), - ) - - def _mock_create_deployment_table_not_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"""CREATE TABLE IF NOT EXISTS {_DATABASE_NAME}.{_SCHEMA_NAME}.{_DEPLOYMENTS_TABLE_NAME} - ({_DEPLOYMENTS_SCHEMA_STRING})""", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"Table {_DEPLOYMENTS_TABLE_NAME} successfully created.")], - collect_statement_params=statement_params, - ), - ) - - def _mock_show_version_table_not_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"""SHOW TABLES LIKE '{_VERSION_TABLE_NAME}' IN {_DATABASE_NAME}.{_SCHEMA_NAME}""", - result=mock_data_frame.MockDataFrame([]).add_collect_result([], statement_params=statement_params), - ) - - def _mock_show_version_table_exists(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"""SHOW TABLES LIKE '{_VERSION_TABLE_NAME}' IN {_DATABASE_NAME}.{_SCHEMA_NAME}""", - result=mock_data_frame.MockDataFrame(self.get_show_tables_success(name=_VERSION_TABLE_NAME)), - ) - - def _mock_select_from_version_table(self, statement_params: Dict[str, str], schema_version: int = 0) -> None: - result_df = mock_data_frame.MockDataFrame().add_collect_result( - result=[snowpark.Row(MAX_VERSION=schema_version)] - ) - self.add_session_mock_sql( - query=(f"SELECT MAX(VERSION) AS MAX_VERSION FROM {_DATABASE_NAME}.{_SCHEMA_NAME}.{_VERSION_TABLE_NAME}"), - result=result_df, - ) - - def _mock_insert_into_version_table(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=( - f"""INSERT INTO {_DATABASE_NAME}.{_SCHEMA_NAME}.{_VERSION_TABLE_NAME} - (VERSION, CREATION_TIME) - VALUES ({_schema._CURRENT_SCHEMA_VERSION}, CURRENT_TIMESTAMP()) - """ - ), - result=mock_data_frame.MockDataFrame([snowpark.Row(**{"number of rows inserted": 1})]), - ) - - def _mock_desc_registry_table(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"DESC TABLE {_DATABASE_NAME}.{_SCHEMA_NAME}.{_REGISTRY_TABLE_NAME}", - result=mock_data_frame.MockDataFrame(self.get_desc_registry_table_success()).add_collect_result( - self.get_desc_registry_table_success() - ), - ) - - def _mock_desc_metadata_table(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"DESC TABLE {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}", - result=mock_data_frame.MockDataFrame(self.get_desc_metadata_table_success()).add_collect_result( - self.get_desc_registry_table_success() - ), - ) - - def _mock_desc_deployments_table(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"DESC TABLE {_DATABASE_NAME}.{_SCHEMA_NAME}.{_DEPLOYMENTS_TABLE_NAME}", - result=mock_data_frame.MockDataFrame(self.get_desc_deployments_table_success()).add_collect_result( - self.get_desc_registry_table_success() - ), - ) - - def _mock_desc_artifacts_table(self, statement_params: Dict[str, str]) -> None: - self.add_session_mock_sql( - query=f"DESC TABLE {_DATABASE_NAME}.{_SCHEMA_NAME}.{_ARTIFACTS_TABLE_NAME}", - result=mock_data_frame.MockDataFrame(self.get_desc_artifacts_table_success()).add_collect_result( - self.get_desc_registry_table_success() - ), - ) - - def add_session_mock_sql(self, query: str, result: Any) -> None: - self._session.add_mock_sql(query=query, result=result) - - def get_model_registry(self) -> model_registry.ModelRegistry: - """Creates a valid model registry for testing.""" - self.setup_open_call() - - return model_registry.ModelRegistry(session=cast(snowpark.Session, self._session)) - - def get_show_tables_success( - self, name: str, database_name: str = _DATABASE_NAME, schema_name: str = _SCHEMA_NAME - ) -> List[snowpark.Row]: - """Helper method that returns a DataFrame that looks like the response of from a successful listing of - tables.""" - return [ - snowpark.Row( - created_on=self.datetime, - name=name, - database_name=database_name, - schema_name=schema_name, - kind="TABLE", - comment="", - cluster_by="", - rows=0, - bytes=0, - owner="OWNER_ROLE", - retention_time=1, - change_tracking="OFF", - is_external="N", - enable_schema_evolution="N", - ) - ] - - def get_show_schemas_success( - self, name: str, database_name: str = _DATABASE_NAME, schema_name: str = _SCHEMA_NAME - ) -> List[snowpark.Row]: - """Helper method that returns a DataFrame that looks like the response of from a successful listing of - schemas.""" - return [ - snowpark.Row( - created_on=self.datetime, - name=name, - is_default="N", - is_current="N", - database_name=database_name, - owner="OWNER_ROLE", - comment="", - options="", - retention_time=1, - ) - ] - - def get_show_databases_success(self, name: str) -> List[snowpark.Row]: - """Helper method that returns a DataFrame that looks like the response of from a successful listing of - databases.""" - return [ - snowpark.Row( - created_on=self.datetime, - name=name, - is_default="N", - is_current="N", - origin="", - owner="OWNER_ROLE", - comment="", - options="", - retention_time=1, - ) - ] - - def get_desc_registry_table_success(self) -> List[snowpark.Row]: - """Helper method that returns a DataFrame that looks like the response of from a successful desc table.""" - return [ - snowpark.Row(name="CREATION_CONTEXT", type="VARCHAR"), - snowpark.Row(name="CREATION_ENVIRONMENT_SPEC", type="OBJECT"), - snowpark.Row(name="CREATION_ROLE", type="VARCHAR"), - snowpark.Row(name="CREATION_TIME", type="TIMESTAMP_TZ"), - snowpark.Row(name="ID", type="VARCHAR PRIMARY KEY RELY"), - snowpark.Row(name="INPUT_SPEC", type="OBJECT"), - snowpark.Row(name="NAME", type="VARCHAR"), - snowpark.Row(name="OUTPUT_SPEC", type="OBJECT"), - snowpark.Row(name="RUNTIME_ENVIRONMENT_SPEC", type="OBJECT"), - snowpark.Row(name="ARTIFACT_IDS", type="ARRAY"), - snowpark.Row(name="TYPE", type="VARCHAR"), - snowpark.Row(name="URI", type="VARCHAR"), - snowpark.Row(name="VERSION", type="VARCHAR"), - ] - - def get_desc_metadata_table_success(self) -> List[snowpark.Row]: - """Helper method that returns a DataFrame that looks like the response of from a successful desc table.""" - return [ - snowpark.Row(name="ATTRIBUTE_NAME", type="VARCHAR"), - snowpark.Row(name="EVENT_ID", type="VARCHAR"), - snowpark.Row(name="EVENT_TIMESTAMP", type="TIMESTAMP_TZ"), - snowpark.Row(name="MODEL_ID", type="VARCHAR"), - snowpark.Row(name="OPERATION", type="VARCHAR"), - snowpark.Row(name="ROLE", type="VARCHAR"), - snowpark.Row(name="SEQUENCE_ID", type="BIGINT"), - snowpark.Row(name="VALUE", type="OBJECT"), - ] - - def get_desc_deployments_table_success(self) -> List[snowpark.Row]: - """Helper method that returns a DataFrame that looks like the response of from a successful desc table.""" - return [ - snowpark.Row(name="CREATION_TIME", type="TIMESTAMP_TZ"), - snowpark.Row(name="MODEL_ID", type="VARCHAR"), - snowpark.Row(name="DEPLOYMENT_NAME", type="VARCHAR"), - snowpark.Row(name="OPTIONS", type="VARIANT"), - snowpark.Row(name="TARGET_PLATFORM", type="VARCHAR"), - snowpark.Row(name="ROLE", type="VARCHAR"), - snowpark.Row(name="STAGE_PATH", type="VARCHAR"), - snowpark.Row(name="SIGNATURE", type="VARIANT"), - snowpark.Row(name="TARGET_METHOD", type="VARCHAR"), - ] - - def get_desc_artifacts_table_success(self) -> List[snowpark.Row]: - """Helper method that returns a DataFrame that looks like the response of from a successful desc table.""" - return [ - snowpark.Row(name="ID", type="VARCHAR"), - snowpark.Row(name="TYPE", type="VARCHAR"), - snowpark.Row(name="NAME", type="VARCHAR"), - snowpark.Row(name="VERSION", type="VARCHAR"), - snowpark.Row(name="CREATION_ROLE", type="VARCHAR"), - snowpark.Row(name="CREATION_TIME", type="TIMESTAMP_TZ"), - snowpark.Row(name="ARTIFACT_SPEC", type="OBJECT"), - ] - - def setup_open_call(self) -> None: - self.add_session_mock_sql( - query=f"SHOW DATABASES LIKE '{_DATABASE_NAME}'", - result=mock_data_frame.MockDataFrame( - self.get_show_databases_success(name=_DATABASE_NAME) - ).add_collect_result(self.get_show_databases_success(name=_DATABASE_NAME)), - ) - self.add_session_mock_sql( - query=f"SHOW SCHEMAS LIKE '{_SCHEMA_NAME}' IN DATABASE {_DATABASE_NAME}", - result=mock_data_frame.MockDataFrame(self.get_show_schemas_success(name=_SCHEMA_NAME)).add_collect_result( - self.get_show_schemas_success(name=_SCHEMA_NAME) - ), - ) - self.add_session_mock_sql( - query=f"SHOW TABLES LIKE '{_REGISTRY_TABLE_NAME}' IN {_DATABASE_NAME}.{_SCHEMA_NAME}", - result=mock_data_frame.MockDataFrame( - self.get_show_tables_success(name=_REGISTRY_TABLE_NAME) - ).add_collect_result(self.get_show_tables_success(name=_REGISTRY_TABLE_NAME)), - ) - self.add_session_mock_sql( - query=f"SHOW TABLES LIKE '{_METADATA_TABLE_NAME}' IN {_DATABASE_NAME}.{_SCHEMA_NAME}", - result=mock_data_frame.MockDataFrame( - self.get_show_tables_success(name=_METADATA_TABLE_NAME) - ).add_collect_result(self.get_show_tables_success(name=_METADATA_TABLE_NAME)), - ) - self.add_session_mock_sql( - query=f"SHOW TABLES LIKE '{_DEPLOYMENTS_TABLE_NAME}' IN {_DATABASE_NAME}.{_SCHEMA_NAME}", - result=mock_data_frame.MockDataFrame( - self.get_show_tables_success(name=_DEPLOYMENTS_TABLE_NAME) - ).add_collect_result(self.get_show_tables_success(name=_DEPLOYMENTS_TABLE_NAME)), - ) - self._mock_show_version_table_exists({}) - self._mock_select_from_version_table({}, _schema._CURRENT_SCHEMA_VERSION) - self.setup_create_views_call() - - def setup_list_model_call(self) -> mock_data_frame.MockDataFrame: - """Setup the expected calls originating from list_model.""" - result_df = mock_data_frame.MockDataFrame() - - self.add_session_mock_sql( - query=(f"SELECT * FROM {_DATABASE_NAME}.{_SCHEMA_NAME}.{_REGISTRY_TABLE_NAME}_VIEW"), - result=result_df, - ) - # Returning result_df to allow the caller to add more expected operations. - return result_df - - def setup_create_views_call(self) -> None: - """Setup the expected calls originating from _create_views.""" - self.add_session_mock_sql( - query=( - f"""CREATE OR REPLACE TEMPORARY VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_DEPLOYMENTS_TABLE_NAME}_VIEW - COPY GRANTS AS - SELECT - DEPLOYMENT_NAME, - MODEL_ID, - {_REGISTRY_TABLE_NAME}.NAME as MODEL_NAME, - {_REGISTRY_TABLE_NAME}.VERSION as MODEL_VERSION, - {_DEPLOYMENTS_TABLE_NAME}.CREATION_TIME as CREATION_TIME, - TARGET_METHOD, - TARGET_PLATFORM, - SIGNATURE, - OPTIONS, - STAGE_PATH, - ROLE - FROM {_DEPLOYMENTS_TABLE_NAME} - LEFT JOIN {_REGISTRY_TABLE_NAME} - ON {_DEPLOYMENTS_TABLE_NAME}.MODEL_ID = {_REGISTRY_TABLE_NAME}.ID - """ - ), - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"View {_DEPLOYMENTS_TABLE_NAME}_VIEW successfully created.")] - ), - ) - self.add_session_mock_sql( - query=( - f"""CREATE OR REPLACE TEMPORARY VIEW - {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}_LAST_DESCRIPTION - COPY GRANTS AS - SELECT DISTINCT - MODEL_ID, - (LAST_VALUE(VALUE) OVER ( - PARTITION BY MODEL_ID ORDER BY EVENT_TIMESTAMP))['DESCRIPTION'] - as DESCRIPTION - FROM {_METADATA_TABLE_NAME} WHERE ATTRIBUTE_NAME = 'DESCRIPTION'""" - ), - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"View {_METADATA_TABLE_NAME}_LAST_DESCRIPTION successfully created.")] - ), - ) - self.add_session_mock_sql( - query=( - f"""CREATE OR REPLACE TEMPORARY VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}_LAST_METRICS - COPY GRANTS AS - SELECT DISTINCT - MODEL_ID, - (LAST_VALUE(VALUE) OVER ( - PARTITION BY MODEL_ID ORDER BY EVENT_TIMESTAMP))['METRICS'] - as METRICS - FROM {_METADATA_TABLE_NAME} WHERE ATTRIBUTE_NAME = 'METRICS'""" - ), - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"View {_METADATA_TABLE_NAME}_LAST_METRICS successfully created.")] - ), - ) - self.add_session_mock_sql( - query=( - f"""CREATE OR REPLACE TEMPORARY VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}_LAST_TAGS - COPY GRANTS AS - SELECT DISTINCT - MODEL_ID, - (LAST_VALUE(VALUE) OVER ( - PARTITION BY MODEL_ID ORDER BY EVENT_TIMESTAMP))['TAGS'] - as TAGS - FROM {_METADATA_TABLE_NAME} WHERE ATTRIBUTE_NAME = 'TAGS'""" - ), - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"View {_METADATA_TABLE_NAME}_LAST_TAGS successfully created.")] - ), - ) - self.add_session_mock_sql( - query=( - f"""CREATE OR REPLACE TEMPORARY VIEW - {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}_LAST_REGISTRATION COPY GRANTS AS - SELECT DISTINCT - MODEL_ID, EVENT_TIMESTAMP as REGISTRATION_TIMESTAMP - FROM {_METADATA_TABLE_NAME} WHERE ATTRIBUTE_NAME = 'REGISTRATION'""" - ), - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"View {_METADATA_TABLE_NAME}_LAST_TAGS successfully created.")] - ), - ) - self.add_session_mock_sql( - query=( - f"""CREATE OR REPLACE TEMPORARY VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_REGISTRY_TABLE_NAME}_VIEW - COPY GRANTS AS - SELECT {_REGISTRY_TABLE_NAME}.*, {_METADATA_TABLE_NAME}_LAST_DESCRIPTION.DESCRIPTION - AS DESCRIPTION, - {_METADATA_TABLE_NAME}_LAST_METRICS.METRICS AS METRICS, - {_METADATA_TABLE_NAME}_LAST_TAGS.TAGS AS TAGS, - {_METADATA_TABLE_NAME}_LAST_REGISTRATION.REGISTRATION_TIMESTAMP AS REGISTRATION_TIMESTAMP - FROM {_REGISTRY_TABLE_NAME} - LEFT JOIN {_METADATA_TABLE_NAME}_LAST_DESCRIPTION ON - ({_METADATA_TABLE_NAME}_LAST_DESCRIPTION.MODEL_ID = {_REGISTRY_TABLE_NAME}.ID) - LEFT JOIN {_METADATA_TABLE_NAME}_LAST_METRICS - ON ({_METADATA_TABLE_NAME}_LAST_METRICS.MODEL_ID = {_REGISTRY_TABLE_NAME}.ID) - LEFT JOIN {_METADATA_TABLE_NAME}_LAST_TAGS - ON ({_METADATA_TABLE_NAME}_LAST_TAGS.MODEL_ID = {_REGISTRY_TABLE_NAME}.ID) - LEFT JOIN {_METADATA_TABLE_NAME}_LAST_REGISTRATION - ON ({_METADATA_TABLE_NAME}_LAST_REGISTRATION.MODEL_ID = {_REGISTRY_TABLE_NAME}.ID) - """ - ), - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"View {_REGISTRY_TABLE_NAME}_VIEW successfully created.")] - ), - ) - - def setup_open_existing(self) -> None: - self.add_session_mock_sql( - query=f"SHOW DATABASES LIKE '{_DATABASE_NAME}'", - result=mock_data_frame.MockDataFrame(self.get_show_databases_success(name=_DATABASE_NAME)), - ) - self.add_session_mock_sql( - query=f"SHOW SCHEMAS LIKE '{_SCHEMA_NAME}' IN DATABASE {_DATABASE_NAME}", - result=mock_data_frame.MockDataFrame(self.get_show_schemas_success(name=_SCHEMA_NAME)), - ) - self.add_session_mock_sql( - query=f"SHOW TABLES LIKE '{_REGISTRY_TABLE_NAME}' IN {_DATABASE_NAME}.{_SCHEMA_NAME}", - result=mock_data_frame.MockDataFrame(self.get_show_tables_success(name=_REGISTRY_TABLE_NAME)), - ) - self.add_session_mock_sql( - query=f"SHOW TABLES LIKE '{_METADATA_TABLE_NAME}' IN {_DATABASE_NAME}.{_SCHEMA_NAME}", - result=mock_data_frame.MockDataFrame(self.get_show_tables_success(name=_METADATA_TABLE_NAME)), - ) - self.add_session_mock_sql( - query=f"SHOW TABLES LIKE '{_DEPLOYMENTS_TABLE_NAME}' IN {_DATABASE_NAME}.{_SCHEMA_NAME}", - result=mock_data_frame.MockDataFrame(self.get_show_tables_success(name=_DEPLOYMENTS_TABLE_NAME)), - ) - self._mock_show_version_table_exists({}) - self._mock_select_from_version_table({}, _schema._CURRENT_SCHEMA_VERSION) - self.setup_create_views_call() - - def setup_schema_upgrade_calls(self, statement_params: Dict[str, str]) -> None: - self._mock_show_version_table_exists(statement_params) - self._mock_select_from_version_table(statement_params) - self._mock_desc_registry_table(statement_params) - # begin schema upgrade plans - reg_table_full_path = f"{_DATABASE_NAME}.{_SCHEMA_NAME}.{_REGISTRY_TABLE_NAME}" - self.add_session_mock_sql( - query=(f"ALTER TABLE {reg_table_full_path} ADD COLUMN TRAINING_DATASET_ID VARCHAR"), - result=mock_data_frame.MockDataFrame([snowpark.Row(status="Statement executed successfully.")]), - ) - self.add_session_mock_sql( - query=(f"ALTER TABLE {reg_table_full_path} DROP COLUMN TRAINING_DATASET_ID"), - result=mock_data_frame.MockDataFrame([snowpark.Row(status="Statement executed successfully.")]), - ) - self.add_session_mock_sql( - query=(f"ALTER TABLE {reg_table_full_path} ADD COLUMN ARTIFACT_IDS ARRAY"), - result=mock_data_frame.MockDataFrame([snowpark.Row(status="Statement executed successfully.")]), - ) - art_table_full_path = f"{_DATABASE_NAME}.{_SCHEMA_NAME}.{_ARTIFACTS_TABLE_NAME}" - self.add_session_mock_sql( - query=(f"ALTER TABLE {art_table_full_path} DROP COLUMN ARTIFACT_SPEC"), - result=mock_data_frame.MockDataFrame([snowpark.Row(status="Statement executed successfully.")]), - ) - self.add_session_mock_sql( - query=(f"ALTER TABLE {art_table_full_path} ADD COLUMN ARTIFACT_SPEC VARCHAR"), - result=mock_data_frame.MockDataFrame([snowpark.Row(status="Statement executed successfully.")]), - ) - self.add_session_mock_sql( - query=( - f"""COMMENT ON COLUMN {art_table_full_path}.ARTIFACT_SPEC IS - 'This column is VARCHAR but supposed to store a valid JSON object'""" - ), - result=mock_data_frame.MockDataFrame([snowpark.Row(status="Statement executed successfully.")]), - ) - - # end schema upgrade plans - self._mock_desc_registry_table(statement_params) - self._mock_desc_metadata_table(statement_params) - self._mock_desc_deployments_table(statement_params) - self._mock_desc_artifacts_table(statement_params) - self._mock_show_version_table_exists(statement_params) - self._mock_insert_into_version_table(statement_params) - - def template_test_get_attribute( - self, collection_res: List[snowpark.Row], use_id: bool = False - ) -> mock_data_frame.MockDataFrame: - expected_df = self.setup_list_model_call() - expected_df.add_operation("filter") - if not use_id: - expected_df.add_operation("filter") - expected_df.add_collect_result(collection_res) - return expected_df - - def template_test_set_attribute( - self, - attribute_name: str, - attribute_value: Union[str, Dict[Any, Any]], - result_num_inserted: int = 1, - use_id: bool = False, - ) -> None: - expected_df = self.setup_list_model_call() - expected_df.add_operation("filter") - if not use_id: - expected_df.add_operation("filter") - expected_df.add_collect_result( - [ - snowpark.Row( - ID=self.model_id, - NAME="name", - VERSION="abc", - URI=f"sfc://{_DATABASE_NAME}.{_SCHEMA_NAME}.model_stage", - ) - ] - ) - - self._session.add_operation("get_current_role", result="current_role") - - self.add_session_mock_sql( - query=( - f"""INSERT INTO {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME} - ({_METADATA_INSERT_COLUMNS_STRING}) - SELECT '{attribute_name}','{self.event_id}',CURRENT_TIMESTAMP(),'{self.model_id}','SET', - 'current_role',OBJECT_CONSTRUCT('{attribute_name}', - {formatting.format_value_for_select(attribute_value)})""" - ), - result=mock_data_frame.MockDataFrame([snowpark.Row(**{"number of rows inserted": result_num_inserted})]), - ) - - def test_create_new(self) -> None: - """Verify that we can create a new ModelRegistry database with the default names.""" - # "Create" calls. - combinations = list(itertools.product([True, False], repeat=7)) - for ( - database_exists, - schema_exists, - version_table_exists, - registry_table_exists, - metadata_table_exists, - deployments_table_exists, - artifacts_table_exists, - ) in combinations: - with self.subTest( - msg=( - f"database_exists={database_exists}, " - f"schema_exists={schema_exists}, " - f"version_table_exists={version_table_exists}, " - f"registry_table_exists={registry_table_exists}, " - f"metadata_table_exists={metadata_table_exists}, " - f"deployments_table_exists={deployments_table_exists}, " - f"artifacts_table_exists={artifacts_table_exists}" - ) - ): - statement_params = telemetry.get_function_usage_statement_params( - project="MLOps", - subproject="ModelRegistry", - function_name="snowflake.ml.registry.model_registry.create_model_registry", - ) - if database_exists: - self._mock_show_database_exists() - else: - self._mock_show_database_not_exists(statement_params) - self._mock_create_database_not_exists(statement_params) - - if schema_exists: - self._mock_show_schema_exists(statement_params) - else: - self._mock_show_schema_not_exists(statement_params) - self._mock_create_schema_not_exists(statement_params) - - # svm.get_deployed_version() - if version_table_exists: - self._mock_show_version_table_exists(statement_params) - self._mock_select_from_version_table(statement_params) - else: - self._mock_show_version_table_not_exists(statement_params) - - if registry_table_exists: - self._mock_create_registry_table_exists(statement_params) - else: - self._mock_create_registry_table_not_exists(statement_params) - - if metadata_table_exists: - self._mock_create_metadata_table_exists(statement_params) - else: - self._mock_create_metadata_table_not_exists(statement_params) - - if deployments_table_exists: - self._mock_create_deployment_table_exists(statement_params) - else: - self._mock_create_deployment_table_not_exists(statement_params) - - if artifacts_table_exists: - self._mock_create_artifacts_table_exists(statement_params) - else: - self._mock_create_artifacts_table_not_exists(statement_params) - - self.setup_schema_upgrade_calls(statement_params) - - model_registry.create_model_registry( - session=cast(snowpark.Session, self._session), - database_name=_DATABASE_NAME, - schema_name=_SCHEMA_NAME, - ) - - def test_create_if_not_exists(self) -> None: - statement_params = telemetry.get_function_usage_statement_params( - project="MLOps", - subproject="ModelRegistry", - function_name="snowflake.ml.registry.model_registry.create_model_registry", - ) - # SQL queries issued by create_model_registry: - self._mock_show_database_not_exists(statement_params) - self._mock_create_database_not_exists(statement_params) - self._mock_show_schema_not_exists(statement_params) - self._mock_create_schema_not_exists(statement_params) - self._mock_show_version_table_not_exists(statement_params) - self._mock_create_registry_table_not_exists(statement_params) - self._mock_create_metadata_table_not_exists(statement_params) - self._mock_create_deployment_table_not_exists(statement_params) - self._mock_create_artifacts_table_not_exists(statement_params) - - self.setup_schema_upgrade_calls(statement_params) - - # 2. SQL queries issued by ModelRegistry constructor. - self.setup_open_existing() - - registry = model_registry.ModelRegistry( - session=cast(snowpark.Session, self._session), - database_name=_DATABASE_NAME, - schema_name=_SCHEMA_NAME, - create_if_not_exists=True, - ) - self.assertIsNotNone(registry) - - def test_open_existing(self) -> None: - """Verify that we can open an existing ModelRegistry database with the default names.""" - self.setup_open_existing() - model_registry.ModelRegistry(session=cast(snowpark.Session, self._session)) - - def test_list_models(self) -> None: - """Test the normal operation of list_models. We create a view and return the model metadata.""" - model_registry = self.get_model_registry() - self.setup_list_model_call().add_collect_result([snowpark.Row(ID=self.model_id, NAME="model_name")]) - - model_list = model_registry.list_models().collect() - self.assertEqual(model_list, [snowpark.Row(ID=self.model_id, NAME="model_name")]) - - def test_set_model_description(self) -> None: - """Test that we can set the description for an existing model.""" - model_registry = self.get_model_registry() - self.template_test_set_attribute("DESCRIPTION", "new_description") - - # Mock unique identifier for event id. - with absltest.mock.patch.object( - model_registry, - "_get_new_unique_identifier", - return_value=self.event_id, - ): - model_registry.set_model_description( - model_name=self.model_name, model_version=self.model_version, description="new_description" - ) - - def test_get_model_description(self) -> None: - """Test that we can get the description of an existing model from the registry.""" - model_registry = self.get_model_registry() - self.template_test_get_attribute( - [ - snowpark.Row( - ID=self.model_id, - NAME=self.model_name, - VERSION=self.model_version, - DESCRIPTION='"model_description"', - ) - ] - ) - - model_description = model_registry.get_model_description( - model_name=self.model_name, model_version=self.model_version - ) - self.assertEqual(model_description, "model_description") - - def test_get_history(self) -> None: - """Test that we can retrieve the history for the model history.""" - model_registry = self.get_model_registry() - expected_collect_result = [ - snowpark.Row( - EVENT_TIMESTAMP="ts", - EVENT_ID=self.event_id, - MODEL_ID=self.model_id, - ROLE="role", - OPERATION="SET", - ATTRIBUTE_NAME="NAME", - VALUE={"NAME": "name"}, - ) - ] - - expected_df = mock_data_frame.MockDataFrame() - self._session.add_operation( - operation="table", - args=(f"{_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}",), - result=expected_df, - ) - expected_df.add_operation("order_by", args=("EVENT_TIMESTAMP",)) - expected_df.add_operation( - "select_expr", - args=( - "EVENT_TIMESTAMP", - "EVENT_ID", - "MODEL_ID", - "ROLE", - "OPERATION", - "ATTRIBUTE_NAME", - "VALUE[ATTRIBUTE_NAME]", - ), - ) - expected_df.add_collect_result(expected_collect_result) - - self.assertEqual(model_registry.get_history().collect(), expected_collect_result) - - def test_get_model_history(self) -> None: - """Test that we can retrieve the history for a specific model.""" - model_registry = self.get_model_registry() - self.template_test_get_attribute( - [snowpark.Row(ID=self.model_id, NAME=self.model_name, VERSION=self.model_version)] - ) - expected_collect_result = [ - snowpark.Row( - EVENT_TIMESTAMP="ts", - EVENT_ID=self.event_id, - MODEL_ID=self.model_id, - ROLE="role", - OPERATION="SET", - ATTRIBUTE_NAME="NAME", - VALUE={"NAME": "name"}, - ) - ] - - expected_df = mock_data_frame.MockDataFrame() - self._session.add_operation( - operation="table", - args=(f"{_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}",), - result=expected_df, - ) - expected_df.add_operation("order_by", args=("EVENT_TIMESTAMP",)) - expected_df.add_operation( - "select_expr", - args=( - "EVENT_TIMESTAMP", - "EVENT_ID", - "MODEL_ID", - "ROLE", - "OPERATION", - "ATTRIBUTE_NAME", - "VALUE[ATTRIBUTE_NAME]", - ), - ) - expected_df.add_operation(operation="filter", check_args=False, check_kwargs=False) - expected_df.add_collect_result(expected_collect_result) - - self.assertEqual( - model_registry.get_model_history(model_name=self.model_name, model_version=self.model_version).collect(), - expected_collect_result, - ) - - def test_set_metric_no_existing(self) -> None: - """Test that we can set a metric for an existing model that does not yet have any metrics set.""" - model_registry = self.get_model_registry() - self.template_test_get_attribute( - [snowpark.Row(ID=self.model_id, NAME=self.model_name, VERSION=self.model_version, METRICS=None)] - ) - self.template_test_set_attribute("METRICS", {"voight-kampff": 0.9}) - - # Mock unique identifier for event id. - with absltest.mock.patch.object( - model_registry, - "_get_new_unique_identifier", - return_value=self.event_id, - ): - model_registry.set_metric( - model_name=self.model_name, - model_version=self.model_version, - metric_name="voight-kampff", - metric_value=0.9, - ) - - def test_set_metric_with_existing(self) -> None: - """Test that we can set a metric for an existing model that already has metrics.""" - model_registry = self.get_model_registry() - self.template_test_get_attribute( - [ - snowpark.Row( - ID=self.model_id, NAME=self.model_name, VERSION=self.model_version, METRICS='{"human-factor": 1.1}' - ) - ] - ) - self.template_test_set_attribute("METRICS", {"human-factor": 1.1, "voight-kampff": 0.9}) - - # Mock unique identifier for event id. - with absltest.mock.patch.object( - model_registry, - "_get_new_unique_identifier", - return_value=self.event_id, - ): - model_registry.set_metric( - model_name=self.model_name, - model_version=self.model_version, - metric_name="voight-kampff", - metric_value=0.9, - ) - - def test_get_metrics(self) -> None: - """Test that we can get the metrics for an existing model.""" - metrics_dict = {"human-factor": 1.1, "voight-kampff": 0.9} - model_registry = self.get_model_registry() - self.template_test_get_attribute( - [ - snowpark.Row( - ID=self.model_id, NAME=self.model_name, VERSION=self.model_version, METRICS=json.dumps(metrics_dict) - ) - ] - ) - self.assertEqual( - model_registry.get_metrics(model_name=self.model_name, model_version=self.model_version), metrics_dict - ) - - def test_get_metric_value(self) -> None: - """Test that we can get a single metric value for an existing model.""" - metrics_dict = {"human-factor": 1.1, "voight-kampff": 0.9} - model_registry = self.get_model_registry() - self.template_test_get_attribute( - [ - snowpark.Row( - ID=self.model_id, NAME=self.model_name, VERSION=self.model_version, METRICS=json.dumps(metrics_dict) - ) - ] - ) - self.assertEqual( - model_registry.get_metric_value( - model_name=self.model_name, model_version=self.model_version, metric_name="human-factor" - ), - 1.1, - ) - - def test_private_insert_registry_entry(self) -> None: - model_registry = self.get_model_registry() - - self.add_session_mock_sql( - query=f""" - INSERT INTO {_DATABASE_NAME}.{_SCHEMA_NAME}.{_REGISTRY_TABLE_NAME} ( ID,NAME,TYPE,URI,VERSION ) - SELECT 'id','name','type','uri','abc' - """, - result=mock_data_frame.MockDataFrame([snowpark.Row(**{"number of rows inserted": 1})]), - ) - - model_properties = {"ID": "id", "NAME": "name", "TYPE": "type", "URI": "uri"} - - model_registry._insert_registry_entry(id="id", name="name", version="abc", properties=model_properties) - - def test_get_tags(self) -> None: - """Test that get_tags is working correctly with various types.""" - model_registry = self.get_model_registry() - self.template_test_get_attribute( - [ - snowpark.Row( - TAGS=""" - { - "top_level": "string", - "nested": { - "float": 0.9, - "int": 23, - "nested_string": "string", - "empty_string": "", - "bool_true": true, - "bool_false": "false", - "1d_array": [ - 1, - 2, - 3 - ], - "2d_array": [ - [ - 90, - 0 - ], - [ - 3, - 7 - ] - ] - } - }""", - ) - ] - ) - tags = model_registry.get_tags(model_name=self.model_name, model_version=self.model_version) - self.assertEqual(tags["top_level"], "string") - self.assertEqual(tags["nested"]["float"], 0.9) - - def test_log_model_path(self) -> None: - """Test _log_model_path(). - - Validate _log_model_path() can perform stage file put operation with the expected stage path and call - register_model() with the expected arguments. - """ - model_registry = self.get_model_registry() - - model_name = "name" - model_version = "abc" - expected_stage_postfix = f"{self.model_id}".upper() - - self.add_session_mock_sql( - query=f"CREATE OR REPLACE STAGE {_DATABASE_NAME}.{_SCHEMA_NAME}.SNOWML_MODEL_{expected_stage_postfix} " - f"ENCRYPTION = (TYPE= 'SNOWFLAKE_SSE')", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(**{"status": f"Stage area SNOWML_MODEL_{expected_stage_postfix} successfully created."})] - ), - ) - - expected_stage_path = ( - f"{identifier.get_inferred_name(_DATABASE_NAME)}" - + "." - + f"{identifier.get_inferred_name(_SCHEMA_NAME)}" - + "." - + f"SNOWML_MODEL_{expected_stage_postfix}" - ) - - with absltest.mock.patch.object( - model_registry, - "_get_new_unique_identifier", - return_value=self.model_id, - ): - model_id, stage_path = model_registry._log_model_path( - model_name=model_name, - model_version=model_version, - ) - self.assertEqual(model_id, self.model_id) - self.assertEqual(stage_path, expected_stage_path) - - def test_log_model(self) -> None: - """Test log_model()""" - model_registry = self.get_model_registry() - self._mock_show_version_table_exists({}) - self._mock_select_from_version_table({}, _schema._CURRENT_SCHEMA_VERSION) - - model_name = "name" - model_version = "abc" - expected_stage_postfix = f"{self.model_id}".upper() - - expected_stage_path = ( - f"{identifier.get_inferred_name(_DATABASE_NAME)}" - + "." - + f"{identifier.get_inferred_name(_SCHEMA_NAME)}" - + "." - + f"SNOWML_MODEL_{expected_stage_postfix}" - ) - model_path = f"@{expected_stage_path}" - with absltest.mock.patch.object( - model_registry, - "_list_selected_models", - return_value=absltest.mock.MagicMock(count=absltest.mock.MagicMock(return_value=0)), - ): - with absltest.mock.patch.object( - model_registry, - "_log_model_path", - return_value=(self.model_id, expected_stage_path), - ) as mock_path: - mock_model = absltest.mock.MagicMock() - mock_type = absltest.mock.MagicMock() - mock_model_composer = absltest.mock.MagicMock( - packager=absltest.mock.MagicMock(meta=absltest.mock.MagicMock(model_type=mock_type)) - ) - with absltest.mock.patch.object( - target=_api, attribute="save_model", return_value=mock_model_composer - ) as mock_save: - with absltest.mock.patch.object( - target=model_registry, attribute="_register_model_with_id", return_value=None - ) as mock_register: - with absltest.mock.patch.object(model_registry, "_get_model_id", return_value=self.model_id): - m_signatures = {"predict": None} - model_registry.log_model( - model_name=model_name, - model_version=model_version, - model=mock_model, - signatures=m_signatures, - description="description", - tags=None, - ) - mock_path.assert_called_once_with(model_name=model_name, model_version=model_version) - mock_save.assert_called_once_with( - name=model_name, - session=self._session, - stage_path=model_path, - model=mock_model, - signatures=m_signatures, - metadata=None, - conda_dependencies=None, - pip_requirements=None, - sample_input_data=None, - code_paths=None, - options=None, - ) - mock_register.assert_called_once_with( - model_name=model_name, - model_version=model_version, - model_id=self.model_id, - type=mock_type, - uri=uri.get_uri_from_snowflake_stage_path(model_path), - description="description", - tags=None, - ) - - self._mock_show_version_table_exists({}) - self._mock_select_from_version_table({}, _schema._CURRENT_SCHEMA_VERSION) - - with absltest.mock.patch.object( - model_registry, - "_list_selected_models", - return_value=absltest.mock.MagicMock(count=absltest.mock.MagicMock(return_value=1)), - ): - with self.assertRaises(connector.DataError): - model_registry.log_model( - model_name=model_name, - model_version=model_version, - model=mock_model, - signatures=m_signatures, - description="description", - tags=None, - ) - - self._mock_show_version_table_exists({}) - self._mock_select_from_version_table({}, _schema._CURRENT_SCHEMA_VERSION) - self.add_session_mock_sql( - query=f"DROP STAGE {_DATABASE_NAME}.{_SCHEMA_NAME}.SNOWML_MODEL_{expected_stage_postfix}", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(**{"status": f"Stage area SNOWML_MODEL_{expected_stage_postfix} successfully dropped."})] - ), - ) - - with absltest.mock.patch.object( - model_registry, - "_list_selected_models", - return_value=absltest.mock.MagicMock(count=absltest.mock.MagicMock(return_value=0)), - ): - with absltest.mock.patch.object( - model_registry, - "_log_model_path", - return_value=(self.model_id, expected_stage_path), - ) as mock_path: - mock_model = absltest.mock.MagicMock() - mock_type = absltest.mock.MagicMock() - with absltest.mock.patch.object(target=_api, attribute="save_model") as mock_save: - mock_save.side_effect = ValueError("Mock Error") - with self.assertRaises(ValueError): - model_registry.log_model( - model_name=model_name, - model_version=model_version, - model=mock_model, - signatures=m_signatures, - description="description", - tags=None, - ) - - def test_delete_model_with_artifact(self) -> None: - """Test deleting a model and artifact from the registry.""" - model_registry = self.get_model_registry() - self.setup_list_model_call().add_operation(operation="filter").add_operation( - operation="filter" - ).add_collect_result( - [ - snowpark.Row( - ID=self.model_id, - NAME=self.model_name, - VERSION=self.model_version, - URI=f"sfc://{_DATABASE_NAME}.{_SCHEMA_NAME}.model_stage", - ) - ], - ) - self.add_session_mock_sql( - query=f""" - DELETE FROM {_DATABASE_NAME}.{_SCHEMA_NAME}.{_REGISTRY_TABLE_NAME} WHERE ID='{self.model_id}' - """, - result=mock_data_frame.MockDataFrame([snowpark.Row(**{"number of rows deleted": 1})]), - ) - self.add_session_mock_sql( - query=f"DROP STAGE {_DATABASE_NAME}.{_SCHEMA_NAME}.model_stage", - result=mock_data_frame.MockDataFrame( - [snowpark.Row(**{"status": f"'{_DATABASE_NAME}.{_SCHEMA_NAME}.model_stage' successfully dropped."})] - ), - ) - self.template_test_set_attribute( - "DELETION", - { - "URI": f"sfc://{_DATABASE_NAME}.{_SCHEMA_NAME}.model_stage", - "delete_artifact": True, - }, - use_id=True, - ) - - with absltest.mock.patch.object( - model_registry, - "_get_new_unique_identifier", - return_value=self.event_id, - ): - model_registry.delete_model(model_name="name", model_version="abc") - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/registry/notebooks/Deployment to Snowpark Container Service Demo.ipynb b/snowflake/ml/registry/notebooks/Deployment to Snowpark Container Service Demo.ipynb deleted file mode 100644 index 06ce6bed..00000000 --- a/snowflake/ml/registry/notebooks/Deployment to Snowpark Container Service Demo.ipynb +++ /dev/null @@ -1,747 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "a45960e1", - "metadata": {}, - "source": [ - "# Deployment to Snowpark Container Service Demo" - ] - }, - { - "cell_type": "markdown", - "id": "aa7a329a", - "metadata": {}, - "source": [ - "### Snowflake-ML-Python Installation" - ] - }, - { - "cell_type": "markdown", - "id": "cb3d7a96", - "metadata": {}, - "source": [ - "- Please refer to our [landing page](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index) to install `snowflake-ml-python`." - ] - }, - { - "cell_type": "markdown", - "id": "3b50d774", - "metadata": {}, - "source": [ - "## Train a model with Snowpark ML API " - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "18a75d71", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Tuple\n", - "from snowflake.ml.modeling import linear_model\n", - "from sklearn import datasets\n", - "import pandas as pd\n", - "import numpy as np\n", - "\n", - "def prepare_logistic_model() -> Tuple[linear_model.LogisticRegression, pd.DataFrame]:\n", - " iris = datasets.load_iris()\n", - " df = pd.DataFrame(data=np.c_[iris[\"data\"], iris[\"target\"]], columns=iris[\"feature_names\"] + [\"target\"])\n", - " df.columns = [s.replace(\" (CM)\", \"\").replace(\" \", \"\") for s in df.columns.str.upper()]\n", - "\n", - " input_cols = [\"SEPALLENGTH\", \"SEPALWIDTH\", \"PETALLENGTH\", \"PETALWIDTH\"]\n", - " label_cols = \"TARGET\"\n", - " output_cols = \"PREDICTED_TARGET\"\n", - "\n", - " estimator = linear_model.LogisticRegression(\n", - " input_cols=input_cols, output_cols=output_cols, label_cols=label_cols, random_state=0, max_iter=1000\n", - " ).fit(df)\n", - "\n", - " return estimator, df.drop(columns=label_cols).head(10)" - ] - }, - { - "cell_type": "markdown", - "id": "db6734fa", - "metadata": {}, - "source": [ - "## Start Snowpark Session" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "58dd3604", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. \n" - ] - } - ], - "source": [ - "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", - "from snowflake.snowpark import Session\n", - "\n", - "session = Session.builder.configs(SnowflakeLoginOptions()).create()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "27dfbc42", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:snowflake.snowpark:create_model_registry() is in private preview since 0.2.0. Do not use it in production. \n", - "WARNING:absl:The database SHULIN_DB already exists. Skipping creation.\n", - "WARNING:absl:The schema SHULIN_DB.SHULIN_SCHEMA already exists. Skipping creation.\n" - ] - } - ], - "source": [ - "from snowflake.ml.registry import model_registry\n", - "from snowflake.ml._internal.utils import identifier\n", - "\n", - "db = identifier._get_unescaped_name(session.get_current_database())\n", - "schema = identifier._get_unescaped_name(session.get_current_schema())\n", - "\n", - "# will be a no-op if registry already exists\n", - "model_registry.create_model_registry(session=session, database_name=db, schema_name=schema) \n", - "registry = model_registry.ModelRegistry(session=session, database_name=db, schema_name=schema)" - ] - }, - { - "cell_type": "markdown", - "id": "38e0a975", - "metadata": {}, - "source": [ - "## Register SnowML Model" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "574e7a43", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:snowflake.snowpark:ModelRegistry.log_model() is in private preview since 0.2.0. Do not use it in production. \n", - "WARNING:snowflake.snowpark:ModelRegistry.list_models() is in private preview since 0.2.0. Do not use it in production. \n", - "/Users/shchen/micromamba/envs/snowml_1.0.12/lib/python3.10/site-packages/snowflake/ml/_internal/env_utils.py:217: UserWarning: Package requirement snowflake-snowpark-python<2,>=1.8.0 specified, while version 1.6.1 is installed. Local version will be ignored to conform to package requirement.\n", - " warnings.warn(\n", - "/Users/shchen/micromamba/envs/snowml_1.0.12/lib/python3.10/site-packages/snowflake/ml/_internal/env_utils.py:217: UserWarning: Package requirement snowflake-snowpark-python<2,>=1.8.0 specified, while version 1.6.1 is installed. Local version will be ignored to conform to package requirement.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "text/plain": [ - "'\\nIf your model has been logged and you want to reference it, you can use:\\nmodel_ref = model_registry.ModelReference(\\n registry=registry, model_name=model_name, model_version=model_version\\n)\\n'" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "logistic_model, test_features = prepare_logistic_model()\n", - "model_name = \"snowpark_ml_logistic\"\n", - "model_version = \"v1\"\n", - "\n", - "model_ref = registry.log_model(\n", - " model_name=model_name,\n", - " model_version=model_version,\n", - " model=logistic_model,\n", - " sample_input_data=test_features,\n", - ")\n", - "\n", - "\"\"\"\n", - "If your model has been logged and you want to reference it, you can use:\n", - "model_ref = model_registry.ModelReference(\n", - " registry=registry, model_name=model_name, model_version=model_version\n", - ")\n", - "\"\"\"" - ] - }, - { - "cell_type": "markdown", - "id": "054a3862", - "metadata": {}, - "source": [ - "## Model Deployment to Snowpark Container Service" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "64d286fb-bc80-4ce4-85e6-bcc46b38ce7c", - "metadata": {}, - "outputs": [], - "source": [ - "# Optionally enable INFO log level to show more logging during model deployment.\n", - "import logging\n", - "logging.basicConfig()\n", - "logging.getLogger().setLevel(logging.INFO)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "72ff114f", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:snowflake.snowpark:ModelRegistry.deploy() is in private preview since 0.2.0. Do not use it in production. \n", - "INFO:snowflake.connector.cursor:query: [SHOW TABLES LIKE '_SYSTEM_REGISTRY_SCHEMA_VERSION' IN SHULIN_DB.SHULIN_SCHEMA]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:Number of results in first chunk: 1\n", - "INFO:snowflake.connector.cursor:query: [SELECT MAX(VERSION) AS MAX_VERSION FROM SHULIN_DB.SHULIN_SCHEMA._SYSTEM_REGISTRY...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:Number of results in first chunk: 1\n", - "INFO:snowflake.connector.cursor:query: [CREATE STAGE IF NOT EXISTS SHULIN_DB.SHULIN_SCHEMA._SYSTEM_REGISTRY_DEPLOYMENTS_...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:Number of results in first chunk: 1\n", - "INFO:snowflake.connector.cursor:query: [SELECT * FROM SHULIN_DB.SHULIN_SCHEMA._SYSTEM_REGISTRY_MODELS_VIEW]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:Number of results in first chunk: 0\n", - "INFO:snowflake.connector.cursor:query: [SELECT * FROM (SELECT * FROM SHULIN_DB.SHULIN_SCHEMA._SYSTEM_REGISTRY_MODELS_V...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:Number of results in first chunk: 1\n", - "INFO:snowflake.connector.cursor:query: [LIST @SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:Number of results in first chunk: 9\n", - "INFO:snowflake.connector.cursor:query: [SELECT * FROM SHULIN_DB.SHULIN_SCHEMA._SYSTEM_REGISTRY_MODELS_VIEW]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:Number of results in first chunk: 0\n", - "INFO:snowflake.connector.cursor:query: [SELECT * FROM (SELECT * FROM SHULIN_DB.SHULIN_SCHEMA._SYSTEM_REGISTRY_MODELS_V...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:Number of results in first chunk: 1\n", - "INFO:snowflake.connector.cursor:query: [ls @SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:Number of results in first chunk: 9\n", - "INFO:snowflake.connector.cursor:query: [GET '@SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/MANI...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:query: [GET '@SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/mode...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:query: [GET '@SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/modu...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:query: [GET '@SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/modu...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:query: [GET '@SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/modu...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:query: [GET '@SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/modu...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:query: [GET '@SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/runt...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:query: [GET '@SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/runt...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:query: [GET '@SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/runt...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "WARNING:snowflake.ml.model._deploy_client.snowservice.deploy:Similar environment detected. Using existing image sfengineering-mlplatformtest.registry.snowflakecomputing.com/shulin_db/shulin_schema/snowml_repo/260e5812c5d0c81981b30ab72d53a291a894a505:latest to skip image build. To disable this feature, set 'force_image_build=True' in deployment options\n", - "INFO:snowflake.ml.model._deploy_client.utils.snowservice_client:Creating service SHULIN_DB.SHULIN_SCHEMA.service_704eb1ee858011ee9dc05ac3f3b698e0\n", - "INFO:snowflake.ml.model._deploy_client.snowservice.deploy:Wait for service SHULIN_DB.SHULIN_SCHEMA.service_704eb1ee858011ee9dc05ac3f3b698e0 to become ready...\n", - "WARNING:snowflake.ml.model._deploy_client.utils.snowservice_client:Best-effort log streaming from SPCS will be enabled when python logging level is set to INFO.Alternatively, you can also query the logs by running the query 'CALL SYSTEM$GET_SERVICE_LOGS('SHULIN_DB.SHULIN_SCHEMA.service_704eb1ee858011ee9dc05ac3f3b698e0', '0', 'inference-server')'\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:Number of CPU cores: 4\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:Setting number of workers to 9\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:41 +0000] [1] [INFO] Starting gunicorn 21.2.0\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:41 +0000] [1] [INFO] Listening at: http://0.0.0.0:5000 (1)\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:41 +0000] [1] [INFO] Using worker: uvicorn.workers.UvicornWorker\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:41 +0000] [20] [INFO] Booting worker with pid: 20\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:41 +0000] [21] [INFO] Booting worker with pid: 21\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:41 +0000] [22] [INFO] Booting worker with pid: 22\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:41 +0000] [23] [INFO] Booting worker with pid: 23\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:41 +0000] [24] [INFO] Booting worker with pid: 24\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:41 +0000] [26] [INFO] Booting worker with pid: 26\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:41 +0000] [32] [INFO] Booting worker with pid: 32\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:41 +0000] [37] [INFO] Booting worker with pid: 37\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [39] [INFO] Booting worker with pid: 39\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [20] [INFO] ENV: environ({'SERVICE_SERVICE_HOST': '10.102.169.166', 'KUBERNETES_SERVICE_PORT_HTTPS': '443', 'KUBERNETES_SERVICE_PORT': '443', 'ENV_NAME': 'base', 'MAMBA_USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PROTO': 'tcp', 'HOSTNAME': 'statefulset-0', 'stage_uid': '1000', 'NUM_WORKERS': 'None', 'SNOWFLAKE_PORT': '443', 'PWD': '/tmp', 'CONDA_PREFIX': '/opt/conda', 'SERVICE_SERVICE_PORT_PREDICT': '5000', 'MAMBA_ROOT_PREFIX': '/opt/conda', 'SNOWFLAKE_ACCOUNT': 'FAB02971', 'SNOWFLAKE_DATABASE': 'SHULIN_DB', 'TARGET_METHOD': 'predict', 'vol1_gid': '0', 'vol1_uid': '0', 'HOME': '/home/mambauser', 'SERVICE_PORT_5000_TCP_ADDR': '10.102.169.166', 'LANG': 'C.UTF-8', 'KUBERNETES_PORT_443_TCP': 'tcp://10.96.0.1:443', 'CONDA_PROMPT_MODIFIER': '(base) ', 'SNOWML_USE_GPU': 'false', 'SNOWFLAKE_SCHEMA': 'SHULIN_SCHEMA', 'SERVICE_PORT_5000_TCP': 'tcp://10.102.169.166:5000', 'stage_gid': '1000', 'MAMBA_EXE': '/bin/micromamba', 'SNOWFLAKE_HOST': 'snowflake.prod3.us-west-2.aws.snowflakecomputing.com', 'USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PORT': '5000', 'CONDA_SHLVL': '1', 'SHLVL': '0', 'SERVICE_SERVICE_PORT': '5000', 'KUBERNETES_PORT_443_TCP_PROTO': 'tcp', 'KUBERNETES_PORT_443_TCP_ADDR': '10.96.0.1', 'CONDA_DEFAULT_ENV': 'base', 'KUBERNETES_SERVICE_HOST': '10.96.0.1', 'LC_ALL': 'C.UTF-8', 'KUBERNETES_PORT': 'tcp://10.96.0.1:443', 'KUBERNETES_PORT_443_TCP_PORT': '443', 'SERVICE_PORT': 'tcp://10.102.169.166:5000', 'PATH': '/opt/conda/bin:/opt/conda/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', 'MODEL_ZIP_STAGE_PATH': '/SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip', 'SERVER_SOFTWARE': 'gunicorn/21.2.0'})\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [20] [INFO] Started server process [20]\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [20] [INFO] Waiting for application startup.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [20] [INFO] Application startup complete.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [20] [INFO] Extracting model zip from /SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip to /tmp/tmpb0qnwx1b/extracted_model_dir\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [22] [INFO] ENV: environ({'SERVICE_SERVICE_HOST': '10.102.169.166', 'KUBERNETES_SERVICE_PORT_HTTPS': '443', 'KUBERNETES_SERVICE_PORT': '443', 'ENV_NAME': 'base', 'MAMBA_USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PROTO': 'tcp', 'HOSTNAME': 'statefulset-0', 'stage_uid': '1000', 'NUM_WORKERS': 'None', 'SNOWFLAKE_PORT': '443', 'PWD': '/tmp', 'CONDA_PREFIX': '/opt/conda', 'SERVICE_SERVICE_PORT_PREDICT': '5000', 'MAMBA_ROOT_PREFIX': '/opt/conda', 'SNOWFLAKE_ACCOUNT': 'FAB02971', 'SNOWFLAKE_DATABASE': 'SHULIN_DB', 'TARGET_METHOD': 'predict', 'vol1_gid': '0', 'vol1_uid': '0', 'HOME': '/home/mambauser', 'SERVICE_PORT_5000_TCP_ADDR': '10.102.169.166', 'LANG': 'C.UTF-8', 'KUBERNETES_PORT_443_TCP': 'tcp://10.96.0.1:443', 'CONDA_PROMPT_MODIFIER': '(base) ', 'SNOWML_USE_GPU': 'false', 'SNOWFLAKE_SCHEMA': 'SHULIN_SCHEMA', 'SERVICE_PORT_5000_TCP': 'tcp://10.102.169.166:5000', 'stage_gid': '1000', 'MAMBA_EXE': '/bin/micromamba', 'SNOWFLAKE_HOST': 'snowflake.prod3.us-west-2.aws.snowflakecomputing.com', 'USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PORT': '5000', 'CONDA_SHLVL': '1', 'SHLVL': '0', 'SERVICE_SERVICE_PORT': '5000', 'KUBERNETES_PORT_443_TCP_PROTO': 'tcp', 'KUBERNETES_PORT_443_TCP_ADDR': '10.96.0.1', 'CONDA_DEFAULT_ENV': 'base', 'KUBERNETES_SERVICE_HOST': '10.96.0.1', 'LC_ALL': 'C.UTF-8', 'KUBERNETES_PORT': 'tcp://10.96.0.1:443', 'KUBERNETES_PORT_443_TCP_PORT': '443', 'SERVICE_PORT': 'tcp://10.102.169.166:5000', 'PATH': '/opt/conda/bin:/opt/conda/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', 'MODEL_ZIP_STAGE_PATH': '/SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip', 'SERVER_SOFTWARE': 'gunicorn/21.2.0'})\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [21] [INFO] ENV: environ({'SERVICE_SERVICE_HOST': '10.102.169.166', 'KUBERNETES_SERVICE_PORT_HTTPS': '443', 'KUBERNETES_SERVICE_PORT': '443', 'ENV_NAME': 'base', 'MAMBA_USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PROTO': 'tcp', 'HOSTNAME': 'statefulset-0', 'stage_uid': '1000', 'NUM_WORKERS': 'None', 'SNOWFLAKE_PORT': '443', 'PWD': '/tmp', 'CONDA_PREFIX': '/opt/conda', 'SERVICE_SERVICE_PORT_PREDICT': '5000', 'MAMBA_ROOT_PREFIX': '/opt/conda', 'SNOWFLAKE_ACCOUNT': 'FAB02971', 'SNOWFLAKE_DATABASE': 'SHULIN_DB', 'TARGET_METHOD': 'predict', 'vol1_gid': '0', 'vol1_uid': '0', 'HOME': '/home/mambauser', 'SERVICE_PORT_5000_TCP_ADDR': '10.102.169.166', 'LANG': 'C.UTF-8', 'KUBERNETES_PORT_443_TCP': 'tcp://10.96.0.1:443', 'CONDA_PROMPT_MODIFIER': '(base) ', 'SNOWML_USE_GPU': 'false', 'SNOWFLAKE_SCHEMA': 'SHULIN_SCHEMA', 'SERVICE_PORT_5000_TCP': 'tcp://10.102.169.166:5000', 'stage_gid': '1000', 'MAMBA_EXE': '/bin/micromamba', 'SNOWFLAKE_HOST': 'snowflake.prod3.us-west-2.aws.snowflakecomputing.com', 'USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PORT': '5000', 'CONDA_SHLVL': '1', 'SHLVL': '0', 'SERVICE_SERVICE_PORT': '5000', 'KUBERNETES_PORT_443_TCP_PROTO': 'tcp', 'KUBERNETES_PORT_443_TCP_ADDR': '10.96.0.1', 'CONDA_DEFAULT_ENV': 'base', 'KUBERNETES_SERVICE_HOST': '10.96.0.1', 'LC_ALL': 'C.UTF-8', 'KUBERNETES_PORT': 'tcp://10.96.0.1:443', 'KUBERNETES_PORT_443_TCP_PORT': '443', 'SERVICE_PORT': 'tcp://10.102.169.166:5000', 'PATH': '/opt/conda/bin:/opt/conda/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', 'MODEL_ZIP_STAGE_PATH': '/SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip', 'SERVER_SOFTWARE': 'gunicorn/21.2.0'})\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [23] [INFO] ENV: environ({'SERVICE_SERVICE_HOST': '10.102.169.166', 'KUBERNETES_SERVICE_PORT_HTTPS': '443', 'KUBERNETES_SERVICE_PORT': '443', 'ENV_NAME': 'base', 'MAMBA_USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PROTO': 'tcp', 'HOSTNAME': 'statefulset-0', 'stage_uid': '1000', 'NUM_WORKERS': 'None', 'SNOWFLAKE_PORT': '443', 'PWD': '/tmp', 'CONDA_PREFIX': '/opt/conda', 'SERVICE_SERVICE_PORT_PREDICT': '5000', 'MAMBA_ROOT_PREFIX': '/opt/conda', 'SNOWFLAKE_ACCOUNT': 'FAB02971', 'SNOWFLAKE_DATABASE': 'SHULIN_DB', 'TARGET_METHOD': 'predict', 'vol1_gid': '0', 'vol1_uid': '0', 'HOME': '/home/mambauser', 'SERVICE_PORT_5000_TCP_ADDR': '10.102.169.166', 'LANG': 'C.UTF-8', 'KUBERNETES_PORT_443_TCP': 'tcp://10.96.0.1:443', 'CONDA_PROMPT_MODIFIER': '(base) ', 'SNOWML_USE_GPU': 'false', 'SNOWFLAKE_SCHEMA': 'SHULIN_SCHEMA', 'SERVICE_PORT_5000_TCP': 'tcp://10.102.169.166:5000', 'stage_gid': '1000', 'MAMBA_EXE': '/bin/micromamba', 'SNOWFLAKE_HOST': 'snowflake.prod3.us-west-2.aws.snowflakecomputing.com', 'USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PORT': '5000', 'CONDA_SHLVL': '1', 'SHLVL': '0', 'SERVICE_SERVICE_PORT': '5000', 'KUBERNETES_PORT_443_TCP_PROTO': 'tcp', 'KUBERNETES_PORT_443_TCP_ADDR': '10.96.0.1', 'CONDA_DEFAULT_ENV': 'base', 'KUBERNETES_SERVICE_HOST': '10.96.0.1', 'LC_ALL': 'C.UTF-8', 'KUBERNETES_PORT': 'tcp://10.96.0.1:443', 'KUBERNETES_PORT_443_TCP_PORT': '443', 'SERVICE_PORT': 'tcp://10.102.169.166:5000', 'PATH': '/opt/conda/bin:/opt/conda/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', 'MODEL_ZIP_STAGE_PATH': '/SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip', 'SERVER_SOFTWARE': 'gunicorn/21.2.0'})\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [22] [INFO] Started server process [22]\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [22] [INFO] Waiting for application startup.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [22] [INFO] Application startup complete.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [21] [INFO] Started server process [21]\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [21] [INFO] Waiting for application startup.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [21] [INFO] Application startup complete.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [21] [INFO] Extracting model zip from /SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip to /tmp/tmp7gqnk72j/extracted_model_dir\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [22] [INFO] Extracting model zip from /SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip to /tmp/tmpybmydbd8/extracted_model_dir\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [23] [INFO] Started server process [23]\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [23] [INFO] Waiting for application startup.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [23] [INFO] Application startup complete.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [23] [INFO] Extracting model zip from /SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip to /tmp/tmplvgh7e5w/extracted_model_dir\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [32] [INFO] ENV: environ({'SERVICE_SERVICE_HOST': '10.102.169.166', 'KUBERNETES_SERVICE_PORT_HTTPS': '443', 'KUBERNETES_SERVICE_PORT': '443', 'ENV_NAME': 'base', 'MAMBA_USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PROTO': 'tcp', 'HOSTNAME': 'statefulset-0', 'stage_uid': '1000', 'NUM_WORKERS': 'None', 'SNOWFLAKE_PORT': '443', 'PWD': '/tmp', 'CONDA_PREFIX': '/opt/conda', 'SERVICE_SERVICE_PORT_PREDICT': '5000', 'MAMBA_ROOT_PREFIX': '/opt/conda', 'SNOWFLAKE_ACCOUNT': 'FAB02971', 'SNOWFLAKE_DATABASE': 'SHULIN_DB', 'TARGET_METHOD': 'predict', 'vol1_gid': '0', 'vol1_uid': '0', 'HOME': '/home/mambauser', 'SERVICE_PORT_5000_TCP_ADDR': '10.102.169.166', 'LANG': 'C.UTF-8', 'KUBERNETES_PORT_443_TCP': 'tcp://10.96.0.1:443', 'CONDA_PROMPT_MODIFIER': '(base) ', 'SNOWML_USE_GPU': 'false', 'SNOWFLAKE_SCHEMA': 'SHULIN_SCHEMA', 'SERVICE_PORT_5000_TCP': 'tcp://10.102.169.166:5000', 'stage_gid': '1000', 'MAMBA_EXE': '/bin/micromamba', 'SNOWFLAKE_HOST': 'snowflake.prod3.us-west-2.aws.snowflakecomputing.com', 'USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PORT': '5000', 'CONDA_SHLVL': '1', 'SHLVL': '0', 'SERVICE_SERVICE_PORT': '5000', 'KUBERNETES_PORT_443_TCP_PROTO': 'tcp', 'KUBERNETES_PORT_443_TCP_ADDR': '10.96.0.1', 'CONDA_DEFAULT_ENV': 'base', 'KUBERNETES_SERVICE_HOST': '10.96.0.1', 'LC_ALL': 'C.UTF-8', 'KUBERNETES_PORT': 'tcp://10.96.0.1:443', 'KUBERNETES_PORT_443_TCP_PORT': '443', 'SERVICE_PORT': 'tcp://10.102.169.166:5000', 'PATH': '/opt/conda/bin:/opt/conda/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', 'MODEL_ZIP_STAGE_PATH': '/SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip', 'SERVER_SOFTWARE': 'gunicorn/21.2.0'})\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [20] [INFO] Loading model from /tmp/tmpb0qnwx1b/extracted_model_dir into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [37] [INFO] ENV: environ({'SERVICE_SERVICE_HOST': '10.102.169.166', 'KUBERNETES_SERVICE_PORT_HTTPS': '443', 'KUBERNETES_SERVICE_PORT': '443', 'ENV_NAME': 'base', 'MAMBA_USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PROTO': 'tcp', 'HOSTNAME': 'statefulset-0', 'stage_uid': '1000', 'NUM_WORKERS': 'None', 'SNOWFLAKE_PORT': '443', 'PWD': '/tmp', 'CONDA_PREFIX': '/opt/conda', 'SERVICE_SERVICE_PORT_PREDICT': '5000', 'MAMBA_ROOT_PREFIX': '/opt/conda', 'SNOWFLAKE_ACCOUNT': 'FAB02971', 'SNOWFLAKE_DATABASE': 'SHULIN_DB', 'TARGET_METHOD': 'predict', 'vol1_gid': '0', 'vol1_uid': '0', 'HOME': '/home/mambauser', 'SERVICE_PORT_5000_TCP_ADDR': '10.102.169.166', 'LANG': 'C.UTF-8', 'KUBERNETES_PORT_443_TCP': 'tcp://10.96.0.1:443', 'CONDA_PROMPT_MODIFIER': '(base) ', 'SNOWML_USE_GPU': 'false', 'SNOWFLAKE_SCHEMA': 'SHULIN_SCHEMA', 'SERVICE_PORT_5000_TCP': 'tcp://10.102.169.166:5000', 'stage_gid': '1000', 'MAMBA_EXE': '/bin/micromamba', 'SNOWFLAKE_HOST': 'snowflake.prod3.us-west-2.aws.snowflakecomputing.com', 'USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PORT': '5000', 'CONDA_SHLVL': '1', 'SHLVL': '0', 'SERVICE_SERVICE_PORT': '5000', 'KUBERNETES_PORT_443_TCP_PROTO': 'tcp', 'KUBERNETES_PORT_443_TCP_ADDR': '10.96.0.1', 'CONDA_DEFAULT_ENV': 'base', 'KUBERNETES_SERVICE_HOST': '10.96.0.1', 'LC_ALL': 'C.UTF-8', 'KUBERNETES_PORT': 'tcp://10.96.0.1:443', 'KUBERNETES_PORT_443_TCP_PORT': '443', 'SERVICE_PORT': 'tcp://10.102.169.166:5000', 'PATH': '/opt/conda/bin:/opt/conda/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', 'MODEL_ZIP_STAGE_PATH': '/SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip', 'SERVER_SOFTWARE': 'gunicorn/21.2.0'})\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [32] [INFO] Extracting model zip from /SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip to /tmp/tmpmfhs5e6d/extracted_model_dir\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [23] [INFO] Loading model from /tmp/tmplvgh7e5w/extracted_model_dir into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [37] [INFO] Started server process [37]\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [37] [INFO] Waiting for application startup.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [32] [INFO] Started server process [32]\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [32] [INFO] Waiting for application startup.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [32] [INFO] Application startup complete.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:42 +0000] [37] [INFO] Application startup complete.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [37] [INFO] Extracting model zip from /SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip to /tmp/tmp756rou0h/extracted_model_dir\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [21] [INFO] Loading model from /tmp/tmp7gqnk72j/extracted_model_dir into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [22] [INFO] Loading model from /tmp/tmpybmydbd8/extracted_model_dir into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [26] [INFO] ENV: environ({'SERVICE_SERVICE_HOST': '10.102.169.166', 'KUBERNETES_SERVICE_PORT_HTTPS': '443', 'KUBERNETES_SERVICE_PORT': '443', 'ENV_NAME': 'base', 'MAMBA_USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PROTO': 'tcp', 'HOSTNAME': 'statefulset-0', 'stage_uid': '1000', 'NUM_WORKERS': 'None', 'SNOWFLAKE_PORT': '443', 'PWD': '/tmp', 'CONDA_PREFIX': '/opt/conda', 'SERVICE_SERVICE_PORT_PREDICT': '5000', 'MAMBA_ROOT_PREFIX': '/opt/conda', 'SNOWFLAKE_ACCOUNT': 'FAB02971', 'SNOWFLAKE_DATABASE': 'SHULIN_DB', 'TARGET_METHOD': 'predict', 'vol1_gid': '0', 'vol1_uid': '0', 'HOME': '/home/mambauser', 'SERVICE_PORT_5000_TCP_ADDR': '10.102.169.166', 'LANG': 'C.UTF-8', 'KUBERNETES_PORT_443_TCP': 'tcp://10.96.0.1:443', 'CONDA_PROMPT_MODIFIER': '(base) ', 'SNOWML_USE_GPU': 'false', 'SNOWFLAKE_SCHEMA': 'SHULIN_SCHEMA', 'SERVICE_PORT_5000_TCP': 'tcp://10.102.169.166:5000', 'stage_gid': '1000', 'MAMBA_EXE': '/bin/micromamba', 'SNOWFLAKE_HOST': 'snowflake.prod3.us-west-2.aws.snowflakecomputing.com', 'USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PORT': '5000', 'CONDA_SHLVL': '1', 'SHLVL': '0', 'SERVICE_SERVICE_PORT': '5000', 'KUBERNETES_PORT_443_TCP_PROTO': 'tcp', 'KUBERNETES_PORT_443_TCP_ADDR': '10.96.0.1', 'CONDA_DEFAULT_ENV': 'base', 'KUBERNETES_SERVICE_HOST': '10.96.0.1', 'LC_ALL': 'C.UTF-8', 'KUBERNETES_PORT': 'tcp://10.96.0.1:443', 'KUBERNETES_PORT_443_TCP_PORT': '443', 'SERVICE_PORT': 'tcp://10.102.169.166:5000', 'PATH': '/opt/conda/bin:/opt/conda/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', 'MODEL_ZIP_STAGE_PATH': '/SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip', 'SERVER_SOFTWARE': 'gunicorn/21.2.0'})\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [37] [INFO] Loading model from /tmp/tmp756rou0h/extracted_model_dir into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [26] [INFO] Started server process [26]\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [26] [INFO] Waiting for application startup.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [26] [INFO] Application startup complete.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [26] [INFO] Extracting model zip from /SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip to /tmp/tmps97gvf92/extracted_model_dir\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [32] [INFO] Loading model from /tmp/tmpmfhs5e6d/extracted_model_dir into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:43 +0000] [26] [INFO] Loading model from /tmp/tmps97gvf92/extracted_model_dir into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:44 +0000] [24] [INFO] ENV: environ({'SERVICE_SERVICE_HOST': '10.102.169.166', 'KUBERNETES_SERVICE_PORT_HTTPS': '443', 'KUBERNETES_SERVICE_PORT': '443', 'ENV_NAME': 'base', 'MAMBA_USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PROTO': 'tcp', 'HOSTNAME': 'statefulset-0', 'stage_uid': '1000', 'NUM_WORKERS': 'None', 'SNOWFLAKE_PORT': '443', 'PWD': '/tmp', 'CONDA_PREFIX': '/opt/conda', 'SERVICE_SERVICE_PORT_PREDICT': '5000', 'MAMBA_ROOT_PREFIX': '/opt/conda', 'SNOWFLAKE_ACCOUNT': 'FAB02971', 'SNOWFLAKE_DATABASE': 'SHULIN_DB', 'TARGET_METHOD': 'predict', 'vol1_gid': '0', 'vol1_uid': '0', 'HOME': '/home/mambauser', 'SERVICE_PORT_5000_TCP_ADDR': '10.102.169.166', 'LANG': 'C.UTF-8', 'KUBERNETES_PORT_443_TCP': 'tcp://10.96.0.1:443', 'CONDA_PROMPT_MODIFIER': '(base) ', 'SNOWML_USE_GPU': 'false', 'SNOWFLAKE_SCHEMA': 'SHULIN_SCHEMA', 'SERVICE_PORT_5000_TCP': 'tcp://10.102.169.166:5000', 'stage_gid': '1000', 'MAMBA_EXE': '/bin/micromamba', 'SNOWFLAKE_HOST': 'snowflake.prod3.us-west-2.aws.snowflakecomputing.com', 'USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PORT': '5000', 'CONDA_SHLVL': '1', 'SHLVL': '0', 'SERVICE_SERVICE_PORT': '5000', 'KUBERNETES_PORT_443_TCP_PROTO': 'tcp', 'KUBERNETES_PORT_443_TCP_ADDR': '10.96.0.1', 'CONDA_DEFAULT_ENV': 'base', 'KUBERNETES_SERVICE_HOST': '10.96.0.1', 'LC_ALL': 'C.UTF-8', 'KUBERNETES_PORT': 'tcp://10.96.0.1:443', 'KUBERNETES_PORT_443_TCP_PORT': '443', 'SERVICE_PORT': 'tcp://10.102.169.166:5000', 'PATH': '/opt/conda/bin:/opt/conda/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', 'MODEL_ZIP_STAGE_PATH': '/SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip', 'SERVER_SOFTWARE': 'gunicorn/21.2.0'})\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:44 +0000] [39] [INFO] ENV: environ({'SERVICE_SERVICE_HOST': '10.102.169.166', 'KUBERNETES_SERVICE_PORT_HTTPS': '443', 'KUBERNETES_SERVICE_PORT': '443', 'ENV_NAME': 'base', 'MAMBA_USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PROTO': 'tcp', 'HOSTNAME': 'statefulset-0', 'stage_uid': '1000', 'NUM_WORKERS': 'None', 'SNOWFLAKE_PORT': '443', 'PWD': '/tmp', 'CONDA_PREFIX': '/opt/conda', 'SERVICE_SERVICE_PORT_PREDICT': '5000', 'MAMBA_ROOT_PREFIX': '/opt/conda', 'SNOWFLAKE_ACCOUNT': 'FAB02971', 'SNOWFLAKE_DATABASE': 'SHULIN_DB', 'TARGET_METHOD': 'predict', 'vol1_gid': '0', 'vol1_uid': '0', 'HOME': '/home/mambauser', 'SERVICE_PORT_5000_TCP_ADDR': '10.102.169.166', 'LANG': 'C.UTF-8', 'KUBERNETES_PORT_443_TCP': 'tcp://10.96.0.1:443', 'CONDA_PROMPT_MODIFIER': '(base) ', 'SNOWML_USE_GPU': 'false', 'SNOWFLAKE_SCHEMA': 'SHULIN_SCHEMA', 'SERVICE_PORT_5000_TCP': 'tcp://10.102.169.166:5000', 'stage_gid': '1000', 'MAMBA_EXE': '/bin/micromamba', 'SNOWFLAKE_HOST': 'snowflake.prod3.us-west-2.aws.snowflakecomputing.com', 'USER': 'mambauser', 'SERVICE_PORT_5000_TCP_PORT': '5000', 'CONDA_SHLVL': '1', 'SHLVL': '0', 'SERVICE_SERVICE_PORT': '5000', 'KUBERNETES_PORT_443_TCP_PROTO': 'tcp', 'KUBERNETES_PORT_443_TCP_ADDR': '10.96.0.1', 'CONDA_DEFAULT_ENV': 'base', 'KUBERNETES_SERVICE_HOST': '10.96.0.1', 'LC_ALL': 'C.UTF-8', 'KUBERNETES_PORT': 'tcp://10.96.0.1:443', 'KUBERNETES_PORT_443_TCP_PORT': '443', 'SERVICE_PORT': 'tcp://10.102.169.166:5000', 'PATH': '/opt/conda/bin:/opt/conda/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', 'MODEL_ZIP_STAGE_PATH': '/SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip', 'SERVER_SOFTWARE': 'gunicorn/21.2.0'})\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:44 +0000] [24] [INFO] Started server process [24]\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:44 +0000] [24] [INFO] Waiting for application startup.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:44 +0000] [24] [INFO] Application startup complete.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:44 +0000] [24] [INFO] Extracting model zip from /SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip to /tmp/tmp0xbrnt1_/extracted_model_dir\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:44 +0000] [39] [INFO] Started server process [39]\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:44 +0000] [39] [INFO] Waiting for application startup.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:44 +0000] [39] [INFO] Application startup complete.\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:44 +0000] [39] [INFO] Extracting model zip from /SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip to /tmp/tmp_f0jegpx/extracted_model_dir\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:44 +0000] [39] [INFO] Loading model from /tmp/tmp_f0jegpx/extracted_model_dir into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:44 +0000] [24] [INFO] Loading model from /tmp/tmp0xbrnt1_/extracted_model_dir into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:46 +0000] [26] [INFO] Successfully loaded model into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:46 +0000] [21] [INFO] Successfully loaded model into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:46 +0000] [23] [INFO] Successfully loaded model into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:46 +0000] [20] [INFO] Successfully loaded model into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:46 +0000] [32] [INFO] Successfully loaded model into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:46 +0000] [22] [INFO] Successfully loaded model into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:46 +0000] [37] [INFO] Successfully loaded model into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:47 +0000] [24] [INFO] Successfully loaded model into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:[2023-11-17 19:35:47 +0000] [39] [INFO] Successfully loaded model into memory\n", - "INFO:snowflake.ml._internal.utils.log_stream_processor:\n", - "INFO:snowflake.ml.model._deploy_client.snowservice.deploy:Service SHULIN_DB.SHULIN_SCHEMA.service_704eb1ee858011ee9dc05ac3f3b698e0 is ready. Creating service function...\n", - "INFO:snowflake.ml.model._deploy_client.snowservice.deploy:Service function SHULIN_DB.SHULIN_SCHEMA.LOGISTIC_FUNC is created. Deployment completed successfully!\n", - "INFO:snowflake.connector.cursor:query: [INSERT INTO SHULIN_DB.SHULIN_SCHEMA._SYSTEM_REGISTRY_DEPLOYMENTS ( CREATION_TIME...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:query: [SELECT * FROM SHULIN_DB.SHULIN_SCHEMA._SYSTEM_REGISTRY_MODELS_VIEW]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:Number of results in first chunk: 0\n", - "INFO:snowflake.connector.cursor:query: [SELECT * FROM (SELECT * FROM SHULIN_DB.SHULIN_SCHEMA._SYSTEM_REGISTRY_MODELS_V...]\n", - "INFO:snowflake.connector.cursor:query execution done\n", - "INFO:snowflake.connector.cursor:Number of results in first chunk: 1\n", - "INFO:snowflake.connector.cursor:query: [INSERT INTO SHULIN_DB.SHULIN_SCHEMA._SYSTEM_REGISTRY_METADATA ( ATTRIBUTE_NAME,E...]\n", - "INFO:snowflake.connector.cursor:query execution done\n" - ] - }, - { - "data": { - "text/plain": [ - "{'name': 'SHULIN_DB.SHULIN_SCHEMA.LOGISTIC_FUNC',\n", - " 'platform': ,\n", - " 'target_method': 'predict',\n", - " 'signature': ModelSignature(\n", - " inputs=[\n", - " FeatureSpec(dtype=DataType.DOUBLE, name='SEPALLENGTH'),\n", - " \t\tFeatureSpec(dtype=DataType.DOUBLE, name='SEPALWIDTH'),\n", - " \t\tFeatureSpec(dtype=DataType.DOUBLE, name='PETALLENGTH'),\n", - " \t\tFeatureSpec(dtype=DataType.DOUBLE, name='PETALWIDTH')\n", - " ],\n", - " outputs=[\n", - " FeatureSpec(dtype=DataType.DOUBLE, name='SEPALLENGTH'),\n", - " \t\tFeatureSpec(dtype=DataType.DOUBLE, name='SEPALWIDTH'),\n", - " \t\tFeatureSpec(dtype=DataType.DOUBLE, name='PETALLENGTH'),\n", - " \t\tFeatureSpec(dtype=DataType.DOUBLE, name='PETALWIDTH'),\n", - " \t\tFeatureSpec(dtype=DataType.DOUBLE, name='PREDICTED_TARGET')\n", - " ]\n", - " ),\n", - " 'options': {'compute_pool': 'DEV_INFERENCE_CPU_POOL', 'enable_ingress': True},\n", - " 'details': {'service_info': {'name': 'SERVICE_704EB1EE858011EE9DC05AC3F3B698E0',\n", - " 'database_name': 'SHULIN_DB',\n", - " 'schema_name': 'SHULIN_SCHEMA',\n", - " 'owner': 'ENGINEER',\n", - " 'compute_pool': 'DEV_INFERENCE_CPU_POOL',\n", - " 'spec': '---\\nspec:\\n containers:\\n - name: \"inference-server\"\\n image: \"sfengineering-mlplatformtest.registry.snowflakecomputing.com/shulin_db/shulin_schema/snowml_repo/260e5812c5d0c81981b30ab72d53a291a894a505:latest\"\\n env:\\n MODEL_ZIP_STAGE_PATH: \"/SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0/model.zip\"\\n NUM_WORKERS: \"None\"\\n SNOWML_USE_GPU: \"false\"\\n TARGET_METHOD: \"predict\"\\n readinessProbe:\\n port: 5000\\n path: \"/health\"\\n volumeMounts:\\n - name: \"vol1\"\\n mountPath: \"/local/user/vol1\"\\n - name: \"stage\"\\n mountPath: \"SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0\"\\n volumes:\\n - name: \"vol1\"\\n source: \"local\"\\n - name: \"stage\"\\n source: \"@SHULIN_DB.SHULIN_SCHEMA.SNOWML_MODEL_704EB1EE858011EE9DC05AC3F3B698E0\"\\n uid: 1000\\n gid: 1000\\n endpoints:\\n - name: \"predict\"\\n port: 5000\\n public: true\\n',\n", - " 'dns_name': 'service-704eb1ee858011ee9dc05ac3f3b698e0.shulin-schema.shulin-db.snowflakecomputing.internal',\n", - " 'public_endpoints': 'Endpoints provisioning in progress... check back in a few minutes',\n", - " 'min_instances': 1,\n", - " 'max_instances': 1,\n", - " 'auto_resume': 'true',\n", - " 'created_on': datetime.datetime(2023, 11, 17, 11, 35, 39, 754000, tzinfo=),\n", - " 'updated_on': datetime.datetime(2023, 11, 17, 11, 35, 40, 289000, tzinfo=),\n", - " 'comment': None},\n", - " 'service_function_sql': \"\\nCREATE OR REPLACE FUNCTION SHULIN_DB.SHULIN_SCHEMA.LOGISTIC_FUNC(input OBJECT)\\n RETURNS OBJECT\\n SERVICE=SHULIN_DB.SHULIN_SCHEMA.service_704eb1ee858011ee9dc05ac3f3b698e0\\n ENDPOINT=predict\\n\\n AS '/predict'\\n\"}}" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from snowflake.ml.model import deploy_platforms\n", - "from snowflake import snowpark\n", - "\n", - "compute_pool = \"DEV_INFERENCE_CPU_POOL\" # Pre-created compute pool\n", - "deployment_name = \"LOGISTIC_FUNC\" # Name of the resulting UDF\n", - "\n", - "deployment_info = model_ref.deploy(\n", - " deployment_name=deployment_name, \n", - " platform=deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES,\n", - " target_method=\"predict\",\n", - " options={\n", - " \"compute_pool\": compute_pool,\n", - " \"enable_ingress\": True,\n", - " #num_gpus: 1 # Specify the number of GPUs for GPU inferenc\n", - " }\n", - ")\n", - "\n", - "deployment_info" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "8709ee24-f7c0-458a-bc54-a2b78d5cc2cb", - "metadata": {}, - "outputs": [], - "source": [ - "import logging\n", - "logging.basicConfig()\n", - "logging.getLogger().setLevel(logging.WARNING)" - ] - }, - { - "cell_type": "markdown", - "id": "1c754e72", - "metadata": {}, - "source": [ - "## Batch Prediction on Snowpark Container Service" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "a5c02328", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:snowflake.snowpark:ModelReference.predict() is in private preview since 0.2.0. Do not use it in production. \n", - "WARNING:snowflake.snowpark:ModelRegistry.get_deployment() is in private preview since 1.0.1. Do not use it in production. \n", - "WARNING:snowflake.snowpark:ModelRegistry.list_deployments() is in private preview since 1.0.1. Do not use it in production. \n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
SEPALLENGTHSEPALWIDTHPETALLENGTHPETALWIDTHPREDICTED_TARGET
05.13.51.40.20.0
14.93.01.40.20.0
24.73.21.30.20.0
34.63.11.50.20.0
45.03.61.40.20.0
55.43.91.70.40.0
64.63.41.40.30.0
75.03.41.50.20.0
84.42.91.40.20.0
94.93.11.50.10.0
\n", - "
" - ], - "text/plain": [ - " SEPALLENGTH SEPALWIDTH PETALLENGTH PETALWIDTH PREDICTED_TARGET\n", - "0 5.1 3.5 1.4 0.2 0.0\n", - "1 4.9 3.0 1.4 0.2 0.0\n", - "2 4.7 3.2 1.3 0.2 0.0\n", - "3 4.6 3.1 1.5 0.2 0.0\n", - "4 5.0 3.6 1.4 0.2 0.0\n", - "5 5.4 3.9 1.7 0.4 0.0\n", - "6 4.6 3.4 1.4 0.3 0.0\n", - "7 5.0 3.4 1.5 0.2 0.0\n", - "8 4.4 2.9 1.4 0.2 0.0\n", - "9 4.9 3.1 1.5 0.1 0.0" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model_ref.predict(deployment_name, test_features)" - ] - }, - { - "cell_type": "markdown", - "id": "add87e4c-986c-4757-a3f5-5109f66d94c6", - "metadata": {}, - "source": [ - "## Invoke Service Public Endpoint on Snowpark Container Service" - ] - }, - { - "cell_type": "markdown", - "id": "829eac06-b760-443f-aaee-0915b0208005", - "metadata": {}, - "source": [ - "### Prerequisites:\n", - "- For Limited Private Preview, the ACCOUNTADMIN of your Snowflake account must execute the following command:\n", - "```\n", - "CREATE SECURITY INTEGRATION SNOWSERVICES_INGRESS_OAUTH\n", - "TYPE=oauth\n", - "OAUTH_CLIENT=snowservices_ingress\n", - "ENABLED=true;\n", - "```\n", - "\n", - "### Notes:\n", - "- Because Snowpark Containers uses Snowflake OAuth to enable ingress, the default role of the user cannot be any of the privileged roles, including ACCOUNTADMIN, SECURITYADMIN, and ORGADMIN. For more information, see Blocking Specific Roles from Using the Integration.\n", - "\n", - "- Not everyone can access the public endpoints a service exposes. Only users in the same Snowflake account having a role with USAGE privilege on a service can access the public endpoints of the service.\n", - "\n", - "For more details, please refers to https://docs.snowflake.com/LIMITEDACCESS/snowpark-containers/working-with-services#ingress-using-a-service-from-outside-snowflake for detailed setup to enable public endpoint on Snowpark Container Service.\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "5d35ef90-0a5d-4c2f-80a9-ca6d6c2eb60c", - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "def get_service_endpoint():\n", - " service_info = deployment_info[\"details\"][\"service_info\"]\n", - " rows = session.sql(f'DESCRIBE SERVICE {service_info[\"database_name\"]}.{service_info[\"schema_name\"]}.{service_info[\"name\"]}').collect()\n", - " res = rows[0][\"public_endpoints\"]\n", - " if \"provisioning in progress\" in res:\n", - " raise Valuere(\"Endpoints provisioning in progress. Please retry in a few seconds\") \n", - " res_json = json.loads(res)\n", - " target_method = deployment_info[\"target_method\"]\n", - " return res_json[target_method]\n", - "\n", - "def get_session_token(session) -> str:\n", - " \"\"\"\n", - " Gets session token from Snowflake client.\n", - " \"\"\"\n", - " return session._conn._conn._rest._token_request(\"ISSUE\")[\"data\"][\"sessionToken\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "5f2c6442-1026-4761-b1cd-64a6a8bf3fa8", - "metadata": {}, - "outputs": [], - "source": [ - "data = {\"data\": [[index, {\"_ID\": index, **row.to_dict()}] for index, row in test_features.iterrows()]}" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "5d90c42c-678b-449c-9b11-4ca34fa23d6b", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'Valuere' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[11], line 12\u001b[0m\n\u001b[1;32m 6\u001b[0m session_token \u001b[38;5;241m=\u001b[39m get_session_token(session)\n\u001b[1;32m 7\u001b[0m headers \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAuthorization\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mSnowflake Token=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msession_token\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 9\u001b[0m }\n\u001b[0;32m---> 12\u001b[0m api_endpoint \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[43mget_service_endpoint\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/predict\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 14\u001b[0m res \u001b[38;5;241m=\u001b[39m requests\u001b[38;5;241m.\u001b[39mpost(api_endpoint, json\u001b[38;5;241m=\u001b[39mdata, headers\u001b[38;5;241m=\u001b[39mheaders)\n\u001b[1;32m 16\u001b[0m session\u001b[38;5;241m.\u001b[39msql(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124marrow\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\u001b[38;5;241m.\u001b[39mcollect()\n", - "Cell \u001b[0;32mIn[9], line 7\u001b[0m, in \u001b[0;36mget_service_endpoint\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m res \u001b[38;5;241m=\u001b[39m rows[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpublic_endpoints\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprovisioning in progress\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m res:\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[43mValuere\u001b[49m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEndpoints provisioning in progress. Please retry in a few seconds\u001b[39m\u001b[38;5;124m\"\u001b[39m) \n\u001b[1;32m 8\u001b[0m res_json \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(res)\n\u001b[1;32m 9\u001b[0m target_method \u001b[38;5;241m=\u001b[39m deployment_info[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtarget_method\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n", - "\u001b[0;31mNameError\u001b[0m: name 'Valuere' is not defined" - ] - } - ], - "source": [ - "import requests\n", - "\n", - "# Temporarily reset PYTHON_CONNECTOR_QUERY_RESULT_FORMAT needed for obtaining session token. \n", - "session.sql(\"ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'json'\").collect()\n", - "\n", - "session_token = get_session_token(session)\n", - "headers = {\n", - " \"Authorization\": f'Snowflake Token=\"{session_token}\"',\n", - "}\n", - "\n", - "\n", - "api_endpoint = f\"https://{get_service_endpoint()}/predict\"\n", - "\n", - "res = requests.post(api_endpoint, json=data, headers=headers)\n", - "\n", - "session.sql(\"ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'arrow'\").collect()\n", - "\n", - "res.json()[\"da\"]\n" - ] - }, - { - "cell_type": "markdown", - "id": "2b45f922-7b36-4555-a148-e78af0a4cf5d", - "metadata": {}, - "source": [ - "## Cleanup " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12991f07", - "metadata": {}, - "outputs": [], - "source": [ - "model_ref.delete_deployment(deployment_name=deployment_name)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "09f337d2", - "metadata": {}, - "outputs": [], - "source": [ - "model_ref.delete_model()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python [conda env:micromamba-snowml_1.0.12] *", - "language": "python", - "name": "conda-env-micromamba-snowml_1.0.12-py" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/snowflake/ml/registry/notebooks/Finetune_Registry.ipynb b/snowflake/ml/registry/notebooks/Finetune_Registry.ipynb deleted file mode 100644 index 2a65e89d..00000000 --- a/snowflake/ml/registry/notebooks/Finetune_Registry.ipynb +++ /dev/null @@ -1,565 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "a91e831d-5778-4321-87c2-2a4f3550b189", - "metadata": {}, - "source": [ - "# LLM Pretrain or Finetune Model Workflow for Model Registry" - ] - }, - { - "cell_type": "markdown", - "id": "024a8eb0-8306-4220-b25d-209aac880586", - "metadata": {}, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "markdown", - "id": "fa0e355f", - "metadata": {}, - "source": [ - "* Create a python3.8 conda env\n", - "`conda create --name {your_preferred_env_name} python=3.8`\n", - "* And, then install the latest snowparkML python package(minimum 1.0.12)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1ed66db9", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "!pip install --force-reinstall --no-deps /home/ubuntu/snowml/bazel-bin/snowflake/ml/snowflake_ml_python-1.0.12-py3-none-any.whl" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "292e9f48", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from IPython.display import display, HTML\n", - "display(HTML(\"\"))\n", - "\n", - "%load_ext autoreload\n", - "%autoreload 2\n" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "7585077b", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.snowpark import Session\n", - "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", - "import pandas as pd\n", - "from snowflake.ml.model.models import llm\n", - "from snowflake.ml.registry import model_registry\n", - "from IPython.display import JSON" - ] - }, - { - "cell_type": "markdown", - "id": "7a0294ba", - "metadata": {}, - "source": [ - "Connection config available at ~/.snowsql/config" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "f876232e", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. \n" - ] - } - ], - "source": [ - "session = Session.builder.configs(SnowflakeLoginOptions('connections.demo')).create()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c6aee8c9", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "session.get_current_database(), session.get_current_schema()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "72c16c14", - "metadata": {}, - "outputs": [], - "source": [ - "REGISTRY_DATABASE_NAME = \"HALU_MR\"\n", - "REGISTRY_SCHEMA_NAME = \"PUBLIC\"" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "c420807b", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:snowflake.snowpark:create_model_registry() is in private preview since 0.2.0. Do not use it in production. \n", - "WARNING:absl:The database HALU_MR already exists. Skipping creation.\n", - "WARNING:absl:The schema HALU_MR.PUBLIC already exists. Skipping creation.\n" - ] - } - ], - "source": [ - "\n", - "model_registry.create_model_registry(\n", - " session=session, database_name=REGISTRY_DATABASE_NAME, schema_name=REGISTRY_SCHEMA_NAME\n", - ")\n", - "registry = model_registry.ModelRegistry(\n", - " session=session, database_name=REGISTRY_DATABASE_NAME, schema_name=REGISTRY_SCHEMA_NAME\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "7d104673-c6fa-4eff-bec1-230c1d783881", - "metadata": {}, - "source": [ - "# Registry opertions" - ] - }, - { - "cell_type": "markdown", - "id": "6956d692-92c2-474f-9b5b-d69c2ef4e364", - "metadata": {}, - "source": [ - "## Define the model" - ] - }, - { - "cell_type": "markdown", - "id": "83ca550b-971d-46ba-a332-8218bc75ae00", - "metadata": {}, - "source": [ - "### Case1: Local Lora Finetune Weights\n", - "Lora finetune weights by huggingface PEFT library is supported." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "3b1d6e92-cb09-4576-a534-18522a040390", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "adapter_config.json adapter_model.bin\thalu_peft_ft training_args.bin\n" - ] - } - ], - "source": [ - "!ls /home/ubuntu/projects/test_ft_weights" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "18323af6", - "metadata": {}, - "outputs": [], - "source": [ - "options = llm.LLMOptions(\n", - " token=\"...\",\n", - " max_batch_size=100,\n", - ")\n", - "model = llm.LLM(\n", - " model_id_or_path=\"/home/ubuntu/projects/test_ft_weights\",\n", - " options=options\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "01883785-576f-435e-a5cb-dfd3243b75c6", - "metadata": {}, - "source": [ - "### Case2: Pretrain models" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "b55dd0d5-b76e-48c8-bd89-f1f2ebe559de", - "metadata": {}, - "outputs": [], - "source": [ - "options = llm.LLMOptions(\n", - " token=\"...\",\n", - " max_batch_size=100, \n", - ")\n", - "model = llm.LLM(\n", - " model_id_or_path=\"meta-llama/Llama-2-7b-chat-hf\",\n", - " options=options\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "23dcb9cf-975d-49b6-90c7-119969d94f9d", - "metadata": {}, - "source": [ - "## Log model" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "dac3fc56", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "svc_model = registry.log_model(\n", - " model_name='llm_notebook_ft',\n", - " model_version='v1',\n", - " model=model,\n", - " options={\"embed_local_ml_library\": True},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "c725ec11-8a30-467f-813f-261971ec65fd", - "metadata": {}, - "source": [ - "## Deploy" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "b17b1fbb", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/conda/envs/pytorch/lib/python3.10/site-packages/snowflake/ml/model/_packager/model_env/model_env.py:353: UserWarning: Found dependencies specified as pip requirements. This may prevent model deploying to Snowflake Warehouse.\n", - " warnings.warn(\n", - "WARNING:snowflake.ml.model._deploy_client.snowservice.deploy:Debug model is enabled, deployment artifacts will be available in /tmp/tmpyp2rz595\n", - "WARNING:snowflake.ml.model._deploy_client.snowservice.deploy:Similar environment detected. Using existing image sfengineering-servicesnow.registry.snowflakecomputing.com/halu_ft_db/public/haul_repo/c125a958091b70d924d69b379b55ee20cbd8157e:latest to skip image build. To disable this feature, set 'force_image_build=True' in deployment options\n", - "WARNING:snowflake.ml.model._deploy_client.utils.snowservice_client:Best-effort log streaming from SPCS will be enabled when python logging level is set to INFO.Alternatively, you can also query the logs by running the query 'CALL SYSTEM$GET_SERVICE_LOGS('HALU_MR.PUBLIC.service_1a0ec2427e5511eea17e06f9498c0da3', '0', 'inference-server')'\n" - ] - } - ], - "source": [ - "from snowflake.ml.model import deploy_platforms\n", - "\n", - "deployment_options = {\n", - " \"compute_pool\": 'BUILD_2023_POOL',\n", - " \"num_gpus\": 1,\n", - " \"image_repo\": 'sfengineering-servicesnow.registry.snowflakecomputing.com/halu_ft_db/public/haul_repo',\n", - " \"enable_remote_image_build\": True,\n", - " \"debug_mode\": True,\n", - "}\n", - " \n", - "deploy_info = svc_model.deploy(\n", - " deployment_name=\"llm_notebook_ft_1\",\n", - " platform=deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES,\n", - " permanent=True,\n", - " options=deployment_options\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "4a80b8bb-2191-4342-acc4-f44f817271c3", - "metadata": {}, - "source": [ - "## Prediction" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "9475d6cd-5222-4bcb-9883-9d8924354d6a", - "metadata": {}, - "outputs": [], - "source": [ - "PROMPT_TEMPLATE = \"\"\"\n", - "\n", - "[INST] <>\n", - "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n", - "<>\n", - "### Instruction:\n", - "Extract JSON response with 'location' and 'toy_list' as keys. Start response by \"{\".\n", - "'location': Location of the caller. Include city only.\n", - "'toy_list\": List of toy names from the caller.\n", - "\n", - "### Input:\n", - "\"\"\"\n", - "\n", - "def build_prompt(input):\n", - " return PROMPT_TEMPLATE + input + \"\\n[/INST]\"" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "d067a009-567d-4869-9e8a-44694e169cc0", - "metadata": {}, - "outputs": [], - "source": [ - "df = pd.read_json('/home/ubuntu/projects/v8.jsonl', lines=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "8ca2e858-fa83-44d9-aca5-8a64ccc78975", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'language': 'EN',\n", - " 'transcript': \"caller: Hello!\\nfrosty: Well, hello! Who's spreading holiday cheer with me today?\\ncaller: I'm Max from Sydney.\\nfrosty: Hello, Max! Can you tell me what's on your wish list this holiday?\\ncaller: Hmm, I am not sure. I guess I like cars.\\nfrosty: We have a fun Bluey car. It's very cool. And also, there's a Teenage Mutant Ninja Turtles pizza delivery van! It's really fun.\\ncaller: Oh, the bluey car sounds cool.\\nfrosty: Great choice, Max! By the way, how do you plan to celebrate the holiday season with your family?\\ncaller: We're going to the beach! It's summer here in Sydney.\\nfrosty: Oh, that sounds wonderful, Max. So, we will put the Bluey car on your holiday wish list, okay?\\ncaller: Yes, please!\\nfrosty: It’s all done. I hope your holiday is filled with joy and fun!\",\n", - " 'name': 'Max',\n", - " 'location': 'Sydney',\n", - " 'toy_list': ['Bluey Convertible and Figures']}" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.iloc[0].to_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "fdbfa6ef-e179-44e1-898b-8c103cf09d4d", - "metadata": {}, - "outputs": [], - "source": [ - "dfl = df['transcript'].to_list()[:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "c484ec44-672d-4269-830b-42ec037cef13", - "metadata": {}, - "outputs": [], - "source": [ - "prompts = [build_prompt(t) for t in dfl]" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "e8377d97-a70d-4709-b1c7-ea638634a557", - "metadata": {}, - "outputs": [], - "source": [ - "input_df = pd.DataFrame({'input': prompts})" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "3477c070-f067-471e-83b1-302dfec392b9", - "metadata": {}, - "outputs": [], - "source": [ - "res = svc_model.predict(\n", - " deployment_name='llm_notebook_ft_1',\n", - " data=input_df\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "id": "f32e6498", - "metadata": {}, - "outputs": [], - "source": [ - "pd.set_option('display.max_colwidth', None)" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "a467afb6", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
generated_text
0{\"toy_list\": [\"Bluey Convertible and Figures\", \"Teenage Mutant Ninja Turtles: Mutant Mayhem Pizza Fire Delivery Van\"], \"location\": \"Sydney\"}
1{\"toy_list\": [\"Furby interactive plush toy\", \"Transformers Rise of the Beasts Beast-Mode Bumblebee\"], \"location\": \"London\"}
2{\"toy_list\": [\"Teenage Mutant Ninja Turtles: Mutant Mayhem Pizza Fire Delivery Van\"], \"location\": \"Auckland\"}
3{\"toy_list\": [\"Transformers Rise of the Beasts Beast-Mode Bumblebee\"], \"location\": \"Denver\"}
4{\"toy_list\": [\"Fingerlings\", \"Barbie Dreamhouse 2023\"], \"location\": \"Sydney\"}
5{\"toy_list\": [\"Barbie Science Lab Playset\", \"Furby interactive plush toy\"], \"location\": \"Houston, Texas\"}
6{\"toy_list\": [\"Star Wars LOLA animatronic droid\", \"Bluey Convertible and Figures\"], \"location\": \"Sydney\"}
7{\"toy_list\": [\"Teenage Mutant Ninja Turtles: Mutant Mayhem Pizza Fire Delivery Van\", \"Bitzee interactive pet\"], \"location\": \"Dublin\"}
8{\"toy_list\": [\"Barbie Science Lab Playset\"], \"location\": \"Melbourne, Australia\"}
9{\"toy_list\": [\"Sesame Street Monster Meditation Elmo\"], \"location\": \"Toronto\"}
\n", - "
" - ], - "text/plain": [ - " generated_text\n", - "0 {\"toy_list\": [\"Bluey Convertible and Figures\", \"Teenage Mutant Ninja Turtles: Mutant Mayhem Pizza Fire Delivery Van\"], \"location\": \"Sydney\"}\n", - "1 {\"toy_list\": [\"Furby interactive plush toy\", \"Transformers Rise of the Beasts Beast-Mode Bumblebee\"], \"location\": \"London\"}\n", - "2 {\"toy_list\": [\"Teenage Mutant Ninja Turtles: Mutant Mayhem Pizza Fire Delivery Van\"], \"location\": \"Auckland\"}\n", - "3 {\"toy_list\": [\"Transformers Rise of the Beasts Beast-Mode Bumblebee\"], \"location\": \"Denver\"}\n", - "4 {\"toy_list\": [\"Fingerlings\", \"Barbie Dreamhouse 2023\"], \"location\": \"Sydney\"}\n", - "5 {\"toy_list\": [\"Barbie Science Lab Playset\", \"Furby interactive plush toy\"], \"location\": \"Houston, Texas\"}\n", - "6 {\"toy_list\": [\"Star Wars LOLA animatronic droid\", \"Bluey Convertible and Figures\"], \"location\": \"Sydney\"}\n", - "7 {\"toy_list\": [\"Teenage Mutant Ninja Turtles: Mutant Mayhem Pizza Fire Delivery Van\", \"Bitzee interactive pet\"], \"location\": \"Dublin\"}\n", - "8 {\"toy_list\": [\"Barbie Science Lab Playset\"], \"location\": \"Melbourne, Australia\"}\n", - "9 {\"toy_list\": [\"Sesame Street Monster Meditation Elmo\"], \"location\": \"Toronto\"}" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "res" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "dcb3b9b1-16a2-4ed8-9f81-399660f2f530", - "metadata": {}, - "outputs": [], - "source": [ - "svc_model.delete_deployment(deployment_name='llm_notebook_ft_1')" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/snowflake/ml/registry/notebooks/Model Packaging Example.ipynb b/snowflake/ml/registry/notebooks/Model Packaging Example.ipynb deleted file mode 100644 index 3a4f1e59..00000000 --- a/snowflake/ml/registry/notebooks/Model Packaging Example.ipynb +++ /dev/null @@ -1,1400 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "5de3eb26", - "metadata": {}, - "source": [ - "# Model Packaging Example" - ] - }, - { - "cell_type": "markdown", - "id": "197efd00", - "metadata": {}, - "source": [ - "## Before Everything" - ] - }, - { - "cell_type": "markdown", - "id": "6ce97b36", - "metadata": {}, - "source": [ - "### Snowflake-ML-Python Installation" - ] - }, - { - "cell_type": "markdown", - "id": "1117c596", - "metadata": {}, - "source": [ - "- Please refer to our [landing page](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index) to install `snowflake-ml-python`." - ] - }, - { - "cell_type": "markdown", - "id": "7ed8032a", - "metadata": {}, - "source": [ - "### Local Installation" - ] - }, - { - "cell_type": "markdown", - "id": "741c249e", - "metadata": {}, - "source": [ - "- transformers>=4.31.0 (For GPT-2 and LLAMA 2 model inference example)\n", - "- tokenizers>=0.13.3 (For LLAMA 2 model inference example)\n", - "- tensorflow (For GPT-2 Example)\n", - "- xgboost==1.7.6 (For XGBoost GPU inference example)" - ] - }, - { - "cell_type": "markdown", - "id": "2bde8397", - "metadata": {}, - "source": [ - "### Additional Requirements" - ] - }, - { - "cell_type": "markdown", - "id": "2647a880", - "metadata": {}, - "source": [ - "- SPCS compute pool with at least 1 GPU (For all GPU inference on SPCS examples below)\n", - "\n", - "- Requested access to use LLama 2 model through HuggingFace (For LLAMA 2 model inference example)\n", - "\n", - "- A HuggingFace token with read access (For LLAMA 2 model inference example)\n", - "\n", - "- Download the News Category Dataset from https://www.kaggle.com/datasets/rmisra/news-category-dataset (For LLAMA 2 model inference example)" - ] - }, - { - "cell_type": "markdown", - "id": "99e58d8c", - "metadata": {}, - "source": [ - "### Setup Notebook" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "afd16ff5", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "d609ff44", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Scale cell width with the browser window to accommodate .show() commands for wider tables.\n", - "from IPython.display import display, HTML\n", - "\n", - "display(HTML(\"\"))" - ] - }, - { - "cell_type": "markdown", - "id": "1ac32c6f", - "metadata": {}, - "source": [ - "### Start Snowpark Session\n", - "\n", - "To avoid exposing credentials in Github, we use a small utility `SnowflakeLoginOptions`. It allows you to score your default credentials in `~/.snowsql/config` in the following format:\n", - "```\n", - "[connections]\n", - "accountname = # Account identifier to connect to Snowflake.\n", - "username = # User name in the account. Optional.\n", - "password = # User password. Optional.\n", - "dbname = # Default database. Optional.\n", - "schemaname = # Default schema. Optional.\n", - "warehousename = # Default warehouse. Optional.\n", - "#rolename = # Default role. Optional.\n", - "#authenticator = # Authenticator: 'snowflake', 'externalbrowser', etc\n", - "```\n", - "Please follow [this](https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings) for more details." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "b2efc0a8", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", - "from snowflake.snowpark import Session\n", - "\n", - "session = Session.builder.configs(SnowflakeLoginOptions()).create()" - ] - }, - { - "cell_type": "markdown", - "id": "dfa9ab88", - "metadata": {}, - "source": [ - "### Open/Create Model Registry" - ] - }, - { - "cell_type": "markdown", - "id": "b0a0c8a8", - "metadata": {}, - "source": [ - "A model registry needs to be created before it can be used. The creation will create a new database in the current account so the active role needs to have permissions to create a database. After the first creation, the model registry can be opened without the need to create it again." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "a95e3431", - "metadata": {}, - "outputs": [], - "source": [ - "REGISTRY_DATABASE_NAME = \"MODEL_REGISTRY\"\n", - "REGISTRY_SCHEMA_NAME = \"PUBLIC\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7fff21bc", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.registry import model_registry\n", - "\n", - "model_registry.create_model_registry(\n", - " session=session, database_name=REGISTRY_DATABASE_NAME, schema_name=REGISTRY_SCHEMA_NAME\n", - ")\n", - "registry = model_registry.ModelRegistry(\n", - " session=session, database_name=REGISTRY_DATABASE_NAME, schema_name=REGISTRY_SCHEMA_NAME\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "d76e14a1", - "metadata": {}, - "source": [ - "## Use with scikit-learn model" - ] - }, - { - "cell_type": "markdown", - "id": "c592d46c", - "metadata": {}, - "source": [ - "### Train A Small Scikit-learn Model" - ] - }, - { - "cell_type": "markdown", - "id": "378eb3ba", - "metadata": {}, - "source": [ - "The cell below trains a small model for demonstration purposes. The nature of the model does not matter, it is purely used to demonstrate the usage of the Model Packaging and Registry." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "8cf44218", - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn import svm\n", - "from sklearn.datasets import load_digits\n", - "import numpy as np\n", - "\n", - "digits = load_digits()\n", - "target_digit = 6\n", - "num_training_examples = 10\n", - "svc_gamma = 0.001\n", - "svc_C = 10.0\n", - "\n", - "clf = svm.SVC(gamma=svc_gamma, C=svc_C, probability=True)\n", - "\n", - "\n", - "def one_vs_all(dataset, digit):\n", - " return [x == digit for x in dataset]\n", - "\n", - "\n", - "# Train a classifier using num_training_examples and use the last 100 examples for test.\n", - "train_features = digits.data[:num_training_examples]\n", - "train_labels = one_vs_all(digits.target[:num_training_examples], target_digit)\n", - "clf.fit(train_features, train_labels)\n", - "\n", - "test_features = digits.data[-100:]\n", - "test_labels = one_vs_all(digits.target[-100:], target_digit)\n", - "prediction = clf.predict(test_features)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c25bd0d4", - "metadata": {}, - "outputs": [], - "source": [ - "print(prediction[:10])" - ] - }, - { - "cell_type": "markdown", - "id": "dda57d0b", - "metadata": {}, - "source": [ - "SVC has multiple method, for example, `predict_proba`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dd7ee333", - "metadata": {}, - "outputs": [], - "source": [ - "prediction_proba = clf.predict_proba(test_features)\n", - "print(prediction_proba[:10])" - ] - }, - { - "cell_type": "markdown", - "id": "317e7843", - "metadata": {}, - "source": [ - "### Register Model" - ] - }, - { - "cell_type": "markdown", - "id": "3b482561", - "metadata": {}, - "source": [ - "The call to `log_model` executes a few steps:\n", - "1. The given model object is serialized and uploaded to a stage.\n", - "1. An entry in the Model Registry is created for the model, referencing the model stage location.\n", - "1. Additional metadata is updated for the model as provided in the call.\n", - "\n", - "For the serialization to work, the model object needs to be serializable in python.\n", - "\n", - "Aso, you have to provide a sample input data so that we could infer the model signature for you, or you can specify the model signature manually." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "68705420", - "metadata": {}, - "outputs": [], - "source": [ - "SVC_MODEL_NAME = \"SIMPLE_SVC_MODEL\"\n", - "SVC_MODEL_VERSION = \"v1\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9d8ad06e", - "metadata": {}, - "outputs": [], - "source": [ - "# A name and model tags can be added to the model at registration time.\n", - "svc_model = registry.log_model(\n", - " model_name=SVC_MODEL_NAME,\n", - " model_version=SVC_MODEL_VERSION,\n", - " model=clf,\n", - " tags={\"stage\": \"testing\", \"classifier_type\": \"svm.SVC\", \"svc_gamma\": svc_gamma, \"svc_C\": svc_C},\n", - " sample_input_data=test_features[:10],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "45c75e28", - "metadata": {}, - "source": [ - "### Deploy Model and Batch Inference" - ] - }, - { - "cell_type": "markdown", - "id": "a8d496db", - "metadata": {}, - "source": [ - "We can also deploy the model we saved to the registry to warehouse and predict it in the warehouse.\n", - "\n", - "Although the model may contain multiple methods, every deployment can only have one target method, and you need to specify that when you deploy the model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7ecab97c", - "metadata": {}, - "outputs": [], - "source": [ - "svc_model.deploy(\n", - " deployment_name=\"svc_model_predict\",\n", - " target_method=\"predict\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3e150421", - "metadata": {}, - "outputs": [], - "source": [ - "remote_prediction = svc_model.predict(deployment_name=\"svc_model_predict\", data=test_features)\n", - "\n", - "print(\"Remote prediction:\", remote_prediction[:10])\n", - "\n", - "print(\"Result comparison:\", np.array_equal(prediction, remote_prediction[\"output_feature_0\"].values))" - ] - }, - { - "cell_type": "markdown", - "id": "6c1f3c07", - "metadata": {}, - "source": [ - "We can also deploy another method to warehouse." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9c6f189c", - "metadata": {}, - "outputs": [], - "source": [ - "svc_model.deploy(\n", - " deployment_name=\"svc_model_predict_proba\",\n", - " target_method=\"predict_proba\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "36a00e1e", - "metadata": {}, - "outputs": [], - "source": [ - "remote_prediction_proba = svc_model.predict(deployment_name=\"svc_model_predict_proba\", data=test_features)\n", - "\n", - "print(\"Remote prediction:\", remote_prediction_proba[:10])\n", - "\n", - "print(\"Result comparison:\", np.allclose(prediction_proba, remote_prediction_proba.values))" - ] - }, - { - "cell_type": "markdown", - "id": "dc2e2f5e", - "metadata": {}, - "source": [ - "## Use with customize model" - ] - }, - { - "cell_type": "markdown", - "id": "9bc58b66", - "metadata": {}, - "source": [ - "### Download a GPT-2 model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0ce2cca3", - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoModelForCausalLM, AutoTokenizer\n", - "\n", - "model_name = \"gpt2-medium\"\n", - "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", - "model = AutoModelForCausalLM.from_pretrained(model_name)" - ] - }, - { - "cell_type": "markdown", - "id": "03454cba", - "metadata": {}, - "source": [ - "### Store GPT-2 Model components locally" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "05a0e170", - "metadata": {}, - "outputs": [], - "source": [ - "ARTIFACTS_DIR = \"/tmp/gpt-2/\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f60d49c4", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "os.makedirs(os.path.join(ARTIFACTS_DIR, \"model\"), exist_ok=True)\n", - "os.makedirs(os.path.join(ARTIFACTS_DIR, \"tokenizer\"), exist_ok=True)\n", - "\n", - "model.save_pretrained(os.path.join(ARTIFACTS_DIR, \"model\"))\n", - "tokenizer.save_pretrained(os.path.join(ARTIFACTS_DIR, \"tokenizer\"))" - ] - }, - { - "cell_type": "markdown", - "id": "333118b7", - "metadata": {}, - "source": [ - "### Create a custom model using GPT-2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "49c27920", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.model import custom_model\n", - "import pandas as pd\n", - "\n", - "\n", - "class GPT2Model(custom_model.CustomModel):\n", - " def __init__(self, context: custom_model.ModelContext) -> None:\n", - " super().__init__(context)\n", - "\n", - " self.model = AutoModelForCausalLM.from_pretrained(self.context.path(\"model\"))\n", - " self.tokenizer = AutoTokenizer.from_pretrained(self.context.path(\"tokenizer\"))\n", - "\n", - " @custom_model.inference_api\n", - " def predict(self, X: pd.DataFrame) -> pd.DataFrame:\n", - " def _generate(input_text: str) -> str:\n", - " input_ids = self.tokenizer.encode(input_text, return_tensors=\"pt\")\n", - "\n", - " output = self.model.generate(input_ids, max_length=50, do_sample=True, top_p=0.95, top_k=60)\n", - " generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)\n", - "\n", - " return generated_text\n", - "\n", - " res_df = pd.DataFrame({\"output\": pd.Series.apply(X[\"input\"], _generate)})\n", - " return res_df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "36438fd5", - "metadata": {}, - "outputs": [], - "source": [ - "gpt_model = GPT2Model(\n", - " custom_model.ModelContext(\n", - " models={},\n", - " artifacts={\n", - " \"model\": os.path.join(ARTIFACTS_DIR, \"model\"),\n", - " \"tokenizer\": os.path.join(ARTIFACTS_DIR, \"tokenizer\"),\n", - " },\n", - " )\n", - ")\n", - "\n", - "gpt_model.predict(pd.DataFrame({\"input\": [\"Hello, are you GPT?\"]}))" - ] - }, - { - "cell_type": "markdown", - "id": "e111b527", - "metadata": {}, - "source": [ - "### Register the custom model" - ] - }, - { - "cell_type": "markdown", - "id": "c27ed16a", - "metadata": {}, - "source": [ - "Here, how to specify dependencies and model signature manually is shown." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0c22ec2f", - "metadata": {}, - "outputs": [], - "source": [ - "GPT2_MODEL_NAME = \"GPT2_MODEL\"\n", - "GPT2_MODEL_VERSION = \"v1\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3a913530", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.model import model_signature\n", - "\n", - "gpt_model_ref = registry.log_model(\n", - " model_name=GPT2_MODEL_NAME,\n", - " model_version=GPT2_MODEL_VERSION,\n", - " model=gpt_model,\n", - " conda_dependencies=[\"tensorflow\", \"transformers\"],\n", - " signatures={\n", - " \"predict\": model_signature.ModelSignature(\n", - " inputs=[model_signature.FeatureSpec(name=\"input\", dtype=model_signature.DataType.STRING)],\n", - " outputs=[model_signature.FeatureSpec(name=\"output\", dtype=model_signature.DataType.STRING)],\n", - " )\n", - " },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "e634f4c1", - "metadata": {}, - "source": [ - "### Deploy the model and predict" - ] - }, - { - "cell_type": "markdown", - "id": "fc0f289d", - "metadata": {}, - "source": [ - "Relax version is an option that allow the deployer tries to relax the version specifications when initial attempt to\n", - "resolve the dependencies in Snowflake Anaconda Channel fails." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f6d64cb0", - "metadata": {}, - "outputs": [], - "source": [ - "gpt_model_ref.deploy(\n", - " deployment_name=\"gpt_model_predict\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24702087", - "metadata": {}, - "outputs": [], - "source": [ - "gpt_model_ref.predict(deployment_name=\"gpt_model_predict\", data=pd.DataFrame({\"input\": [\"Hello, are you GPT?\"]}))" - ] - }, - { - "cell_type": "markdown", - "id": "b44a55b7", - "metadata": {}, - "source": [ - "## Use with XGBoost Model, Snowpark DataFrame and permanent deployment" - ] - }, - { - "cell_type": "markdown", - "id": "05e45630", - "metadata": {}, - "source": [ - "### Prepare dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16debd21", - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.datasets import fetch_kddcup99\n", - "\n", - "DATA_TABLE_NAME = \"KDDCUP99_DATASET\"\n", - "\n", - "kddcup99_data = fetch_kddcup99(as_frame=True)\n", - "kddcup99_sp_df = session.create_dataframe(kddcup99_data.frame)\n", - "kddcup99_sp_df.write.mode(\"overwrite\").save_as_table(DATA_TABLE_NAME)" - ] - }, - { - "cell_type": "markdown", - "id": "771cad94", - "metadata": {}, - "source": [ - "### Preprocessing Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "04b976c8", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.modeling.preprocessing import one_hot_encoder, ordinal_encoder, standard_scaler\n", - "import snowflake.snowpark.functions as F\n", - "\n", - "quote_fn = lambda x: f'\"{x}\"'\n", - "\n", - "ONE_HOT_ENCODE_COL_NAMES = [\"protocol_type\", \"service\", \"flag\"]\n", - "ORDINAL_ENCODE_COL_NAMES = [\"labels\"]\n", - "STANDARD_SCALER_COL_NAMES = [\n", - " \"duration\",\n", - " \"src_bytes\",\n", - " \"dst_bytes\",\n", - " \"wrong_fragment\",\n", - " \"urgent\",\n", - " \"hot\",\n", - " \"num_failed_logins\",\n", - " \"num_compromised\",\n", - " \"num_root\",\n", - " \"num_file_creations\",\n", - " \"num_shells\",\n", - " \"num_access_files\",\n", - " \"num_outbound_cmds\",\n", - " \"count\",\n", - " \"srv_count\",\n", - " \"dst_host_count\",\n", - " \"dst_host_srv_count\",\n", - "]\n", - "\n", - "TRAIN_SIZE_K = 0.2\n", - "kddcup99_data = session.table(DATA_TABLE_NAME)\n", - "kddcup99_data = kddcup99_data.with_columns(\n", - " list(map(quote_fn, ONE_HOT_ENCODE_COL_NAMES + ORDINAL_ENCODE_COL_NAMES)),\n", - " [\n", - " F.to_char(col_name, \"utf-8\")\n", - " for col_name in list(map(quote_fn, ONE_HOT_ENCODE_COL_NAMES + ORDINAL_ENCODE_COL_NAMES))\n", - " ],\n", - ")\n", - "kddcup99_sp_df_train, kddcup99_sp_df_test = tuple(\n", - " kddcup99_data.random_split([TRAIN_SIZE_K, 1 - TRAIN_SIZE_K], seed=2568)\n", - ")\n", - "\n", - "ft_one_hot_encoder = one_hot_encoder.OneHotEncoder(\n", - " handle_unknown=\"ignore\",\n", - " input_cols=list(map(quote_fn, ONE_HOT_ENCODE_COL_NAMES)),\n", - " output_cols=ONE_HOT_ENCODE_COL_NAMES,\n", - " drop_input_cols=True,\n", - ")\n", - "ft_one_hot_encoder = ft_one_hot_encoder.fit(kddcup99_sp_df_train)\n", - "kddcup99_sp_df_train = ft_one_hot_encoder.transform(kddcup99_sp_df_train)\n", - "kddcup99_sp_df_test = ft_one_hot_encoder.transform(kddcup99_sp_df_test)\n", - "\n", - "ft_ordinal_encoder = ordinal_encoder.OrdinalEncoder(\n", - " input_cols=list(map(quote_fn, ORDINAL_ENCODE_COL_NAMES)),\n", - " output_cols=list(map(quote_fn, ORDINAL_ENCODE_COL_NAMES)),\n", - " drop_input_cols=True,\n", - ")\n", - "ft_ordinal_encoder = ft_ordinal_encoder.fit(kddcup99_sp_df_train)\n", - "kddcup99_sp_df_train = ft_ordinal_encoder.transform(kddcup99_sp_df_train)\n", - "kddcup99_sp_df_test = ft_ordinal_encoder.transform(kddcup99_sp_df_test)\n", - "\n", - "ft_standard_scaler = standard_scaler.StandardScaler(\n", - " input_cols=list(map(quote_fn, STANDARD_SCALER_COL_NAMES)),\n", - " output_cols=list(map(quote_fn, STANDARD_SCALER_COL_NAMES)),\n", - " drop_input_cols=True,\n", - ")\n", - "ft_standard_scaler = ft_standard_scaler.fit(kddcup99_sp_df_train)\n", - "kddcup99_sp_df_train = ft_standard_scaler.transform(kddcup99_sp_df_train)\n", - "kddcup99_sp_df_test = ft_standard_scaler.transform(kddcup99_sp_df_test)" - ] - }, - { - "cell_type": "markdown", - "id": "d4d25ee7", - "metadata": {}, - "source": [ - "### Train an XGBoost model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ea1e3bee", - "metadata": {}, - "outputs": [], - "source": [ - "XGB_MODEL_NAME = \"XGB_MODEL_KDDCUP99\"\n", - "XGB_MODEL_VERSION = \"v1\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "68bb0f77", - "metadata": {}, - "outputs": [], - "source": [ - "import xgboost\n", - "\n", - "regressor = xgboost.XGBClassifier(objective=\"multi:softprob\", n_estimators=500, reg_lambda=1, gamma=0, max_depth=5)\n", - "kddcup99_pd_df_train = kddcup99_sp_df_train.to_pandas()\n", - "regressor.fit(\n", - " kddcup99_pd_df_train.drop(columns=[\"labels\"]),\n", - " kddcup99_pd_df_train[\"labels\"],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "2e9446fc", - "metadata": {}, - "source": [ - "### Log the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1bf06733", - "metadata": {}, - "outputs": [], - "source": [ - "xgb_model = registry.log_model(\n", - " model_name=XGB_MODEL_NAME,\n", - " model_version=XGB_MODEL_VERSION,\n", - " model=regressor,\n", - " sample_input_data=kddcup99_sp_df_train.drop('\"labels\"'),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "5948b7c8", - "metadata": {}, - "source": [ - "### Deploy the model permanently" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b1f4cc21", - "metadata": {}, - "outputs": [], - "source": [ - "xgb_model.deploy(\n", - " deployment_name=\"xgb_model_predict\", target_method=\"predict\", permanent=True, options={\"relax_version\": True}\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "e560bd8d", - "metadata": {}, - "source": [ - "### Predict with Snowpark DataFrame" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9578a89b", - "metadata": {}, - "outputs": [], - "source": [ - "sp_res = xgb_model.predict(deployment_name=\"xgb_model_predict\", data=kddcup99_sp_df_test)\n", - "sp_res.show()" - ] - }, - { - "cell_type": "markdown", - "id": "08614b16", - "metadata": {}, - "source": [ - "### Prepare another SQL connection and registry" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "421ff7e1", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", - "from snowflake.snowpark import Session\n", - "\n", - "another_session = Session.builder.configs(SnowflakeLoginOptions()).create()" - ] - }, - { - "cell_type": "markdown", - "id": "d1e99456", - "metadata": {}, - "source": [ - "### Call the deployed permanent UDF" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8523b768", - "metadata": {}, - "outputs": [], - "source": [ - "another_registry = model_registry.ModelRegistry(\n", - " session=another_session, database_name=REGISTRY_DATABASE_NAME, schema_name=REGISTRY_SCHEMA_NAME\n", - ")\n", - "xgb_model_ref = model_registry.ModelReference(\n", - " registry=another_registry,\n", - " model_name=XGB_MODEL_NAME,\n", - " model_version=XGB_MODEL_VERSION,\n", - ")\n", - "xgb_model_ref.list_deployments().show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ae714013", - "metadata": {}, - "outputs": [], - "source": [ - "sp_res = xgb_model_ref.predict(\n", - " deployment_name=\"xgb_model_predict\", data=another_session.create_dataframe(kddcup99_sp_df_test.to_pandas())\n", - ")\n", - "sp_res.show()" - ] - }, - { - "cell_type": "markdown", - "id": "6b4eabe1", - "metadata": {}, - "source": [ - "### Remove the deployed UDF" - ] - }, - { - "cell_type": "markdown", - "id": "be5ecdb5", - "metadata": {}, - "source": [ - "This would be done by calling delete_deployment in the registry." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a8eceb2a", - "metadata": {}, - "outputs": [], - "source": [ - "xgb_model_ref.delete_deployment(deployment_name=\"xgb_model_predict\")" - ] - }, - { - "cell_type": "markdown", - "id": "5df0ed62", - "metadata": {}, - "source": [ - "### Deploy to SPCS and using GPU for inference" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2f67748d", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.model import deploy_platforms\n", - "\n", - "xgb_model.deploy(\n", - " deployment_name=\"xgb_model_predict_spcs\",\n", - " target_method=\"predict\",\n", - " platform=deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES,\n", - " permanent=True,\n", - " options={\"compute_pool\": \"...\", \"num_gpus\": 1, \"num_workers\": 24},\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bff75a87", - "metadata": {}, - "outputs": [], - "source": [ - "sp_res = xgb_model.predict(deployment_name=\"xgb_model_predict_spcs\", data=kddcup99_sp_df_test)\n", - "sp_res.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3372b222", - "metadata": {}, - "outputs": [], - "source": [ - "xgb_model.delete_deployment(deployment_name=\"xgb_model_predict_spcs\")" - ] - }, - { - "cell_type": "markdown", - "id": "2114bb8c", - "metadata": {}, - "source": [ - "## Using LLM with HuggingFace Pipeline" - ] - }, - { - "cell_type": "markdown", - "id": "07bb4d94", - "metadata": {}, - "source": [ - "### Preparing Data into Snowflake" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "280c8644", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "news_dataset = pd.read_json(\"News_Category_Dataset_v3.json\", lines=True).convert_dtypes()" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "c8500b9a", - "metadata": {}, - "outputs": [], - "source": [ - "NEWS_DATA_TABLE_NAME = \"news_dataset\"\n", - "news_dataset_sp_df = session.create_dataframe(news_dataset)\n", - "news_dataset_sp_df.write.mode(\"overwrite\").save_as_table(NEWS_DATA_TABLE_NAME)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "533ef50c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", - "|\"headline\" |\"category\" |\"short_description\" |\n", - "-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", - "|Where Do We Come From? |WEIRD NEWS |My dear readers, Denial is not a river in Egypt. What will it take to wake up? Or have we allowed ourselves to be so disempowered that we have thrown in the towel? If so, is self destruction imminent? I would hope not. |\n", - "|Sen. Mike Enzi Wins Primary, Will Face Democrat In November |POLITICS | |\n", - "|Start-Art |ARTS |For too long the contemporary art world has been the exclusive redoubt of insiders, tastemakers, and a privileged elite. Gertrude has exploded this paradigm, and fashioned a conversational forum that democratizes and demystifies contemporary art. |\n", - "|Tony Wagner's The Global Achievement Gap Is More Relevant Than Ever |EDUCATION |We have always had plenty of soul-killing, drill and kill instruction. In the past, however, it was seen as education malpractice. Now, it is imposed in the name of \"reform.\" |\n", - "|Why Are Shoes So Damn Expensive? |COMEDY | |\n", - "|Create an Online Gift Registry for Your Baby and Get the Items You Need and Want |PARENTS |Online gift registries are making it easy and convenient for expectant parents to get everything on their wish lists. |\n", - "|I'm an American Citizen. If You Want to Remain a Cop, Don't Violate My Human Rights |POLITICS |This idea that cops get to say when and where constitutional rights apply is so very, deeply misguided that I am shocked anyone could type it out without coming to their senses mid-sentence. |\n", - "|Jose Antonio Vargas Among Undocumented Immigrants Making Urgent Plea To Obama |POLITICS | |\n", - "|The 1 Minute Blog. Emotions Consuming You? |GOOD NEWS | |\n", - "|Beyond the Ice-Bucket: A Deeper Challenge |IMPACT |Life is full of challenges. Some are profoundly life-changing, and some are cold and wet. The ice-bucket challenge was not only great summer fun, it was also one of the most positive and productive viral campaigns in history. |\n", - "-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", - "\n" - ] - } - ], - "source": [ - "news_dataset_sp = session.table(NEWS_DATA_TABLE_NAME).select('\"headline\"','\"category\"','\"short_description\"')\n", - "\n", - "news_dataset_sp.show(max_width=600)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "7d9c076b", - "metadata": {}, - "outputs": [], - "source": [ - "LLM_MODEL_NAME = \"llama-2-7b-chat\"\n", - "LLM_MODEL_VERSION = \"v1\"" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "d9d072bc", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.model.models import huggingface_pipeline\n", - "\n", - "llama_model = huggingface_pipeline.HuggingFacePipelineModel(\n", - " task=\"text-generation\",\n", - " model=\"meta-llama/Llama-2-7b-chat-hf\",\n", - " token=\"...\", # Put your HuggingFace token here.\n", - " return_full_text=False,\n", - " max_new_tokens=100,\n", - " batch_size=1,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "29b00bc2", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:snowflake.snowpark:ModelRegistry.log_model() is in private preview since 0.2.0. Do not use it in production. \n", - "WARNING:snowflake.snowpark:ModelRegistry.list_models() is in private preview since 0.2.0. Do not use it in production. \n" - ] - } - ], - "source": [ - "llama_model_ref = registry.log_model(\n", - " model_name=LLM_MODEL_NAME,\n", - " model_version=LLM_MODEL_VERSION,\n", - " model=llama_model,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "4d6c93ea", - "metadata": {}, - "outputs": [], - "source": [ - "DEPLOYMENT_NAME=\"llama_predict\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c464d3aa", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.model import deploy_platforms\n", - "\n", - "llama_model_ref.deploy(\n", - " deployment_name=DEPLOYMENT_NAME,\n", - " platform=deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES,\n", - " permanent=True,\n", - " options={\n", - " \"compute_pool\": \"...\",\n", - " \"num_gpus\": 1,\n", - " },\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "e68f70b0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", - "|\"headline\" |\"category\" |\"short_description\" |\"inputs\" |\n", - "-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", - "|Where Do We Come From? |WEIRD NEWS |My dear readers, Denial is not a river in Egypt. What will it take to wake up? Or have we allowed ourselves to be so disempowered that we have thrown in the towel? If so, is self destruction imminent? I would hope not. |[INST] <> |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} |\n", - "| | | |<> |\n", - "| | | | Where Do We Come From? My dear readers, Denial is not a river in Egypt. What will it take to wake up? Or have we allowed ourselves to be so disempowered that we have thrown in the towel? If so, is self destruction imminent? I would hope not. [/INST] |\n", - "|Sen. Mike Enzi Wins Primary, Will Face Democrat In November |POLITICS | |[INST] <> |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} |\n", - "| | | |<> |\n", - "| | | | Sen. Mike Enzi Wins Primary, Will Face Democrat In November [/INST] |\n", - "|Start-Art |ARTS |For too long the contemporary art world has been the exclusive redoubt of insiders, tastemakers, and a privileged elite. Gertrude has exploded this paradigm, and fashioned a conversational forum that democratizes and demystifies contemporary art. |[INST] <> |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} |\n", - "| | | |<> |\n", - "| | | | Start-Art For too long the contemporary art world has been the exclusive redoubt of insiders, tastemakers, and a privileged elite. Gertrude has exploded this paradigm, and fashioned a conversational forum that democratizes and demystifies contemporary art. [/INST] |\n", - "|Tony Wagner's The Global Achievement Gap Is More Relevant Than Ever |EDUCATION |We have always had plenty of soul-killing, drill and kill instruction. In the past, however, it was seen as education malpractice. Now, it is imposed in the name of \"reform.\" |[INST] <> |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} |\n", - "| | | |<> |\n", - "| | | | Tony Wagner's The Global Achievement Gap Is More Relevant Than Ever We have always had plenty of soul-killing, drill and kill instruction. In the past, however, it was seen as education malpractice. Now, it is imposed in the name of \"reform.\" [/INST] |\n", - "|Why Are Shoes So Damn Expensive? |COMEDY | |[INST] <> |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} |\n", - "| | | |<> |\n", - "| | | | Why Are Shoes So Damn Expensive? [/INST] |\n", - "|Create an Online Gift Registry for Your Baby and Get the Items You Need and Want |PARENTS |Online gift registries are making it easy and convenient for expectant parents to get everything on their wish lists. |[INST] <> |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} |\n", - "| | | |<> |\n", - "| | | | Create an Online Gift Registry for Your Baby and Get the Items You Need and Want Online gift registries are making it easy and convenient for expectant parents to get everything on their wish lists. [/INST] |\n", - "|I'm an American Citizen. If You Want to Remain a Cop, Don't Violate My Human Rights |POLITICS |This idea that cops get to say when and where constitutional rights apply is so very, deeply misguided that I am shocked anyone could type it out without coming to their senses mid-sentence. |[INST] <> |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} |\n", - "| | | |<> |\n", - "| | | | I'm an American Citizen. If You Want to Remain a Cop, Don't Violate My Human Rights This idea that cops get to say when and where constitutional rights apply is so very, deeply misguided that I am shocked anyone could type it out without coming to their senses mid-sentence. [/INST] |\n", - "|Jose Antonio Vargas Among Undocumented Immigrants Making Urgent Plea To Obama |POLITICS | |[INST] <> |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} |\n", - "| | | |<> |\n", - "| | | | Jose Antonio Vargas Among Undocumented Immigrants Making Urgent Plea To Obama [/INST] |\n", - "|The 1 Minute Blog. Emotions Consuming You? |GOOD NEWS | |[INST] <> |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} |\n", - "| | | |<> |\n", - "| | | | The 1 Minute Blog. Emotions Consuming You? [/INST] |\n", - "|Beyond the Ice-Bucket: A Deeper Challenge |IMPACT |Life is full of challenges. Some are profoundly life-changing, and some are cold and wet. The ice-bucket challenge was not only great summer fun, it was also one of the most positive and productive viral campaigns in history. |[INST] <> |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} |\n", - "| | | |<> |\n", - "| | | | Beyond the Ice-Bucket: A Deeper Challenge Life is full of challenges. Some are profoundly life-changing, and some are cold and wet. The ice-bucket challenge was not only great summer fun, it was also one of the most positive and productive viral campaigns in history. [/INST] |\n", - "-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", - "\n" - ] - } - ], - "source": [ - "import snowflake.snowpark.functions as F\n", - "\n", - "prompt_prefix = \"\"\"[INST] <>\n", - "Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} \n", - "As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8}\n", - "<>\n", - "\"\"\"\n", - "prompt_suffix = \"[/INST]\"\n", - "\n", - "input_df = news_dataset_sp.with_column(\n", - " '\"inputs\"',\n", - " F.concat_ws(\n", - " F.lit(\" \"), F.lit(prompt_prefix), F.col('\"headline\"'), F.col('\"short_description\"'), F.lit(prompt_suffix)\n", - " ),\n", - ")\n", - "\n", - "input_df.show(max_width=600)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "05aa68e8", - "metadata": {}, - "outputs": [], - "source": [ - "res = llama_model_ref.predict(\n", - " deployment_name=DEPLOYMENT_NAME,\n", - " data=input_df\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "783be746", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", - "|\"headline\" |\"category\" |\"short_description\" |\"inputs\" |\"outputs\" |\"PRED_CATEGORY\" |\"PRED_KEYWORDS\" |\"PRED_IMPORTANCE\" |\n", - "---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", - "|Where Do We Come From? |WEIRD NEWS |My dear readers, Denial is not a river in Egypt. What will it take to wake up? Or have we allowed ourselves to be so disempowered that we have thrown in the towel? If so, is self destruction imminent? I would hope not. |[INST] <> |[{\"generated_text\": \" Here is the JSON output for the given text:\\n{\\n\\\"category\\\": \\\"Society\\\",\\n\\\"keywords\\\": [\\n\\\"denial\\\",\\n\\\"Egypt\\\",\\n\\\"self-destruction\\\"\\n],\\n\\\"importance\\\": 7\\n\\n}\\n\\nNote: The JSON output is valid according to the provided schema, and includes the required properties and values.\"}] |\"Society\" |[ |7 |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} | | | \"denial\", | |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} | | | \"Egypt\", | |\n", - "| | | |<> | | | \"self-destruction\" | |\n", - "| | | | Where Do We Come From? My dear readers, Denial is not a river in Egypt. What will it take to wake up? Or have we allowed ourselves to be so disempowered that we have thrown in the towel? If so, is self destruction imminent? I would hope not. [/INST] | | |] | |\n", - "|Sen. Mike Enzi Wins Primary, Will Face Democrat In November |POLITICS | |[INST] <> |[{\"generated_text\": \" Sure, here's the JSON output for the input \\\"Sen. Mike Enzi Wins Primary, Will Face Democrat In November\\\":\\n{\\n\\\"category\\\": \\\"Politics\\\",\\n\\\"keywords\\\": [\\\"Mike Enzi\\\", \\\"primary\\\", \\\"election\\\", \\\"Democrat\\\"],\\n\\\"importance\\\": 7\\n\\n}\\n\\nNote that I've included the required fields \\\"category\\\", \\\"keywords\\\", and \\\"importance\\\" with the appropriate\"}] |\"Politics\" |[ |7 |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} | | | \"Mike Enzi\", | |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} | | | \"primary\", | |\n", - "| | | |<> | | | \"election\", | |\n", - "| | | | Sen. Mike Enzi Wins Primary, Will Face Democrat In November [/INST] | | | \"Democrat\" | |\n", - "| | | | | | |] | |\n", - "|Start-Art |ARTS |For too long the contemporary art world has been the exclusive redoubt of insiders, tastemakers, and a privileged elite. Gertrude has exploded this paradigm, and fashioned a conversational forum that democratizes and demystifies contemporary art. |[INST] <> |[{\"generated_text\": \" Here is the JSON output for the given input:\\n{\\n\\\"category\\\": \\\"Art\\\",\\n\\\"keywords\\\": [\\n\\\"Gertrude\\\",\\n\\\"contemporary art\\\",\\n\\\"democratization\\\",\\n\\\"demystification\\\"\\n],\\n\\\"importance\\\": 9\\n\\n}\\n\\nNote that I have included the required fields \\\"category\\\", \\\"keywords\\\", and \\\"importance\\\" in the JSON output, and have also included the specified types for each\"}] |\"Art\" |[ |9 |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} | | | \"Gertrude\", | |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} | | | \"contemporary art\", | |\n", - "| | | |<> | | | \"democratization\", | |\n", - "| | | | Start-Art For too long the contemporary art world has been the exclusive redoubt of insiders, tastemakers, and a privileged elite. Gertrude has exploded this paradigm, and fashioned a conversational forum that democratizes and demystifies contemporary art. [/INST] | | | \"demystification\" | |\n", - "| | | | | | |] | |\n", - "|Tony Wagner's The Global Achievement Gap Is More Relevant Than Ever |EDUCATION |We have always had plenty of soul-killing, drill and kill instruction. In the past, however, it was seen as education malpractice. Now, it is imposed in the name of \"reform.\" |[INST] <> |[{\"generated_text\": \" Sure, here is the JSON output for the given input:\\n{\\n\\\"category\\\": \\\"Education\\\",\\n\\\"keywords\\\": [\\\"soul-killing\\\", \\\"drill and kill\\\", \\\"education malpractice\\\", \\\"reform\\\"],\\n\\\"importance\\\": 7\\n}\\nNote that I have included the required properties and their respective values as per the provided JSON schema.\"}] |\"Education\" |[ |7 |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} | | | \"soul-killing\", | |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} | | | \"drill and kill\", | |\n", - "| | | |<> | | | \"education malpractice\", | |\n", - "| | | | Tony Wagner's The Global Achievement Gap Is More Relevant Than Ever We have always had plenty of soul-killing, drill and kill instruction. In the past, however, it was seen as education malpractice. Now, it is imposed in the name of \"reform.\" [/INST] | | | \"reform\" | |\n", - "| | | | | | |] | |\n", - "|Why Are Shoes So Damn Expensive? |COMEDY | |[INST] <> |[{\"generated_text\": \" Here is the JSON output for the given text:\\n{\\n\\\"category\\\": \\\"Society\\\",\\n\\\"keywords\\\": [\\\"shoes\\\", \\\"expensive\\\", \\\"prices\\\", \\\"consumerism\\\"],\\n\\\"importance\\\": 7\\n}\\n\\nNote: The output is in compliance with the provided JSON schema, and includes the required properties and values.\"}] |\"Society\" |[ |7 |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} | | | \"shoes\", | |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} | | | \"expensive\", | |\n", - "| | | |<> | | | \"prices\", | |\n", - "| | | | Why Are Shoes So Damn Expensive? [/INST] | | | \"consumerism\" | |\n", - "| | | | | | |] | |\n", - "|Create an Online Gift Registry for Your Baby and Get the Items You Need and Want |PARENTS |Online gift registries are making it easy and convenient for expectant parents to get everything on their wish lists. |[INST] <> |[{\"generated_text\": \" Sure! Here is the JSON output for the given input:\\n{\\n\\\"category\\\": \\\"Parenting\\\",\\n\\\"keywords\\\": [\\n\\\"gift registry\\\",\\n\\\"baby gifts\\\",\\n\\\"online registry\\\",\\n\\\"registry ideas\\\",\\n\\\"baby shower gifts\\\"\\n],\\n\\\"importance\\\": 7\\n\\n}\\n\\nNote that I have included the \\\"category\\\" and \\\"keywords\\\" properties as requested, and have also included the \\\"\"}] |\"Parenting\" |[ |7 |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} | | | \"gift registry\", | |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} | | | \"baby gifts\", | |\n", - "| | | |<> | | | \"online registry\", | |\n", - "| | | | Create an Online Gift Registry for Your Baby and Get the Items You Need and Want Online gift registries are making it easy and convenient for expectant parents to get everything on their wish lists. [/INST] | | | \"registry ideas\", | |\n", - "| | | | | | | \"baby shower gifts\" | |\n", - "| | | | | | |] | |\n", - "|I'm an American Citizen. If You Want to Remain a Cop, Don't Violate My Human Rights |POLITICS |This idea that cops get to say when and where constitutional rights apply is so very, deeply misguided that I am shocked anyone could type it out without coming to their senses mid-sentence. |[INST] <> |[{\"generated_text\": \" Sure, here is the JSON output for the given text:\\n{\\n\\\"category\\\": \\\"Society\\\",\\n\\\"keywords\\\": [\\n\\\"police\\\",\\n\\\"human rights\\\",\\n\\\"constitution\\\",\\n\\\"violation\\\"\\n],\\n\\\"importance\\\": 9\\n\\n}\\n\\nNote that I have included the required properties and values according to the JSON schema you provided.\"}] |\"Society\" |[ |9 |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} | | | \"police\", | |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} | | | \"human rights\", | |\n", - "| | | |<> | | | \"constitution\", | |\n", - "| | | | I'm an American Citizen. If You Want to Remain a Cop, Don't Violate My Human Rights This idea that cops get to say when and where constitutional rights apply is so very, deeply misguided that I am shocked anyone could type it out without coming to their senses mid-sentence. [/INST] | | | \"violation\" | |\n", - "| | | | | | |] | |\n", - "|Jose Antonio Vargas Among Undocumented Immigrants Making Urgent Plea To Obama |POLITICS | |[INST] <> |[{\"generated_text\": \" Sure, here is the JSON output for the given input:\\n{\\n\\\"category\\\": \\\"Politics\\\",\\n\\\"keywords\\\": [\\\"undocumented immigrants\\\", \\\"Barack Obama\\\", \\\"plea\\\"],\\n\\\"importance\\\": 7\\n\\n}\\n\\nPlease note that the output is in the format of a JSON object, with the properties, keywords, and importance values corresponding to the fields in the JSON schema you provided.\"}] |\"Politics\" |[ |7 |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} | | | \"undocumented immigrants\", | |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} | | | \"Barack Obama\", | |\n", - "| | | |<> | | | \"plea\" | |\n", - "| | | | Jose Antonio Vargas Among Undocumented Immigrants Making Urgent Plea To Obama [/INST] | | |] | |\n", - "|The 1 Minute Blog. Emotions Consuming You? |GOOD NEWS | |[INST] <> |[{\"generated_text\": \" Sure, I can provide a valid JSON response that conforms to the schema you provided. Here's an example:\\n{\\n\\\"category\\\": \\\"Local News\\\",\\n\\\"keywords\\\": [\\n\\\"crime\\\",\\n\\\"arrest\\\",\\n\\\"suspect\\\",\\n\\\"police\\\",\\n\\\"investigation\\\"\\n],\\n\\\"importance\\\": 7\\n\\n}\\n\\nPlease let me know if you have any other questions or if you'd like me to\"}] |\"Local News\" |[ |7 |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} | | | \"crime\", | |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} | | | \"arrest\", | |\n", - "| | | |<> | | | \"suspect\", | |\n", - "| | | | The 1 Minute Blog. Emotions Consuming You? [/INST] | | | \"police\", | |\n", - "| | | | | | | \"investigation\" | |\n", - "| | | | | | |] | |\n", - "|Beyond the Ice-Bucket: A Deeper Challenge |IMPACT |Life is full of challenges. Some are profoundly life-changing, and some are cold and wet. The ice-bucket challenge was not only great summer fun, it was also one of the most positive and productive viral campaigns in history. |[INST] <> |[{\"generated_text\": \" Here is the JSON output for the provided text:\\n{\\n\\\"category\\\": \\\"Society and Culture\\\",\\n\\\"keywords\\\": [\\n\\\"challenge\\\",\\n\\\"life\\\",\\n\\\"viral campaign\\\",\\n\\\"positive\\\",\\n\\\"productive\\\"\\n],\\n\\\"importance\\\": 7\\n\\n}\\n\\nNote: The output is formatted according to the provided JSON schema, with the required properties and values. The \\\"keywords\\\" array includes the five keywords mentioned\"}] |\"Society and Culture\" |[ |7 |\n", - "| | | |Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the new is important. The higher the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]} | | | \"challenge\", | |\n", - "| | | |As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8} | | | \"life\", | |\n", - "| | | |<> | | | \"viral campaign\", | |\n", - "| | | | Beyond the Ice-Bucket: A Deeper Challenge Life is full of challenges. Some are profoundly life-changing, and some are cold and wet. The ice-bucket challenge was not only great summer fun, it was also one of the most positive and productive viral campaigns in history. [/INST] | | | \"positive\", | |\n", - "| | | | | | | \"productive\" | |\n", - "| | | | | | |] | |\n", - "---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", - "\n" - ] - } - ], - "source": [ - "json_capture_regexp = r'[{\\[]{1}([,:{}\\[\\]0-9.\\-+Eaeflnr-u \\n\\r\\t]|\".*?\")+[}\\]]{1}'\n", - "\n", - "output_json_col = F.parse_json(\n", - " F.regexp_extract(\n", - " F.replace(F.get(F.get(F.parse_json(F.col('\"outputs\"')), 0), F.lit(\"generated_text\")), r\"\\\"\", '\"'),\n", - " json_capture_regexp,\n", - " 0,\n", - " )\n", - ")\n", - "\n", - "output_df = res.with_columns(\n", - " [\"pred_category\", \"pred_keywords\", \"pred_importance\"],\n", - " [\n", - " F.get(output_json_col, F.lit(\"category\")),\n", - " F.get(output_json_col, F.lit(\"keywords\")),\n", - " F.get(output_json_col, F.lit(\"importance\")),\n", - " ],\n", - ")\n", - "\n", - "output_df.show(max_width=600)" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "9baccb60", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:snowflake.snowpark:ModelRegistry.delete_deployment() is in private preview since 1.0.1. Do not use it in production. \n" - ] - } - ], - "source": [ - "llama_model_ref.delete_deployment(deployment_name=DEPLOYMENT_NAME)" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "968f8571", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:snowflake.snowpark:ModelRegistry.delete_model() is in private preview since 0.2.0. Do not use it in production. \n" - ] - } - ], - "source": [ - "llama_model_ref.delete_model()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - }, - "vscode": { - "interpreter": { - "hash": "fb0a62cbfaa59af7646af5a6672c5c3e72ec75fbadf6ff0336b6769523f221a5" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/snowflake/ml/registry/package_visibility_test.py b/snowflake/ml/registry/package_visibility_test.py index 8ae69986..bf1cbf46 100644 --- a/snowflake/ml/registry/package_visibility_test.py +++ b/snowflake/ml/registry/package_visibility_test.py @@ -1,9 +1,6 @@ -from types import ModuleType - from absl.testing import absltest from snowflake.ml import registry -from snowflake.ml.registry import model_registry class PackageVisibilityTest(absltest.TestCase): @@ -12,9 +9,6 @@ class PackageVisibilityTest(absltest.TestCase): def test_class_visible(self) -> None: self.assertIsInstance(registry.Registry, type) - def test_module_visible(self) -> None: - self.assertIsInstance(model_registry, ModuleType) - if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/registry/registry.py b/snowflake/ml/registry/registry.py index ae36e430..8253eddf 100644 --- a/snowflake/ml/registry/registry.py +++ b/snowflake/ml/registry/registry.py @@ -4,6 +4,7 @@ import pandas as pd +from snowflake import snowpark from snowflake.ml._internal import telemetry from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.model import ( @@ -12,6 +13,13 @@ model_signature, type_hints as model_types, ) +from snowflake.ml.model._client.model import model_version_impl +from snowflake.ml.monitoring._client import ( + model_monitor, + model_monitor_manager, + model_monitor_version, +) +from snowflake.ml.monitoring.entities import model_monitor_config from snowflake.ml.registry._manager import model_manager from snowflake.snowpark import session @@ -26,6 +34,7 @@ def __init__( *, database_name: Optional[str] = None, schema_name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, ) -> None: """Opens a registry within a pre-created Snowflake schema. @@ -35,6 +44,9 @@ def __init__( will be used. Defaults to None. schema_name: The name of the schema. If None, the current schema of the session will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None. + options: Optional set of configurations to modify registry. + Registry Options include: + - enable_monitoring: Feature flag to indicate whether registry can be used for monitoring. Raises: ValueError: When there is no specified or active database in the session. @@ -64,6 +76,21 @@ def __init__( session, database_name=self._database_name, schema_name=self._schema_name ) + self.enable_monitoring = options.get("enable_monitoring", False) if options else False + if self.enable_monitoring: + monitor_statement_params = telemetry.get_statement_params( + project=telemetry.TelemetryProject.MLOPS.value, + subproject=telemetry.TelemetrySubProject.MONITORING.value, + ) + + self._model_monitor_manager = model_monitor_manager.ModelMonitorManager( + session=session, + database_name=self._database_name, + schema_name=self._schema_name, + create_if_not_exists=True, # TODO: Support static setup method to configure schema for monitoring. + statement_params=monitor_statement_params, + ) + @property def location(self) -> str: """Get the location (database.schema) of the registry.""" @@ -93,7 +120,7 @@ def log_model( Args: model: Model object of supported types such as Scikit-learn, XGBoost, LightGBM, Snowpark ML, PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline, - Sentence Transformers, Peft-finetuned LLM, or Custom Model. + Sentence Transformers, or Custom Model. model_name: Name to identify the model. version_name: Version identifier for the model. Combination of model_name and version_name must be unique. If not specified, a random name will be generated. @@ -119,8 +146,8 @@ def log_model( - embed_local_ml_library: Embed local Snowpark ML into the code directory or folder. Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda Channel. Otherwise, defaults to False - - relax_version: Whether or not relax the version constraints of the dependencies. - It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True. + - relax_version: Whether or not relax the version constraints of the dependencies when running in the + Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True. - function_type: Set the method function type globally. To set method function types individually see function_type in model_options. - method_options: Per-method saving options including: @@ -182,6 +209,7 @@ def log_model( sample_input_data: Optional[model_types.SupportedDataType] = None, code_paths: Optional[List[str]] = None, ext_modules: Optional[List[ModuleType]] = None, + task: model_types.Task = model_types.Task.UNKNOWN, options: Optional[model_types.ModelSaveOption] = None, ) -> ModelVersion: """ @@ -191,7 +219,7 @@ def log_model( model: Supported model or ModelVersion object. - Supported model: Model object of supported types such as Scikit-learn, XGBoost, LightGBM, Snowpark ML, PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline, Sentence Transformers, - Peft-finetuned LLM, or Custom Model. + or Custom Model. - ModelVersion: Source ModelVersion object used to create the new ModelVersion object. model_name: Name to identify the model. version_name: Version identifier for the model. Combination of model_name and version_name must be unique. @@ -213,6 +241,9 @@ def log_model( ext_modules: List of external modules to pickle with the model object. Only supported when logging the following types of model: Scikit-learn, Snowpark ML, PyTorch, TorchScript and Custom Model. Defaults to None. + task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION, + TABULAR_BINARY_CLASSIFICATION, TABULAR_MULTI_CLASSIFICATION, TABULAR_RANKING, or UNKNOWN. By default, + it is set to Task.UNKNOWN and may be overridden by inferring from the Model Object. options (Dict[str, Any], optional): Additional model saving options. Model Saving Options include: @@ -220,8 +251,8 @@ def log_model( - embed_local_ml_library: Embed local Snowpark ML into the code directory or folder. Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda Channel. Otherwise, defaults to False - - relax_version: Whether or not relax the version constraints of the dependencies. - It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True. + - relax_version: Whether or not relax the version constraints of the dependencies when running in the + Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True. - function_type: Set the method function type globally. To set method function types individually see function_type in model_options. - method_options: Per-method saving options including: @@ -261,6 +292,7 @@ def log_model( sample_input_data=sample_input_data, code_paths=code_paths, ext_modules=ext_modules, + task=task, options=options, statement_params=statement_params, ) @@ -333,3 +365,130 @@ def delete_model(self, model_name: str) -> None: ) self._model_manager.delete_model(model_name=model_name, statement_params=statement_params) + + @telemetry.send_api_usage_telemetry( + project=telemetry.TelemetryProject.MLOPS.value, + subproject=telemetry.TelemetrySubProject.MONITORING.value, + ) + @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION) + def add_monitor( + self, + name: str, + table_config: model_monitor_config.ModelMonitorTableConfig, + model_monitor_config: model_monitor_config.ModelMonitorConfig, + *, + add_dashboard_udtfs: bool = False, + ) -> model_monitor.ModelMonitor: + """Add a Model Monitor to the Registry + + Args: + name: Name of Model Monitor to create + table_config: Configuration options of table for ModelMonitor. + model_monitor_config: Configuration options of ModelMonitor. + add_dashboard_udtfs: Add UDTFs useful for creating a dashboard. + + Returns: + The newly added ModelMonitor object. + + Raises: + ValueError: If monitoring feature flag is not enabled. + """ + if not self.enable_monitoring: + raise ValueError( + "Must enable monitoring in Registry to use this method. Please set the `enable_monitoring=True` option" + ) + + # TODO: Change to fully qualified source table reference to allow table to live in different DB. + return self._model_monitor_manager.add_monitor( + name, table_config, model_monitor_config, add_dashboard_udtfs=add_dashboard_udtfs + ) + + @overload + def get_monitor(self, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor: + """Get a Model Monitor on a ModelVersion from the Registry + + Args: + model_version: ModelVersion for which to retrieve the ModelMonitor. + """ + ... + + @overload + def get_monitor(self, name: str) -> model_monitor.ModelMonitor: + """Get a Model Monitor from the Registry + + Args: + name: Name of Model Monitor to retrieve. + """ + ... + + @telemetry.send_api_usage_telemetry( + project=telemetry.TelemetryProject.MLOPS.value, + subproject=telemetry.TelemetrySubProject.MONITORING.value, + ) + @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION) + def get_monitor( + self, *, name: Optional[str] = None, model_version: Optional[model_version_impl.ModelVersion] = None + ) -> model_monitor.ModelMonitor: + """Get a Model Monitor from the Registry + + Args: + name: Name of Model Monitor to retrieve. + model_version: ModelVersion for which to retrieve the ModelMonitor. + + Returns: + The fetched ModelMonitor. + + Raises: + ValueError: If monitoring feature flag is not enabled. + ValueError: If neither name nor model_version specified. + """ + if not self.enable_monitoring: + raise ValueError( + "Must enable monitoring in Registry to use this method. Please set the `enable_monitoring=True` option" + ) + if name is not None: + return self._model_monitor_manager.get_monitor(name=name) + elif model_version is not None: + return self._model_monitor_manager.get_monitor_by_model_version(model_version=model_version) + else: + raise ValueError("Must provide either `name` or `model_version` to get ModelMonitor") + + @telemetry.send_api_usage_telemetry( + project=telemetry.TelemetryProject.MLOPS.value, + subproject=telemetry.TelemetrySubProject.MONITORING.value, + ) + @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION) + def show_model_monitors(self) -> List[snowpark.Row]: + """Show all model monitors in the registry. + + Returns: + List of snowpark.Row containing metadata for each model monitor. + + Raises: + ValueError: If monitoring feature flag is not enabled. + """ + if not self.enable_monitoring: + raise ValueError( + "Must enable monitoring in Registry to use this method. Please set the `enable_monitoring=True` option" + ) + return self._model_monitor_manager.show_model_monitors() + + @telemetry.send_api_usage_telemetry( + project=telemetry.TelemetryProject.MLOPS.value, + subproject=telemetry.TelemetrySubProject.MONITORING.value, + ) + @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION) + def delete_monitor(self, name: str) -> None: + """Delete a Model Monitor from the Registry + + Args: + name: Name of the Model Monitor to delete. + + Raises: + ValueError: If monitoring feature flag is not enabled. + """ + if not self.enable_monitoring: + raise ValueError( + "Must enable monitoring in Registry to use this method. Please set the `enable_monitoring=True` option" + ) + self._model_monitor_manager.delete_monitor(name) diff --git a/snowflake/ml/registry/registry_test.py b/snowflake/ml/registry/registry_test.py index f562e1d8..a9dc4d7c 100644 --- a/snowflake/ml/registry/registry_test.py +++ b/snowflake/ml/registry/registry_test.py @@ -1,11 +1,17 @@ from typing import cast from unittest import mock +from unittest.mock import patch from absl.testing import absltest +from snowflake.ml.model import model_signature, type_hints +from snowflake.ml.model._client.model import model_version_impl +from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema +from snowflake.ml.monitoring._client import model_monitor +from snowflake.ml.monitoring.entities import model_monitor_config from snowflake.ml.registry import registry -from snowflake.ml.test_utils import mock_session -from snowflake.snowpark import Session +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import Row, Session, types class RegistryNameTest(absltest.TestCase): @@ -152,6 +158,7 @@ def test_log_model(self) -> None: ext_modules=m_ext_modules, options=m_options, statement_params=mock.ANY, + task=type_hints.Task.UNKNOWN, ) def test_log_model_from_model_version(self) -> None: @@ -177,6 +184,7 @@ def test_log_model_from_model_version(self) -> None: ext_modules=None, options=None, statement_params=mock.ANY, + task=type_hints.Task.UNKNOWN, ) def test_delete_model(self) -> None: @@ -190,5 +198,184 @@ def test_delete_model(self) -> None: ) +class MonitorRegistryTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.test_monitor_name = "TEST" + self.test_source_table_name = "MODEL_OUTPUTS" + self.test_db_name = "SNOWML_OBSERVABILITY" + self.test_schema_name = "METADATA" + self.test_model_name = "test_model" + self.test_model_name_sql = "TEST_MODEL" + self.test_model_version_name = "test_model_version" + self.test_model_version_name_sql = "TEST_MODEL_VERSION" + self.test_fq_model_name = f"{self.test_db_name}.{self.test_schema_name}.{self.test_model_name}" + self.test_warehouse = "TEST_WAREHOUSE" + self.test_timestamp_column = "TIMESTAMP" + self.test_prediction_column_name = "PREDICTION" + self.test_label_column_name = "LABEL" + self.test_id_column_name = "ID" + self.test_baseline_table_name_sql = "_SNOWML_OBS_BASELINE_TEST_MODEL_TEST_MODEL_VERSION" + + model_version = mock.MagicMock() + model_version.version_name = self.test_model_version_name + model_version.model_name = self.test_model_name + model_version.fully_qualified_model_name = self.test_fq_model_name + model_version.show_functions.return_value = [ + model_manifest_schema.ModelFunctionInfo( + name="PREDICT", + target_method="predict", + target_method_function_type="FUNCTION", + signature=model_signature.ModelSignature(inputs=[], outputs=[]), + is_partitioned=False, + ) + ] + model_version.get_model_task.return_value = type_hints.Task.TABULAR_REGRESSION + self.m_model_version: model_version_impl.ModelVersion = model_version + self.test_monitor_config = model_monitor_config.ModelMonitorConfig( + model_version=self.m_model_version, + model_function_name="predict", + background_compute_warehouse_name=self.test_warehouse, + ) + self.test_table_config = model_monitor_config.ModelMonitorTableConfig( + prediction_columns=[self.test_prediction_column_name], + label_columns=[self.test_label_column_name], + id_columns=[self.test_id_column_name], + timestamp_column=self.test_timestamp_column, + source_table=self.test_source_table_name, + ) + + mock_struct_fields = [] + for col in ["NUM_0"]: + mock_struct_fields.append(types.StructField(col, types.FloatType(), True)) + for col in ["CAT_0"]: + mock_struct_fields.append(types.StructField(col, types.StringType(), True)) + self.mock_schema = types.StructType._from_attributes(mock_struct_fields) + + mock_struct_fields = [] + for col in ["NUM_0"]: + mock_struct_fields.append(types.StructField(col, types.FloatType(), True)) + for col in ["CAT_0"]: + mock_struct_fields.append(types.StructField(col, types.StringType(), True)) + self.mock_schema = types.StructType._from_attributes(mock_struct_fields) + + def _add_expected_monitoring_init_calls(self, model_monitor_create_if_not_exists: bool = False) -> None: + self.m_session.add_mock_sql( + query="""CREATE TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.METADATA._SYSTEM_MONITORING_METADATA + (MONITOR_NAME VARCHAR, SOURCE_TABLE_NAME VARCHAR, FULLY_QUALIFIED_MODEL_NAME VARCHAR, + MODEL_VERSION_NAME VARCHAR, FUNCTION_NAME VARCHAR, TASK VARCHAR, IS_ENABLED BOOLEAN, + TIMESTAMP_COLUMN_NAME VARCHAR, PREDICTION_COLUMN_NAMES ARRAY, + LABEL_COLUMN_NAMES ARRAY, ID_COLUMN_NAMES ARRAY) + """, + result=mock_data_frame.MockDataFrame([Row(status="Table successfully created.")]), + ) + + if not model_monitor_create_if_not_exists: # this code path does validation on whether tables exist. + self.m_session.add_mock_sql( + query="""SHOW TABLES LIKE '_SYSTEM_MONITORING_METADATA' IN SNOWML_OBSERVABILITY.METADATA""", + result=mock_data_frame.MockDataFrame([Row(name="_SYSTEM_MONITORING_METADATA")]), + ) + + def test_init(self) -> None: + self._add_expected_monitoring_init_calls(model_monitor_create_if_not_exists=True) + session = cast(Session, self.m_session) + r1 = registry.Registry( + session, + database_name=self.test_db_name, + schema_name=self.test_schema_name, + options={"enable_monitoring": True}, + ) + self.assertEqual(r1.enable_monitoring, True) + + r2 = registry.Registry( + session, + database_name=self.test_db_name, + schema_name=self.test_schema_name, + ) + self.assertEqual(r2.enable_monitoring, False) + self.m_session.finalize() + + def test_add_monitor(self) -> None: + self._add_expected_monitoring_init_calls(model_monitor_create_if_not_exists=True) + + session = cast(Session, self.m_session) + m_r = registry.Registry( + session, + database_name=self.test_db_name, + schema_name=self.test_schema_name, + options={"enable_monitoring": True}, + ) + m_monitor = mock.Mock() + m_monitor.name = self.test_monitor_name + + with mock.patch.object(m_r._model_monitor_manager, "add_monitor", return_value=m_monitor) as mock_add_monitor: + monitor: model_monitor.ModelMonitor = m_r.add_monitor( + self.test_monitor_name, + self.test_table_config, + self.test_monitor_config, + ) + mock_add_monitor.assert_called_once_with( + self.test_monitor_name, self.test_table_config, self.test_monitor_config, add_dashboard_udtfs=False + ) + self.assertEqual(monitor.name, self.test_monitor_name) + self.m_session.finalize() + + def test_get_monitor(self) -> None: + self._add_expected_monitoring_init_calls(model_monitor_create_if_not_exists=True) + + session = cast(Session, self.m_session) + m_r = registry.Registry( + session, + database_name=self.test_db_name, + schema_name=self.test_schema_name, + options={"enable_monitoring": True}, + ) + m_model_monitor: model_monitor.ModelMonitor = mock.MagicMock() + with mock.patch.object( + m_r._model_monitor_manager, "get_monitor", return_value=m_model_monitor + ) as mock_get_monitor: + m_r.get_monitor(name=self.test_monitor_name) + mock_get_monitor.assert_called_once_with(name=self.test_monitor_name) + self.m_session.finalize() + + def test_get_monitor_by_model_version(self) -> None: + self._add_expected_monitoring_init_calls(model_monitor_create_if_not_exists=True) + session = cast(Session, self.m_session) + m_r = registry.Registry( + session, + database_name=self.test_db_name, + schema_name=self.test_schema_name, + options={"enable_monitoring": True}, + ) + m_model_monitor: model_monitor.ModelMonitor = mock.MagicMock() + with mock.patch.object( + m_r._model_monitor_manager, "get_monitor_by_model_version", return_value=m_model_monitor + ) as mock_get_monitor: + m_r.get_monitor(model_version=self.m_model_version) + mock_get_monitor.assert_called_once_with(model_version=self.m_model_version) + self.m_session.finalize() + + @patch("snowflake.ml.monitoring._client.model_monitor_manager.ModelMonitorManager", autospec=True) + def test_show_model_monitors(self, m_model_monitor_manager_class: mock.MagicMock) -> None: + # Dont need to call self._add_expected_monitoring_init_calls since ModelMonitorManager.__init__ is + # auto mocked. + m_model_monitor_manager = m_model_monitor_manager_class.return_value + sql_result = [ + Row( + col1="val1", + col2="val2", + ) + ] + m_model_monitor_manager.show_model_monitors.return_value = sql_result + session = cast(Session, self.m_session) + m_r = registry.Registry( + session, + database_name=self.test_db_name, + schema_name=self.test_schema_name, + options={"enable_monitoring": True}, + ) + self.assertEqual(m_r.show_model_monitors(), sql_result) + + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/test_utils/mock_session.py b/snowflake/ml/test_utils/mock_session.py index fe6bbcef..8baefc7d 100644 --- a/snowflake/ml/test_utils/mock_session.py +++ b/snowflake/ml/test_utils/mock_session.py @@ -1,6 +1,6 @@ from __future__ import annotations # for return self methods -from typing import Any, Type +from typing import Any, Optional, Type from unittest import TestCase from snowflake import snowpark @@ -110,3 +110,9 @@ def query_history(self) -> Any: check_kwargs=False, ) return mo.result + + def get_current_database(self) -> Optional[str]: + return None + + def get_current_schema(self) -> Optional[str]: + return None diff --git a/snowflake/ml/version.bzl b/snowflake/ml/version.bzl index d69b564f..e1060d5b 100644 --- a/snowflake/ml/version.bzl +++ b/snowflake/ml/version.bzl @@ -1,2 +1,2 @@ # This is parsed by regex in conda reciper meta file. Make sure not to break it. -VERSION = "1.6.2" +VERSION = "1.6.3" diff --git a/tests/integ/snowflake/cortex/embed_text_test.py b/tests/integ/snowflake/cortex/embed_text_test.py new file mode 100644 index 00000000..b4128137 --- /dev/null +++ b/tests/integ/snowflake/cortex/embed_text_test.py @@ -0,0 +1,40 @@ +from typing import List + +from absl.testing import absltest + +from snowflake import snowpark +from snowflake.cortex import EmbedText768, EmbedText1024 +from snowflake.ml.utils import connection_params +from snowflake.snowpark import Session, functions + +_TEXT = "Text to embed" + + +class EmbedTextTest(absltest.TestCase): + def setUp(self) -> None: + self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() + + def tearDown(self) -> None: + self._session.close() + + def text_embed_text_768(self) -> None: + df_in = self._session.create_dataframe([snowpark.Row(model="e5-base-v2", text=_TEXT)]) + df_out = df_in.select(EmbedText768(functions.col("model"), functions.col("text"))) + res = df_out.collect()[0][0] + self.assertIsInstance(res, List) + self.assertEqual(len(res), 768) + # Check a subset. + self.assertEqual(res[:4], [-0.001, 0.002, -0.003, 0.004]) + + def text_embed_text_1024(self) -> None: + df_in = self._session.create_dataframe([snowpark.Row(model="multilingual-e5-large", text=_TEXT)]) + df_out = df_in.select(EmbedText1024(functions.col("model"), functions.col("text"))) + res = df_out.collect()[0][0] + self.assertIsInstance(res, List) + self.assertEqual(len(res), 1024) + # Check a subset. + self.assertEqual(res[:4], [-0.001, 0.002, -0.003, 0.004]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py b/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py index 33dd0d60..b3b9909b 100644 --- a/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py +++ b/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py @@ -144,6 +144,13 @@ def test_fit_and_compare_results(self) -> None: np.testing.assert_allclose(results.flatten(), sk_results.flatten(), rtol=1.0e-1, atol=1.0e-2) + @pytest.mark.skipif( + os.getenv("IN_SPCS_ML_RUNTIME") == "True", + reason=( + "Skipping this test on Container Runtimes. " + "See: https://snowflakecomputing.atlassian.net/browse/SNOW-1648870" + ), + ) def test_fit_predict_proba_and_compare_results(self) -> None: pd_data = self._test_data pd_data["ROW_INDEX"] = pd_data.reset_index().index diff --git a/tests/integ/snowflake/ml/extra_tests/quoted_identifier_test.py b/tests/integ/snowflake/ml/extra_tests/quoted_identifier_test.py index 42dfa41e..cdef3ba2 100644 --- a/tests/integ/snowflake/ml/extra_tests/quoted_identifier_test.py +++ b/tests/integ/snowflake/ml/extra_tests/quoted_identifier_test.py @@ -1,4 +1,5 @@ import os +from unittest import skipIf import numpy as np from absl.testing import absltest, parameterized @@ -29,6 +30,10 @@ def setUp(self): def tearDown(self): self._session.close() + @skipIf( + os.getenv("IN_SPCS_ML_RUNTIME") == "True", + "Skipping this test on Container Runtimes. See: https://snowflakecomputing.atlassian.net/browse/SNOW-1633651", + ) @parameterized.parameters(False, True) def test_sp_quoted_identifier_modeling(self, test_within_sproc) -> None: if test_within_sproc: diff --git a/tests/integ/snowflake/ml/feature_store/feature_store_test.py b/tests/integ/snowflake/ml/feature_store/feature_store_test.py index 0927ea69..710414b0 100644 --- a/tests/integ/snowflake/ml/feature_store/feature_store_test.py +++ b/tests/integ/snowflake/ml/feature_store/feature_store_test.py @@ -2612,6 +2612,38 @@ def register(fs: FeatureStore, name: str, refresh_mode: Optional[str] = None) -> self.assertEqual("FULL", register(fs, "fv2", "FULL").refresh_mode) self.assertEqual("INCREMENTAL", register(fs, "fv3", "INCREMENTAL").refresh_mode) + def test_specified_initialize(self) -> None: + fs = self._create_feature_store() + sql = f"SELECT id, name, title FROM {self._mock_table}" + + e = Entity("foo", ["id"]) + fs.register_entity(e) + + def register( + fs: FeatureStore, name: str, refresh_freq: Optional[str] = None, initialize: str = "" + ) -> FeatureView: + fv = FeatureView( + name=name, + entities=[e], + feature_df=self._session.sql(sql), + refresh_freq=refresh_freq, + initialize=initialize, + ) + return fs.register_feature_view(feature_view=fv, version="v1") + + self.assertEqual("ON_CREATE", register(fs, "fv1", None, "ON_CREATE").initialize) + self.assertEqual("ON_CREATE", register(fs, "fv2", "1d", "ON_CREATE").initialize) + self.assertEqual("ON_SCHEDULE", register(fs, "fv3", "5d", "ON_SCHEDULE").initialize) + + with self.assertRaisesRegex(ValueError, "'initialize' only supports ON_CREATE or ON_SCHEDULE"): + FeatureView( + name="fv4", + entities=[e], + feature_df=self._session.sql(sql), + refresh_freq="1d", + initialize="RANDOM_INIT_VALUE", + ) + def test_feature_view_list_columns(self) -> None: fs = self._create_feature_store() diff --git a/tests/integ/snowflake/ml/image_builds/BUILD.bazel b/tests/integ/snowflake/ml/image_builds/BUILD.bazel deleted file mode 100644 index 8e3e15f9..00000000 --- a/tests/integ/snowflake/ml/image_builds/BUILD.bazel +++ /dev/null @@ -1,14 +0,0 @@ -load("//bazel:py_rules.bzl", "py_test") - -py_test( - name = "image_registry_client_integ_test", - timeout = "long", - srcs = ["image_registry_client_integ_test.py"], - deps = [ - "//snowflake/ml/_internal/container_services/image_registry:registry_client", - "//snowflake/ml/_internal/utils:identifier", - "//snowflake/ml/_internal/utils:query_result_checker", - "//snowflake/ml/model/_deploy_client/utils:snowservice_client", - "//tests/integ/snowflake/ml/test_utils:spcs_integ_test_base", - ], -) diff --git a/tests/integ/snowflake/ml/image_builds/image_registry_client_integ_test.py b/tests/integ/snowflake/ml/image_builds/image_registry_client_integ_test.py deleted file mode 100644 index e502c894..00000000 --- a/tests/integ/snowflake/ml/image_builds/image_registry_client_integ_test.py +++ /dev/null @@ -1,66 +0,0 @@ -from absl.testing import absltest - -from snowflake.ml._internal.container_services.image_registry import ( - registry_client as image_registry_client, -) -from snowflake.ml._internal.utils import identifier, query_result_checker -from snowflake.ml.model._deploy_client.utils import snowservice_client -from tests.integ.snowflake.ml.test_utils import spcs_integ_test_base - - -class ImageRegistryClientIntegTest(spcs_integ_test_base.SpcsIntegTestBase): - def setUp(self) -> None: - super().setUp() - self._TEST_REPO = "TEST_REPO" - client = snowservice_client.SnowServiceClient(self._session) - client.create_image_repo( - identifier.get_schema_level_object_identifier(self._test_db, self._test_schema, self._TEST_REPO) - ) - self._session.sql("ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'json'").collect() - - def tearDown(self) -> None: - self._session.sql("ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'arrow'").collect() - super().tearDown() - - def _get_repo_url(self) -> str: - """Retrieve repo url. - - Returns: repo url, sample repo url format: org-account.registry.snowflakecomputing.com/db/schema/repo. - """ - sql = ( - f"SHOW IMAGE REPOSITORIES LIKE '{self._TEST_REPO}' " - f"IN SCHEMA {'.'.join([self._test_db, self._test_schema])}" - ) - result = ( - query_result_checker.SqlResultValidator( - session=self._session, - query=sql, - ) - .has_column("repository_url") - .has_dimensions(expected_rows=1) - .validate() - ) - return result[0]["repository_url"] - - def test_copy_from_docker_hub_to_spcs_registry_and_add_tag(self) -> None: - repo_url = self._get_repo_url() - dest_image = "/".join([repo_url, "kaniko-project/executor:v1.16.0-debug"]) - client = image_registry_client.ImageRegistryClient(self._session, full_dest_image_name=dest_image) - self.assertFalse(client.image_exists(dest_image)) - client.copy_image( - "gcr.io/kaniko-project/executor@sha256:b8c0977f88f24dbd7cbc2ffe5c5f824c410ccd0952a72cc066efc4b6dfbb52b6", - dest_image, - ) - self.assertTrue(client.image_exists(dest_image)) - - parts = dest_image.split(":") - assert len(parts) == 2 - new_tag = "snowml-test-tag" - full_image_with_new_tag = ":".join([parts[0], new_tag]) - self.assertFalse(client.image_exists(full_image_with_new_tag)) - client.add_tag_to_remote_image(dest_image, new_tag=new_tag) - self.assertTrue(client.image_exists(full_image_with_new_tag)) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/lineage/lineage_integ_test.py b/tests/integ/snowflake/ml/lineage/lineage_integ_test.py index 3a49987e..3bbbd491 100644 --- a/tests/integ/snowflake/ml/lineage/lineage_integ_test.py +++ b/tests/integ/snowflake/ml/lineage/lineage_integ_test.py @@ -3,6 +3,7 @@ from absl.testing import absltest from sklearn import datasets +from snowflake.ml import dataset from snowflake.ml.feature_store.entity import Entity from snowflake.ml.feature_store.feature_store import CreationMode, FeatureStore from snowflake.ml.feature_store.feature_view import FeatureView @@ -10,7 +11,6 @@ from snowflake.ml.model import ModelVersion from snowflake.ml.modeling.linear_model import LogisticRegression from snowflake.ml.registry import Registry -from snowflake.ml.utils import connection_params from snowflake.snowpark import Session from tests.integ.snowflake.ml.test_utils import common_test_base, db_manager @@ -19,9 +19,8 @@ class TestSnowflakeLineage(common_test_base.CommonTestBase): def setUp(self) -> None: """Creates Snowpark and Snowflake environments for testing.""" super().setUp() - self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - self._db_manager = db_manager.DBManager(self._session) - self._current_db = self._session.get_current_database().replace('"', "") + self._db_manager = db_manager.DBManager(self.session) + self._current_db = self.session.get_current_database().replace('"', "") self._test_schema = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( uuid.uuid4().hex.upper()[:5], "schema" ).upper() @@ -55,13 +54,13 @@ def test_lineage(self): model_version = "V" + uuid.uuid4().hex.upper()[:5] self._create_iris_table(self.session, f"{self._test_schema}.{table_name}") - df = self._session.table(table_name) + df = self.session.table(table_name) fs = FeatureStore( - self._session, + self.session, self._current_db, self._test_schema, - default_warehouse=self._session.get_current_warehouse(), + default_warehouse=self.session.get_current_warehouse(), creation_mode=CreationMode.CREATE_IF_NOT_EXIST, ) @@ -112,8 +111,6 @@ def test_lineage(self): ) assert isinstance(fv_downstream[0], LineageNode) - from snowflake.ml import dataset - fv_downstream = fv.lineage(domain_filter=["dataset"]) self._check_lineage( fv_downstream, f"{self._current_db}.{self._test_schema}.{dataset_name}", dataset_version, "dataset" diff --git a/tests/integ/snowflake/ml/model/BUILD.bazel b/tests/integ/snowflake/ml/model/BUILD.bazel index 3f49e60e..e69de29b 100644 --- a/tests/integ/snowflake/ml/model/BUILD.bazel +++ b/tests/integ/snowflake/ml/model/BUILD.bazel @@ -1,233 +0,0 @@ -load("//bazel:py_rules.bzl", "py_library", "py_test") - -py_library( - name = "warehouse_model_integ_test_utils", - testonly = True, - srcs = ["warehouse_model_integ_test_utils.py"], - deps = [ - "//snowflake/ml/model:_api", - "//snowflake/ml/model:deploy_platforms", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/model/_signatures:snowpark_handler", - "//tests/integ/snowflake/ml/test_utils:dataframe_utils", - "//tests/integ/snowflake/ml/test_utils:db_manager", - "//tests/integ/snowflake/ml/test_utils:test_env_utils", - ], -) - -py_test( - name = "warehouse_catboost_model_integ_test", - timeout = "long", - srcs = ["warehouse_catboost_model_integ_test.py"], - shard_count = 2, - deps = [ - ":warehouse_model_integ_test_utils", - "//snowflake/ml/model:custom_model", - "//snowflake/ml/model:deploy_platforms", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:dataframe_utils", - "//tests/integ/snowflake/ml/test_utils:db_manager", - "//tests/integ/snowflake/ml/test_utils:test_env_utils", - ], -) - -py_test( - name = "warehouse_custom_model_integ_test", - timeout = "long", - srcs = ["warehouse_custom_model_integ_test.py"], - shard_count = 6, - deps = [ - ":warehouse_model_integ_test_utils", - "//snowflake/ml/model:custom_model", - "//snowflake/ml/model:deploy_platforms", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:dataframe_utils", - "//tests/integ/snowflake/ml/test_utils:db_manager", - "//tests/integ/snowflake/ml/test_utils:test_env_utils", - ], -) - -py_test( - name = "warehouse_pytorch_model_integ_test", - timeout = "long", - srcs = ["warehouse_pytorch_model_integ_test.py"], - shard_count = 6, - deps = [ - ":warehouse_model_integ_test_utils", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/model/_signatures:pytorch_handler", - "//snowflake/ml/model/_signatures:snowpark_handler", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:dataframe_utils", - "//tests/integ/snowflake/ml/test_utils:db_manager", - "//tests/integ/snowflake/ml/test_utils:model_factory", - ], -) - -py_test( - name = "warehouse_tensorflow_model_integ_test", - timeout = "long", - srcs = ["warehouse_tensorflow_model_integ_test.py"], - shard_count = 6, - deps = [ - ":warehouse_model_integ_test_utils", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/model/_signatures:snowpark_handler", - "//snowflake/ml/model/_signatures:tensorflow_handler", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:dataframe_utils", - "//tests/integ/snowflake/ml/test_utils:db_manager", - "//tests/integ/snowflake/ml/test_utils:model_factory", - ], -) - -py_test( - name = "warehouse_sklearn_xgboost_model_integ_test", - timeout = "long", - srcs = ["warehouse_sklearn_xgboost_model_integ_test.py"], - shard_count = 6, - deps = [ - ":warehouse_model_integ_test_utils", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:dataframe_utils", - "//tests/integ/snowflake/ml/test_utils:db_manager", - ], -) - -py_test( - name = "warehouse_lightgbm_model_integ_test", - timeout = "long", - srcs = ["warehouse_lightgbm_model_integ_test.py"], - shard_count = 2, - deps = [ - ":warehouse_model_integ_test_utils", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:dataframe_utils", - "//tests/integ/snowflake/ml/test_utils:db_manager", - ], -) - -py_test( - name = "warehouse_snowml_model_integ_test", - timeout = "long", - srcs = ["warehouse_snowml_model_integ_test.py"], - shard_count = 4, - deps = [ - ":warehouse_model_integ_test_utils", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/modeling/lightgbm:lgbm_regressor", - "//snowflake/ml/modeling/linear_model:logistic_regression", - "//snowflake/ml/modeling/xgboost:xgb_regressor", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:db_manager", - ], -) - -py_test( - name = "model_badcase_integ_test", - timeout = "long", - srcs = ["model_badcase_integ_test.py"], - deps = [ - ":warehouse_model_integ_test_utils", - "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/model:_api", - "//snowflake/ml/model:custom_model", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:db_manager", - ], -) - -py_test( - name = "warehouse_mlflow_model_integ_test", - timeout = "long", - srcs = ["warehouse_mlflow_model_integ_test.py"], - shard_count = 4, - deps = [ - ":warehouse_model_integ_test_utils", - "//snowflake/ml/_internal:env", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/model/_signatures:numpy_handler", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:db_manager", - ], -) - -py_test( - name = "deployment_to_snowservice_integ_test", - timeout = "long", - srcs = ["deployment_to_snowservice_integ_test.py"], - deps = [ - "//snowflake/ml/model:_api", - "//snowflake/ml/model:custom_model", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/model/_deploy_client/snowservice:deploy", - "//snowflake/ml/model/_deploy_client/utils:constants", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:db_manager", - "//tests/integ/snowflake/ml/test_utils:test_env_utils", - ], -) - -py_test( - name = "warehouse_huggingface_pipeline_model_integ_test", - timeout = "long", - srcs = ["warehouse_huggingface_pipeline_model_integ_test.py"], - shard_count = 8, - deps = [ - ":warehouse_model_integ_test_utils", - "//snowflake/ml/_internal:env_utils", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:db_manager", - ], -) - -py_test( - name = "warehouse_sentence_transformers_model_integ_test", - timeout = "long", - srcs = ["warehouse_sentence_transformers_model_integ_test.py"], - shard_count = 4, - deps = [ - ":warehouse_model_integ_test_utils", - "//snowflake/ml/_internal:env_utils", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:db_manager", - ], -) - -py_test( - name = "spcs_llm_model_integ_test", - timeout = "eternal", # 3600s, GPU image takes very long to build.. - srcs = ["spcs_llm_model_integ_test.py"], - compatible_with_snowpark = False, - deps = [ - ":warehouse_model_integ_test_utils", - "//snowflake/ml/_internal:env_utils", - "//snowflake/ml/model:type_hints", - "//snowflake/ml/model/models:llm_model", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:db_manager", - "//tests/integ/snowflake/ml/test_utils:spcs_integ_test_base", - "//tests/integ/snowflake/ml/test_utils:test_env_utils", - ], -) - -py_test( - name = "warehouse_model_compat_v1_test", - timeout = "long", - srcs = ["warehouse_model_compat_v1_test.py"], - shard_count = 8, - deps = [ - "//snowflake/ml/_internal:env", - "//snowflake/ml/model:_api", - "//snowflake/ml/model:deploy_platforms", - "//tests/integ/snowflake/ml/test_utils:common_test_base", - "//tests/integ/snowflake/ml/test_utils:db_manager", - ], -) diff --git a/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel b/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel index 779d6f24..9005d060 100644 --- a/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel +++ b/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel @@ -41,15 +41,3 @@ py_test( "//tests/integ/snowflake/ml/test_utils:db_manager", ], ) - -py_test( - name = "model_deployment_test", - timeout = "long", - srcs = ["model_deployment_test.py"], - shard_count = 2, - deps = [ - "//snowflake/ml/registry", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:db_manager", - ], -) diff --git a/tests/integ/snowflake/ml/model/_client/model/model_deployment_test.py b/tests/integ/snowflake/ml/model/_client/model/model_deployment_test.py deleted file mode 100644 index 51ab822c..00000000 --- a/tests/integ/snowflake/ml/model/_client/model/model_deployment_test.py +++ /dev/null @@ -1,132 +0,0 @@ -import inspect -import time -import uuid - -import numpy as np -from absl.testing import absltest -from sklearn import datasets, linear_model, svm - -from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.registry import registry -from snowflake.ml.utils import connection_params -from snowflake.snowpark import Session -from tests.integ.snowflake.ml.test_utils import db_manager - - -class ModelDeploymentTest(absltest.TestCase): - """Test model container services deployment.""" - - _TEST_CPU_COMPUTE_POOL = "REGTEST_INFERENCE_CPU_POOL" - _SPCS_EAI = "SPCS_EGRESS_ACCESS_INTEGRATION" - - def setUp(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - login_options = connection_params.SnowflakeLoginOptions() - - self._run_id = uuid.uuid4().hex[:2] - self._test_db = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self._run_id, "db").upper() - self._test_schema = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self._run_id, "schema" - ).upper() - self._test_image_repo = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self._run_id, "image_repo" - ).upper() - - self._session = Session.builder.configs( - { - **login_options, - **{"database": self._test_db, "schema": self._test_schema}, - } - ).create() - - self._db_manager = db_manager.DBManager(self._session) - self._db_manager.create_database(self._test_db) - self._db_manager.create_schema(self._test_schema) - self._db_manager.create_image_repo(self._test_image_repo) - self._db_manager.cleanup_databases(expire_hours=6) - self.registry = registry.Registry(self._session) - - def tearDown(self) -> None: - self._db_manager.drop_database(self._test_db) - self._session.close() - - @absltest.skip - def test_create_service(self) -> None: - iris_X, iris_y = datasets.load_iris(return_X_y=True) - # LogisticRegression is for classfication task, such as iris - regr = linear_model.LogisticRegression() - regr.fit(iris_X, iris_y) - - model_name = f"model_{inspect.stack()[1].function}" - version_name = f"ver_{self._run_id}" - mv = self.registry.log_model( - model=regr, - model_name=model_name, - version_name=version_name, - sample_input_data=iris_X, - ) - - service = f"service_{self._run_id}" - mv.create_service( - service_name=service, - image_build_compute_pool=self._TEST_CPU_COMPUTE_POOL, - service_compute_pool=self._TEST_CPU_COMPUTE_POOL, - image_repo=self._test_image_repo, - force_rebuild=True, - build_external_access_integration=self._SPCS_EAI, - ) - self.assertTrue(self._wait_for_service(service)) - - @absltest.skip - def test_inference(self) -> None: - iris_X, iris_y = datasets.load_iris(return_X_y=True) - svc = svm.LinearSVC() - svc.fit(iris_X, iris_y) - - model_name = f"model_{inspect.stack()[1].function}" - version_name = f"ver_{self._run_id}" - mv = self.registry.log_model( - model=svc, - model_name=model_name, - version_name=version_name, - sample_input_data=iris_X, - ) - - service = f"service_{self._run_id}" - mv.create_service( - service_name=service, - image_build_compute_pool=self._TEST_CPU_COMPUTE_POOL, - service_compute_pool=self._TEST_CPU_COMPUTE_POOL, - image_repo=self._test_image_repo, - force_rebuild=True, - build_external_access_integration=self._SPCS_EAI, - ) - self.assertTrue(self._wait_for_service(service)) - - res = mv.run(iris_X, function_name="predict", service_name=service) - np.testing.assert_allclose(res["output_feature_0"].values, svc.predict(iris_X)) - - def _wait_for_service(self, service: str) -> bool: - service_identifier = sql_identifier.SqlIdentifier(service).identifier() - - # wait for service creation - while True: - services = [serv["name"] for serv in self._session.sql("SHOW SERVICES").collect()] - if service_identifier not in services: - time.sleep(10) - else: - break - - # wait for service to run - while True: - status = self._session.sql(f"DESC SERVICE {service_identifier}").collect()[0]["status"] - if status == "RUNNING": - return True - elif status == "PENDING": - time.sleep(10) - else: - return False - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py b/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py index 10113d89..7c516976 100644 --- a/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py +++ b/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py @@ -78,11 +78,11 @@ def test_metrics(self) -> None: def test_export(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: self._mv.export(tmpdir) - self.assertLen(list(glob.iglob(os.path.join(tmpdir, "**", "*"), recursive=True)), 14) + self.assertLen(list(glob.iglob(os.path.join(tmpdir, "**", "*"), recursive=True)), 16) with tempfile.TemporaryDirectory() as tmpdir: self._mv.export(tmpdir, export_mode=ExportMode.FULL) - self.assertLen(list(glob.iglob(os.path.join(tmpdir, "**", "*"), recursive=True)), 26) + self.assertLen(list(glob.iglob(os.path.join(tmpdir, "**", "*"), recursive=True)), 29) def test_load(self) -> None: loaded_model = self._mv.load() diff --git a/tests/integ/snowflake/ml/model/deployment_to_snowservice_integ_test.py b/tests/integ/snowflake/ml/model/deployment_to_snowservice_integ_test.py deleted file mode 100644 index 6f845e85..00000000 --- a/tests/integ/snowflake/ml/model/deployment_to_snowservice_integ_test.py +++ /dev/null @@ -1,121 +0,0 @@ -# TODO[shchen], SNOW-889081, re-enable once server-side image build is supported. - -# import uuid -# from unittest import SkipTest -# from typing import Tuple - -# import pandas as pd -# import sklearn.base -# import sklearn.datasets as datasets -from absl.testing import absltest - -# from sklearn import neighbors - -# from snowflake.ml.model import ( -# _api as model_api, -# _model_meta, -# custom_model, -# deploy_platforms, -# type_hints as model_types, -# ) -# from snowflake.ml.model._deploy_client.snowservice import deploy as snowservice_api -# from snowflake.ml.model._deploy_client.utils import constants -# from snowflake.ml.utils import connection_params -# from snowflake.snowpark import Session -# from tests.integ.snowflake.ml.test_utils import db_manager - -# _IRIS = datasets.load_iris(as_frame=True) -# _IRIS_X = _IRIS.data -# _IRIS_Y = _IRIS.target - - -# def _get_sklearn_model() -> "sklearn.base.BaseEstimator": -# knn_model = neighbors.KNeighborsClassifier() -# knn_model.fit(_IRIS_X, _IRIS_Y) -# return knn_model - - -# -# class DeploymentToSnowServiceIntegTest(absltest.TestCase): -# _RUN_ID = uuid.uuid4().hex[:2] -# # Upper is necessary for `db, schema and repo names for an image repo must be unquoted identifiers.` -# TEST_DB = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(_RUN_ID, "db").upper() -# TEST_SCHEMA = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(_RUN_ID, "schema").upper() -# TEST_STAGE = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(_RUN_ID, "stage").upper() -# TEST_IMAGE_REPO = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(_RUN_ID, "repo").upper() -# TEST_ROLE = "SYSADMIN" -# TEST_COMPUTE_POOL = "MODEL_DEPLOYMENT_INTEG_TEST_POOL_STANDARD_2" # PRE-CREATED -# CONNECTION_NAME = "snowservice" # PRE-CREATED AND STORED IN KEY VAULT - -# @classmethod -# def setUpClass(cls) -> None: -# try: -# login_options = connection_params.SnowflakeLoginOptions(connection_name=cls.CONNECTION_NAME) -# except KeyError: -# raise SkipTest("SnowService connection parameters not present: skipping SnowServicesIntegTest.") - -# cls._session = Session.builder.configs( -# { -# **login_options, -# **{"database": cls.TEST_DB, "schema": cls.TEST_SCHEMA}, -# } -# ).create() -# cls._db_manager = db_manager.DBManager(cls._session) -# cls._db_manager.set_role(cls.TEST_ROLE) -# cls._db_manager.create_stage(cls.TEST_STAGE, cls.TEST_SCHEMA, cls.TEST_DB, sse_encrypted=True) -# cls._db_manager.create_image_repo(cls.TEST_IMAGE_REPO) -# cls._db_manager.cleanup_databases(expire_hours=6) - -# @classmethod -# def tearDownClass(cls) -> None: -# cls._db_manager.drop_image_repo(cls.TEST_IMAGE_REPO) -# # Dropping the db/schema will implicitly terminate the service function and snowservice as well. -# cls._db_manager.drop_database(cls.TEST_DB) -# cls._session.close() - -# def setUp(self) -> None: -# # Set up a unique id for each artifact, in addition to the class-level prefix. This is particularly useful -# # when differentiating artifacts generated between different test cases, such as service function names. -# self.uid = uuid.uuid4().hex[:4] - -# def _save_model_to_stage( -# self, model: custom_model.CustomModel, sample_input_data: pd.DataFrame -# ) -> Tuple[str, _model_meta.ModelMetadata]: -# stage_path = f"@{self.TEST_STAGE}/{self.uid}" -# meta = model_api.save_model( # type: ignore[call-overload] -# name="model", -# session=self._session, -# stage_path=stage_path, -# model=model, -# sample_input_data=sample_input_data, -# options={"embed_local_ml_library": True}, -# ) -# return stage_path, meta - -# def test_deployment_workflow(self) -> None: -# stage_path, meta = self._save_model_to_stage(model=_get_sklearn_model(), sample_input_data=_IRIS_X) -# service_func_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( -# self._RUN_ID, f"func_{self.uid}" -# ) -# deployment_options: model_types.SnowparkContainerServiceDeployOptions = { -# "compute_pool": self.TEST_COMPUTE_POOL, -# # image_repo is optional for user, pass in full image repo for test purposes only -# "image_repo": self._db_manager.get_snowservice_image_repo( -# subdomain=constants.DEV_IMAGE_REGISTRY_SUBDOMAIN, repo=self.TEST_IMAGE_REPO -# ), -# } -# model_api.deploy( -# name=service_func_name, -# session=self._session, -# stage_path=stage_path, -# platform=deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, -# target_method="predict", -# model_id=uuid.uuid4().hex, -# options={ -# **deployment_options, -# }, # type: ignore[call-overload] -# ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/model_badcase_integ_test.py b/tests/integ/snowflake/ml/model/model_badcase_integ_test.py deleted file mode 100644 index 26cbb9cb..00000000 --- a/tests/integ/snowflake/ml/model/model_badcase_integ_test.py +++ /dev/null @@ -1,182 +0,0 @@ -import posixpath -import uuid - -import numpy as np -import pandas as pd -from absl.testing import absltest - -from snowflake.ml._internal.exceptions import exceptions as snowml_exceptions -from snowflake.ml.model import ( - _api as model_api, - custom_model, - deploy_platforms, - type_hints as model_types, -) -from snowflake.ml.utils import connection_params -from snowflake.snowpark import Session, exceptions as snowpark_exceptions -from tests.integ.snowflake.ml.test_utils import db_manager, test_env_utils - - -class DemoModel(custom_model.CustomModel): - def __init__(self, context: custom_model.ModelContext) -> None: - super().__init__(context) - - @custom_model.inference_api - def predict(self, input: pd.DataFrame) -> pd.DataFrame: - return pd.DataFrame({"output": input["c1"]}) - - -class TestModelBadCaseInteg(absltest.TestCase): - @classmethod - def setUpClass(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - # To create different UDF names among different runs - self._db_manager = db_manager.DBManager(self._session) - - self._db_manager.cleanup_schemas() - self._db_manager.cleanup_stages() - self._db_manager.cleanup_user_functions() - - # To create different UDF names among different runs - self.run_id = uuid.uuid4().hex - self._test_schema_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "model_deployment_bad_case_test_schema" - ) - self._db_manager.create_schema(self._test_schema_name) - self._db_manager.use_schema(self._test_schema_name) - - self.deploy_stage_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "deployment_stage" - ) - self.full_qual_stage = self._db_manager.create_stage( - self.deploy_stage_name, schema_name=self._test_schema_name, sse_encrypted=False - ) - - @classmethod - def tearDownClass(self) -> None: - self._db_manager.drop_stage(self.deploy_stage_name, schema_name=self._test_schema_name) - self._db_manager.drop_schema(self._test_schema_name) - self._session.close() - - def test_bad_model_deploy(self) -> None: - lm = DemoModel(custom_model.ModelContext()) - arr = np.array([[1, 2, 3], [4, 2, 5]]) - pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - tmp_stage = self._session.get_session_stage() - model_api.save_model( - name="custom_bad_model", - session=self._session, - stage_path=posixpath.join(tmp_stage, "custom_bad_model"), - model=lm, - sample_input_data=pd_df, - metadata={"author": "halu", "version": "1"}, - conda_dependencies=["invalidnumpy==1.22.4"], - options=model_types.CustomModelSaveOption({"embed_local_ml_library": True}), - ) - function_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "custom_bad_model") - with self.assertRaises(snowml_exceptions.SnowflakeMLException) as e: - _ = model_api.deploy( - session=self._session, - name=function_name, - stage_path=posixpath.join(tmp_stage, "custom_bad_model"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - target_method="predict", - options=model_types.WarehouseDeployOptions({"relax_version": False}), - ) - self.assertIsInstance(e.exception.original_exception, RuntimeError) - - def test_custom_demo_model(self) -> None: - tmp_stage = self._session.get_session_stage() - lm = DemoModel(custom_model.ModelContext()) - arr = np.random.randint(100, size=(10000, 3)) - pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - - model_composer = model_api.save_model( - name="custom_demo_model", - session=self._session, - stage_path=posixpath.join(tmp_stage, "custom_demo_model"), - model=lm, - conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server( - self._session, "snowflake-snowpark-python!=1.12.0" - ) - ], - sample_input_data=pd_df, - metadata={"author": "halu", "version": "1"}, - ) - - self.assertIsNotNone(model_composer.packager.meta.env._snowpark_ml_version.local) - - function_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "custom_demo_model") - with self.assertRaises(snowml_exceptions.SnowflakeMLException) as e: - deploy_info = model_api.deploy( - session=self._session, - name=function_name, - stage_path=posixpath.join(tmp_stage, "custom_demo_model"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - target_method="predict", - options=model_types.WarehouseDeployOptions( - { - "permanent_udf_stage_location": f"{self.full_qual_stage}/", - # Test stage location validation - } - ), - ) - self.assertIsInstance(e.exception.original_exception, ValueError) - - deploy_info = model_api.deploy( - session=self._session, - name=function_name, - stage_path=posixpath.join(tmp_stage, "custom_demo_model"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - target_method="predict", - options=model_types.WarehouseDeployOptions( - { - "permanent_udf_stage_location": f"@{self.full_qual_stage}/", - } - ), - ) - assert deploy_info is not None - res = model_api.predict(session=self._session, deployment=deploy_info, X=pd_df) - - pd.testing.assert_frame_equal( - res, - pd.DataFrame(arr[:, 0], columns=["output"]), - ) - - with self.assertRaises(snowpark_exceptions.SnowparkSQLException): - deploy_info = model_api.deploy( - session=self._session, - name=function_name, - stage_path=posixpath.join(tmp_stage, "custom_demo_model"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - target_method="predict", - options=model_types.WarehouseDeployOptions( - { - "permanent_udf_stage_location": f"@{self.full_qual_stage}/", - } - ), - ) - - self._db_manager.drop_function(function_name=function_name, args=["OBJECT"]) - - deploy_info = model_api.deploy( - session=self._session, - name=function_name, - stage_path=posixpath.join(tmp_stage, "custom_demo_model"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - target_method="predict", - options=model_types.WarehouseDeployOptions( - { - "permanent_udf_stage_location": f"@{self.full_qual_stage}/", - "replace_udf": True, - } - ), - ) - - self._db_manager.drop_function(function_name=function_name, args=["OBJECT"]) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/spcs_llm_model_integ_test.py b/tests/integ/snowflake/ml/model/spcs_llm_model_integ_test.py deleted file mode 100644 index d3fbb1ce..00000000 --- a/tests/integ/snowflake/ml/model/spcs_llm_model_integ_test.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import tempfile - -import pandas as pd -import pytest -from absl.testing import absltest - -from snowflake.ml.model import ( - _api as model_api, - deploy_platforms, - type_hints as model_types, -) -from snowflake.ml.model.models import llm -from tests.integ.snowflake.ml.test_utils import ( - db_manager, - spcs_integ_test_base, - test_env_utils, -) - - -@pytest.mark.conda_incompatible -class TestSPCSLLMModelInteg(spcs_integ_test_base.SpcsIntegTestBase): - def setUp(self) -> None: - super().setUp() - self.cache_dir = tempfile.TemporaryDirectory() - self._original_hf_home = os.getenv("HF_HOME", None) - os.environ["HF_HOME"] = self.cache_dir.name - - def tearDown(self) -> None: - super().tearDown() - if self._original_hf_home: - os.environ["HF_HOME"] = self._original_hf_home - else: - del os.environ["HF_HOME"] - self.cache_dir.cleanup() - - def test_text_generation_pipeline( - self, - ) -> None: - model = llm.LLM( - model_id_or_path="facebook/opt-350m", - ) - - x_df = pd.DataFrame( - [["Hello world"]], - ) - - stage_path = f"@{self._test_stage}/{self._run_id}" - deployment_stage_path = f"@{self._test_stage}/{self._run_id}" - model_api.save_model( # type: ignore[call-overload] - name="model", - session=self._session, - stage_path=stage_path, - model=model, - options={"embed_local_ml_library": True}, - conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_conda("snowflake-snowpark-python"), - ], - ) - svc_func_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self._run_id, - f"func_{self._run_id}", - ) - - deployment_options: model_types.SnowparkContainerServiceDeployOptions = { - "compute_pool": self._TEST_GPU_COMPUTE_POOL, - "num_gpus": 1, - "model_in_image": True, - "external_access_integrations": self._SPCS_EAIS, - } - - deploy_info = model_api.deploy( - name=svc_func_name, - session=self._session, - stage_path=stage_path, - deployment_stage_path=deployment_stage_path, - model_id=svc_func_name, - platform=deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, - options={ - **deployment_options, # type: ignore[arg-type] - }, # type: ignore[call-overload] - ) - assert deploy_info is not None - res = model_api.predict(session=self._session, deployment=deploy_info, X=x_df) - self.assertIn("generated_text", res) - self.assertEqual(len(res["generated_text"]), 1) - self.assertNotEmpty(res["generated_text"][0]) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/warehouse_catboost_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_catboost_model_integ_test.py deleted file mode 100644 index ac8f1db6..00000000 --- a/tests/integ/snowflake/ml/model/warehouse_catboost_model_integ_test.py +++ /dev/null @@ -1,159 +0,0 @@ -import uuid -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import catboost -import inflection -import numpy as np -import pandas as pd -from absl.testing import absltest, parameterized -from sklearn import datasets, model_selection - -from snowflake.ml.model import type_hints as model_types -from snowflake.ml.utils import connection_params -from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session -from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import dataframe_utils, db_manager - - -class TestWarehouseLightGBMModelInteg(parameterized.TestCase): - @classmethod - def setUpClass(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - - self._db_manager = db_manager.DBManager(self._session) - self._db_manager.cleanup_schemas() - self._db_manager.cleanup_stages() - self._db_manager.cleanup_user_functions() - - # To create different UDF names among different runs - self.run_id = uuid.uuid4().hex - self._test_schema_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "model_deployment_catboost_model_test_schema" - ) - self._db_manager.create_schema(self._test_schema_name) - self._db_manager.use_schema(self._test_schema_name) - - self.deploy_stage_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "deployment_stage" - ) - self.full_qual_stage = self._db_manager.create_stage( - self.deploy_stage_name, schema_name=self._test_schema_name, sse_encrypted=False - ) - - @classmethod - def tearDownClass(self) -> None: - self._db_manager.drop_stage(self.deploy_stage_name, schema_name=self._test_schema_name) - self._db_manager.drop_schema(self._test_schema_name) - self._session.close() - - def base_test_case( - self, - name: str, - model: model_types.SupportedModelType, - sample_input_data: model_types.SupportedDataType, - test_input: model_types.SupportedDataType, - deploy_params: Dict[str, Tuple[Dict[str, Any], Callable[[Union[pd.DataFrame, SnowparkDataFrame]], Any]]], - permanent_deploy: Optional[bool] = False, - ) -> None: - warehouse_model_integ_test_utils.base_test_case( - self._db_manager, - run_id=self.run_id, - full_qual_stage=self.full_qual_stage, - name=name, - model=model, - sample_input_data=sample_input_data, - test_input=test_input, - deploy_params=deploy_params, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_catboost_classifier( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_X = cal_data.data - cal_y = cal_data.target - cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] - cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) - - classifier = catboost.CatBoostClassifier() - classifier.fit(cal_X_train, cal_y_train) - - self.base_test_case( - name="catboost_model", - model=classifier, - sample_input_data=cal_X_test, - test_input=cal_X_test, - deploy_params={ - "predict": ( - {}, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(classifier.predict(cal_X_test), axis=1) - ), - ), - "predict_proba": ( - {}, - lambda res: np.testing.assert_allclose( - res.values, - classifier.predict_proba(cal_X_test), - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_catboost_classifier_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_X = cal_data.data - cal_y = cal_data.target - cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] - cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) - - classifier = catboost.CatBoostClassifier() - classifier.fit(cal_X_train, cal_y_train) - - y_df_expected = pd.concat( - [ - cal_X_test.reset_index(drop=True), - pd.DataFrame(classifier.predict(cal_X_test), columns=["output_feature_0"]), - ], - axis=1, - ) - y_df_expected_proba = pd.concat( - [ - cal_X_test.reset_index(drop=True), - pd.DataFrame(classifier.predict_proba(cal_X_test), columns=["output_feature_0", "output_feature_1"]), - ], - axis=1, - ) - - cal_data_sp_df_train = self._session.create_dataframe(cal_X_train) - cal_data_sp_df_test = self._session.create_dataframe(cal_X_test) - self.base_test_case( - name="lightgbm_model_sp", - model=classifier, - sample_input_data=cal_data_sp_df_train, - test_input=cal_data_sp_df_test, - deploy_params={ - "predict": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), - ), - "predict_proba": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected_proba, check_dtype=False), - ), - }, - permanent_deploy=permanent_deploy, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/warehouse_custom_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_custom_model_integ_test.py deleted file mode 100644 index d193049b..00000000 --- a/tests/integ/snowflake/ml/model/warehouse_custom_model_integ_test.py +++ /dev/null @@ -1,512 +0,0 @@ -import asyncio -import os -import tempfile -import uuid -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import numpy as np -import pandas as pd -from absl.testing import absltest, parameterized - -from snowflake.ml.model import custom_model, type_hints as model_types -from snowflake.ml.utils import connection_params -from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session -from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import dataframe_utils, db_manager - - -class DemoModel(custom_model.CustomModel): - def __init__(self, context: custom_model.ModelContext) -> None: - super().__init__(context) - - @custom_model.inference_api - def predict(self, input: pd.DataFrame) -> pd.DataFrame: - return pd.DataFrame({"output": input["c1"]}) - - -class DemoModelSPQuote(custom_model.CustomModel): - def __init__(self, context: custom_model.ModelContext) -> None: - super().__init__(context) - - @custom_model.inference_api - def predict(self, input: pd.DataFrame) -> pd.DataFrame: - return pd.DataFrame({'"output"': input['"c1"']}) - - -class DemoModelArray(custom_model.CustomModel): - def __init__(self, context: custom_model.ModelContext) -> None: - super().__init__(context) - - @custom_model.inference_api - def predict(self, input: pd.DataFrame) -> pd.DataFrame: - return pd.DataFrame({"output": input.values.tolist()}) - - -class AsyncComposeModel(custom_model.CustomModel): - def __init__(self, context: custom_model.ModelContext) -> None: - super().__init__(context) - - @custom_model.inference_api - async def predict(self, input: pd.DataFrame) -> pd.DataFrame: - res1 = await self.context.model_ref("m1").predict.async_run(input) - res_sum = res1["output"] + self.context.model_ref("m2").predict(input)["output"] - return pd.DataFrame({"output": res_sum / 2}) - - -class DemoModelWithArtifacts(custom_model.CustomModel): - def __init__(self, context: custom_model.ModelContext) -> None: - super().__init__(context) - with open(context.path("bias"), encoding="utf-8") as f: - v = int(f.read()) - self.bias = v - - @custom_model.inference_api - def predict(self, input: pd.DataFrame) -> pd.DataFrame: - return pd.DataFrame({"output": (input["c1"] + self.bias) > 12}) - - -class TestWarehouseCustomModelInteg(parameterized.TestCase): - @classmethod - def setUpClass(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - - self._db_manager = db_manager.DBManager(self._session) - self._db_manager.cleanup_schemas() - self._db_manager.cleanup_stages() - self._db_manager.cleanup_user_functions() - - # To create different UDF names among different runs - self.run_id = uuid.uuid4().hex - self._test_schema_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "model_deployment_custom_model_test_schema" - ) - self._db_manager.create_schema(self._test_schema_name) - self._db_manager.use_schema(self._test_schema_name) - - self.deploy_stage_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "deployment_stage" - ) - self.full_qual_stage = self._db_manager.create_stage( - self.deploy_stage_name, schema_name=self._test_schema_name, sse_encrypted=False - ) - - @classmethod - def tearDownClass(self) -> None: - self._db_manager.drop_stage(self.deploy_stage_name, schema_name=self._test_schema_name) - self._db_manager.drop_schema(self._test_schema_name) - self._session.close() - - def base_test_case( - self, - name: str, - model: model_types.SupportedModelType, - sample_input_data: model_types.SupportedDataType, - test_input: model_types.SupportedDataType, - deploy_params: Dict[str, Tuple[Dict[str, Any], Callable[[Union[pd.DataFrame, SnowparkDataFrame]], Any]]], - permanent_deploy: Optional[bool] = False, - ) -> None: - warehouse_model_integ_test_utils.base_test_case( - self._db_manager, - run_id=self.run_id, - full_qual_stage=self.full_qual_stage, - name=name, - model=model, - sample_input_data=sample_input_data, - test_input=test_input, - deploy_params=deploy_params, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_async_model_composition( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - async def _test(self: "TestWarehouseCustomModelInteg") -> None: - arr = np.random.randint(100, size=(10000, 3)) - pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - clf = DemoModel(custom_model.ModelContext()) - model_context = custom_model.ModelContext( - models={ - "m1": clf, - "m2": clf, - } - ) - acm = AsyncComposeModel(model_context) - self.base_test_case( - name="async_model_composition", - model=acm, - sample_input_data=pd_df, - test_input=pd_df, - deploy_params={ - "": ( - {}, - lambda res: pd.testing.assert_frame_equal( - res, - pd.DataFrame(arr[:, 0], columns=["output"], dtype=float), - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - asyncio.get_event_loop().run_until_complete(_test(self)) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_demo_model_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - lm = DemoModel(custom_model.ModelContext()) - arr = [[1, 2, 3], [4, 2, 5]] - sp_df = self._session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) - y_df_expected = pd.DataFrame([[1, 2, 3, 1], [4, 2, 5, 4]], columns=["c1", "c2", "c3", "output"]) - self.base_test_case( - name="custom_demo_model_sp0", - model=lm, - sample_input_data=sp_df, - test_input=sp_df, - deploy_params={ - "": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_demo_model_sp_quote( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - lm = DemoModelSPQuote(custom_model.ModelContext()) - arr = [[1, 2, 3], [4, 2, 5]] - sp_df = self._session.create_dataframe(arr, schema=['"""c1"""', '"""c2"""', '"""c3"""']) - pd_df = pd.DataFrame(arr, columns=['"c1"', '"c2"', '"c3"']) - self.base_test_case( - name="custom_demo_model_sp_quote", - model=lm, - sample_input_data=sp_df, - test_input=pd_df, - deploy_params={ - "": ( - {}, - lambda res: pd.testing.assert_frame_equal( - res, - pd.DataFrame([1, 4], columns=['"output"'], dtype=np.int8), - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_demo_model_sp_quote_norm_1( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - lm = DemoModelSPQuote(custom_model.ModelContext()) - arr = [[1, 2, 3], [4, 2, 5]] - pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - sp_df = self._session.create_dataframe(arr, schema=['"""c1"""', '"""c2"""', '"""c3"""']) - sp_df_1 = self._session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) - y_df_expected = pd.concat([pd_df, pd_df[["c1"]].rename(columns={"c1": "output"})], axis=1) - self.base_test_case( - name="custom_demo_model_sp_quote", - model=lm, - sample_input_data=sp_df, - test_input=sp_df_1, - deploy_params={ - "": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_demo_model_sp_mix_1( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - lm = DemoModel(custom_model.ModelContext()) - arr = [[1, 2, 3], [4, 2, 5]] - pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - sp_df = self._session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) - y_df_expected = pd.concat([pd_df, pd_df[["c1"]].rename(columns={"c1": "output"})], axis=1) - self.base_test_case( - name="custom_demo_model_sp1", - model=lm, - sample_input_data=pd_df, - test_input=sp_df, - deploy_params={ - "": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_demo_model_sp_mix_1_norm( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - lm = DemoModel(custom_model.ModelContext()) - arr = [[1, 2, 3], [4, 2, 5]] - pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - sp_df = self._session.create_dataframe(arr, schema=["c1", "c2", "c3"]) - y_df_expected = pd.concat( - [pd_df.rename(columns=str.upper), pd_df[["c1"]].rename(columns={"c1": "OUTPUT"})], axis=1 - ) - self.base_test_case( - name="custom_demo_model_sp1", - model=lm, - sample_input_data=pd_df, - test_input=sp_df, - deploy_params={ - "": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_demo_model_sp_mix_2( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - lm = DemoModel(custom_model.ModelContext()) - arr = [[1, 2, 3], [4, 2, 5]] - pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - sp_df = self._session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) - self.base_test_case( - name="custom_demo_model_sp2", - model=lm, - sample_input_data=sp_df, - test_input=pd_df, - deploy_params={ - "": ( - {}, - lambda res: pd.testing.assert_frame_equal( - res, - pd.DataFrame([1, 4], columns=["output"], dtype=np.int8), - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_demo_model_sp_mix_2_norm( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - lm = DemoModel(custom_model.ModelContext()) - arr = [[1, 2, 3], [4, 2, 5]] - sp_df = self._session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) - sp_df_1 = self._session.create_dataframe(arr, schema=["c1", "c2", "c3"]) - pd_df = pd.DataFrame(arr, columns=["C1", "C2", "C3"]) - y_df_expected = pd.concat([pd_df, pd_df[["C1"]].rename(columns={"C1": "OUTPUT"})], axis=1) - self.base_test_case( - name="custom_demo_model_sp2", - model=lm, - sample_input_data=sp_df, - test_input=sp_df_1, - deploy_params={ - "": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_demo_model_array( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - lm = DemoModelArray(custom_model.ModelContext()) - arr = np.array([[1, 2, 3], [4, 2, 5]]) - pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - self.base_test_case( - name="custom_demo_model_array", - model=lm, - sample_input_data=pd_df, - test_input=pd_df, - deploy_params={ - "": ( - {}, - lambda res: pd.testing.assert_frame_equal( - res, - pd.DataFrame(data={"output": [[1, 2, 3], [4, 2, 5]]}), - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_demo_model_str( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - lm = DemoModel(custom_model.ModelContext()) - pd_df = pd.DataFrame([["Yogiri", "Civia", "Echo"], ["Artia", "Doris", "Rosalyn"]], columns=["c1", "c2", "c3"]) - self.base_test_case( - name="custom_demo_model_str", - model=lm, - sample_input_data=pd_df, - test_input=pd_df, - deploy_params={ - "": ( - {}, - lambda res: pd.testing.assert_frame_equal( - res, - pd.DataFrame(data={"output": ["Yogiri", "Artia"]}), - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_demo_model_array_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - lm = DemoModelArray(custom_model.ModelContext()) - arr = np.array([[1, 2, 3], [4, 2, 5]]) - pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - sp_df = self._session.create_dataframe(pd_df) - y_df_expected = pd.concat([pd_df, pd.DataFrame(data={"output": [[1, 2, 3], [4, 2, 5]]})], axis=1) - self.base_test_case( - name="custom_demo_model_array_sp", - model=lm, - sample_input_data=sp_df, - test_input=sp_df, - deploy_params={ - "": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), - ) - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_demo_model_str_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - lm = DemoModel(custom_model.ModelContext()) - pd_df = pd.DataFrame([["Yogiri", "Civia", "Echo"], ["Artia", "Doris", "Rosalyn"]], columns=["c1", "c2", "c3"]) - sp_df = self._session.create_dataframe(pd_df) - y_df_expected = pd.concat([pd_df, pd.DataFrame(data={"output": ["Yogiri", "Artia"]})], axis=1) - self.base_test_case( - name="custom_demo_model_str_sp", - model=lm, - sample_input_data=sp_df, - test_input=sp_df, - deploy_params={ - "": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), - ) - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_demo_model_array_str( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - lm = DemoModelArray(custom_model.ModelContext()) - pd_df = pd.DataFrame([["Yogiri", "Civia", "Echo"], ["Artia", "Doris", "Rosalyn"]], columns=["c1", "c2", "c3"]) - self.base_test_case( - name="custom_demo_model_array_str", - model=lm, - sample_input_data=pd_df, - test_input=pd_df, - deploy_params={ - "": ( - {}, - lambda res: pd.testing.assert_frame_equal( - res, - pd.DataFrame(data={"output": [["Yogiri", "Civia", "Echo"], ["Artia", "Doris", "Rosalyn"]]}), - ), - ) - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_model_with_artifacts( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - with open(os.path.join(tmpdir, "bias"), "w", encoding="utf-8") as f: - f.write("10") - lm = DemoModelWithArtifacts( - custom_model.ModelContext(models={}, artifacts={"bias": os.path.join(tmpdir, "bias")}) - ) - arr = np.array([[1, 2, 3], [4, 2, 5]]) - pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - self.base_test_case( - name="custom_model_with_artifacts", - model=lm, - sample_input_data=pd_df, - test_input=pd_df, - deploy_params={ - "": ( - {}, - lambda res: pd.testing.assert_frame_equal( - res, - pd.DataFrame([False, True], columns=["output"]), - ), - ) - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_custom_model_bool_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - with open(os.path.join(tmpdir, "bias"), "w", encoding="utf-8") as f: - f.write("10") - lm = DemoModelWithArtifacts( - custom_model.ModelContext(models={}, artifacts={"bias": os.path.join(tmpdir, "bias")}) - ) - arr = np.array([[1, 2, 3], [4, 2, 5]]) - pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - sp_df = self._session.create_dataframe(pd_df) - y_df_expected = pd.concat([pd_df, pd.DataFrame([False, True], columns=["output"])], axis=1) - self.base_test_case( - name="custom_model_bool_sp", - model=lm, - sample_input_data=sp_df, - test_input=sp_df, - deploy_params={ - "": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), - ) - }, - permanent_deploy=permanent_deploy, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/warehouse_huggingface_pipeline_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_huggingface_pipeline_model_integ_test.py deleted file mode 100644 index 2242bd59..00000000 --- a/tests/integ/snowflake/ml/model/warehouse_huggingface_pipeline_model_integ_test.py +++ /dev/null @@ -1,675 +0,0 @@ -import json -import os -import tempfile -import uuid -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import pandas as pd -from absl.testing import absltest, parameterized -from packaging import requirements, version - -from snowflake.ml._internal import env_utils -from snowflake.ml.model import type_hints as model_types -from snowflake.ml.utils import connection_params -from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session -from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import db_manager - - -class TestWarehouseHuggingFacehModelInteg(parameterized.TestCase): - @classmethod - def setUpClass(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - - self._db_manager = db_manager.DBManager(self._session) - self._db_manager.cleanup_schemas() - self._db_manager.cleanup_stages() - self._db_manager.cleanup_user_functions() - - self.cache_dir = tempfile.TemporaryDirectory() - self._original_cache_dir = os.getenv("TRANSFORMERS_CACHE", None) - os.environ["TRANSFORMERS_CACHE"] = self.cache_dir.name - - # To create different UDF names among different runs - self.run_id = uuid.uuid4().hex - self._test_schema_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "model_deployment_huggingface_model_test_schema" - ) - self._db_manager.create_schema(self._test_schema_name) - self._db_manager.use_schema(self._test_schema_name) - - self.deploy_stage_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "deployment_stage" - ) - self.full_qual_stage = self._db_manager.create_stage( - self.deploy_stage_name, - schema_name=self._test_schema_name, - sse_encrypted=False, - ) - - @classmethod - def tearDownClass(self) -> None: - self._db_manager.drop_stage(self.deploy_stage_name, schema_name=self._test_schema_name) - self._db_manager.drop_schema(self._test_schema_name) - self._session.close() - if self._original_cache_dir: - os.environ["TRANSFORMERS_CACHE"] = self._original_cache_dir - self.cache_dir.cleanup() - - def base_test_case( - self, - name: str, - model: model_types.SupportedModelType, - test_input: model_types.SupportedDataType, - deploy_params: Dict[ - str, - Tuple[Dict[str, Any], Callable[[Union[pd.DataFrame, SnowparkDataFrame]], Any]], - ], - permanent_deploy: Optional[bool] = False, - additional_dependencies: Optional[List[str]] = None, - ) -> None: - warehouse_model_integ_test_utils.base_test_case( - self._db_manager, - run_id=self.run_id, - full_qual_stage=self.full_qual_stage, - name=name, - model=model, - sample_input_data=None, - test_input=test_input, - deploy_params=deploy_params, - permanent_deploy=permanent_deploy, - additional_dependencies=additional_dependencies, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_conversational_pipeline( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - # We have to import here due to cache location issue. - # Only by doing so can we make the cache dir setting effective. - import transformers - - if version.parse(transformers.__version__) >= version.parse("4.42.0"): - self.skipTest("This test is not compatible with transformers>=4.42.0") - - model = transformers.pipeline(task="conversational", model="ToddGoldfarb/Cadet-Tiny") - - x_df = pd.DataFrame( - [ - { - "user_inputs": [ - "Do you speak French?", - "Do you know how to say Snowflake in French?", - ], - "generated_responses": ["Yes I do."], - }, - ] - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["generated_responses"])) - - for row in res["generated_responses"]: - self.assertIsInstance(row, list) - for resp in row: - self.assertIsInstance(resp, str) - - self.base_test_case( - name="huggingface_conversational_pipeline", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_fill_mask_pipeline( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - import transformers - - model = transformers.pipeline( - task="fill-mask", - model="sshleifer/tiny-distilroberta-base", - top_k=1, - ) - - x_df = pd.DataFrame( - [ - ["LynYuu is the of the Grand Duchy of Yu."], - ] - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["outputs"])) - - for row in res["outputs"]: - self.assertIsInstance(row, str) - resp = json.loads(row) - self.assertIsInstance(resp, list) - self.assertIn("score", resp[0]) - self.assertIn("token", resp[0]) - self.assertIn("token_str", resp[0]) - self.assertIn("sequence", resp[0]) - - self.base_test_case( - name="huggingface_fill_mask_pipeline", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_ner_pipeline( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - import transformers - - model = transformers.pipeline(task="ner", model="hf-internal-testing/tiny-bert-for-token-classification") - - x_df = pd.DataFrame( - [ - ["My name is Izumi and I live in Tokyo, Japan."], - ] - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["outputs"])) - - for row in res["outputs"]: - self.assertIsInstance(row, str) - resp = json.loads(row) - self.assertIsInstance(resp, list) - self.assertIn("entity", resp[0]) - self.assertIn("score", resp[0]) - self.assertIn("index", resp[0]) - self.assertIn("word", resp[0]) - self.assertIn("start", resp[0]) - self.assertIn("end", resp[0]) - - self.base_test_case( - name="huggingface_ner_pipeline", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_question_answering_pipeline( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - import transformers - - model = transformers.pipeline( - task="question-answering", - model="sshleifer/tiny-distilbert-base-cased-distilled-squad", - top_k=1, - ) - - x_df = pd.DataFrame( - [ - { - "question": "What did Doris want to do?", - "context": ( - "Doris is a cheerful mermaid from the ocean depths. She transformed into a bipedal creature " - 'and came to see everyone because she wanted to "learn more about the world of athletics."' - " She dislikes cuisines with seafood." - ), - } - ], - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["score", "start", "end", "answer"])) - - self.assertEqual(res["score"].dtype.type, np.float64) - self.assertEqual(res["start"].dtype.type, np.int64) - self.assertEqual(res["end"].dtype.type, np.int64) - self.assertEqual(res["answer"].dtype.type, np.object_) - - self.base_test_case( - name="huggingface_question_answering_pipeline", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_question_answering_pipeline_multiple_output( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - import transformers - - model = transformers.pipeline( - task="question-answering", - model="sshleifer/tiny-distilbert-base-cased-distilled-squad", - top_k=3, - ) - - x_df = pd.DataFrame( - [ - { - "question": "What did Doris want to do?", - "context": ( - "Doris is a cheerful mermaid from the ocean depths. She transformed into a bipedal creature " - 'and came to see everyone because she wanted to "learn more about the world of athletics."' - " She dislikes cuisines with seafood." - ), - } - ], - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["outputs"])) - - for row in res["outputs"]: - self.assertIsInstance(row, str) - resp = json.loads(row) - self.assertIsInstance(resp, list) - self.assertIn("score", resp[0]) - self.assertIn("start", resp[0]) - self.assertIn("end", resp[0]) - self.assertIn("answer", resp[0]) - - self.base_test_case( - name="huggingface_question_answering_pipeline_multiple_output", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_summarization_pipeline( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - import transformers - - model = transformers.pipeline(task="summarization", model="sshleifer/tiny-mbart") - - x_df = pd.DataFrame( - [ - [ - ( - "Neuro-sama is a chatbot styled after a female VTuber that hosts live streams on the Twitch " - 'channel "vedal987". Her speech and personality are generated by an artificial intelligence' - " (AI) system which utilizes a large language model, allowing her to communicate with " - "viewers in a live chat. She was created by a computer programmer and AI-developer named " - "Jack Vedal, who decided to build upon the concept of an AI VTuber by combining interactions " - "between AI game play and a computer-generated avatar. She debuted on Twitch on December 19, " - "2022 after four years of development." - ) - ] - ], - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["summary_text"])) - - self.assertEqual(res["summary_text"].dtype.type, np.object_) - - self.base_test_case( - name="huggingface_summarization_pipeline", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - additional_dependencies=[ - str(env_utils.get_local_installed_version_of_pip_package(requirements.Requirement("sentencepiece"))) - ], - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_table_question_answering_pipeline( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - import transformers - - model = transformers.pipeline(task="table-question-answering", model="google/tapas-tiny-finetuned-wtq") - - x_df = pd.DataFrame( - [ - { - "query": "Which channel has the most subscribers?", - "table": json.dumps( - { - "Channel": [ - "A.I.Channel", - "Kaguya Luna", - "Mirai Akari", - "Siro", - ], - "Subscribers": [ - "3,020,000", - "872,000", - "694,000", - "660,000", - ], - "Videos": ["1,200", "113", "639", "1,300"], - "Created At": [ - "Jun 30 2016", - "Dec 4 2017", - "Feb 28 2014", - "Jun 23 2017", - ], - } - ), - } - ], - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["answer", "coordinates", "cells", "aggregator"])) - - self.assertEqual(res["answer"].dtype.type, np.object_) - self.assertEqual(res["coordinates"].dtype.type, np.object_) - self.assertIsInstance(res["coordinates"][0], list) - self.assertEqual(res["cells"].dtype.type, np.object_) - self.assertIsInstance(res["cells"][0], list) - self.assertEqual(res["aggregator"].dtype.type, np.object_) - - self.base_test_case( - name="huggingface_table_question_answering_pipeline", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_text_classification_pair_pipeline( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - import transformers - - model = transformers.pipeline(task="text-classification", model="cross-encoder/ms-marco-MiniLM-L-12-v2") - - x_df = pd.DataFrame( - [{"text": "I like you.", "text_pair": "I love you, too."}], - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["label", "score"])) - - self.assertEqual(res["label"].dtype.type, np.object_) - self.assertEqual(res["score"].dtype.type, np.float64) - - self.base_test_case( - name="huggingface_text_classification_pair_pipeline", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_text_classification_pipeline( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - import transformers - - model = transformers.pipeline( - task="text-classification", - model="hf-internal-testing/tiny-random-distilbert", - top_k=1, - ) - - x_df = pd.DataFrame( - [ - { - "text": "I am wondering if I should have udon or rice for lunch", - "text_pair": "", - } - ], - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["outputs"])) - - for row in res["outputs"]: - self.assertIsInstance(row, str) - resp = json.loads(row) - self.assertIsInstance(resp, list) - self.assertIn("label", resp[0]) - self.assertIn("score", resp[0]) - - self.base_test_case( - name="huggingface_text_classification_pipeline", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_text_generation_pipeline( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - import transformers - - model = transformers.pipeline( - task="text-generation", - model="sshleifer/tiny-ctrl", - ) - - x_df = pd.DataFrame( - [['A descendant of the Lost City of Atlantis, who swam to Earth while saying, "']], - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["outputs"])) - - for row in res["outputs"]: - self.assertIsInstance(row, str) - resp = json.loads(row) - self.assertIsInstance(resp, list) - self.assertIn("generated_text", resp[0]) - - self.base_test_case( - name="huggingface_text_generation_pipeline", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_text2text_generation_pipeline( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - import transformers - - model = transformers.pipeline( - task="text2text-generation", - model="patrickvonplaten/t5-tiny-random", - ) - - x_df = pd.DataFrame( - [['A descendant of the Lost City of Atlantis, who swam to Earth while saying, "']], - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["generated_text"])) - self.assertEqual(res["generated_text"].dtype.type, np.object_) - - self.base_test_case( - name="huggingface_text2text_generation_pipeline", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_translation_pipeline( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - import transformers - - model = transformers.pipeline(task="translation_en_to_ja", model="patrickvonplaten/t5-tiny-random") - - x_df = pd.DataFrame( - [ - [ - ( - "Snowflake's Data Cloud is powered by an advanced data platform provided as a self-managed " - "service. Snowflake enables data storage, processing, and analytic solutions that are faster, " - "easier to use, and far more flexible than traditional offerings. The Snowflake data platform " - "is not built on any existing database technology or “big data” software platforms such as " - "Hadoop. Instead, Snowflake combines a completely new SQL query engine with an innovative " - "architecture natively designed for the cloud. To the user, Snowflake provides all of the " - "functionality of an enterprise analytic database, along with many additional special features " - "and unique capabilities." - ) - ] - ], - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["translation_text"])) - self.assertEqual(res["translation_text"].dtype.type, np.object_) - - self.base_test_case( - name="huggingface_translation_pipeline", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_zero_shot_classification_pipeline( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - import transformers - - model = transformers.pipeline( - task="zero-shot-classification", - model="sshleifer/tiny-distilbert-base-cased-distilled-squad", - ) - - x_df = pd.DataFrame( - [ - { - "sequences": "I have a problem with Snowflake that needs to be resolved asap!!", - "candidate_labels": ["urgent", "not urgent"], - }, - { - "sequences": "I have a problem with Snowflake that needs to be resolved asap!!", - "candidate_labels": ["English", "Japanese"], - }, - ], - ) - - def check_res(res: pd.DataFrame) -> None: - pd.testing.assert_index_equal(res.columns, pd.Index(["sequence", "labels", "scores"])) - self.assertEqual(res["sequence"].dtype.type, np.object_) - self.assertEqual( - res["sequence"][0], - "I have a problem with Snowflake that needs to be resolved asap!!", - ) - self.assertEqual( - res["sequence"][1], - "I have a problem with Snowflake that needs to be resolved asap!!", - ) - self.assertEqual(res["labels"].dtype.type, np.object_) - self.assertListEqual(sorted(res["labels"][0]), sorted(["urgent", "not urgent"])) - self.assertListEqual(sorted(res["labels"][1]), sorted(["English", "Japanese"])) - self.assertEqual(res["scores"].dtype.type, np.object_) - self.assertIsInstance(res["labels"][0], list) - self.assertIsInstance(res["labels"][1], list) - - self.base_test_case( - name="huggingface_zero_shot_classification_pipeline", - model=model, - test_input=x_df, - deploy_params={ - "": ( - {}, - check_res, - ), - }, - permanent_deploy=permanent_deploy, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/warehouse_lightgbm_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_lightgbm_model_integ_test.py deleted file mode 100644 index 0986fce3..00000000 --- a/tests/integ/snowflake/ml/model/warehouse_lightgbm_model_integ_test.py +++ /dev/null @@ -1,223 +0,0 @@ -import uuid -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import inflection -import lightgbm -import numpy as np -import pandas as pd -from absl.testing import absltest, parameterized -from sklearn import datasets, model_selection - -from snowflake.ml.model import type_hints as model_types -from snowflake.ml.utils import connection_params -from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session -from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import dataframe_utils, db_manager - - -class TestWarehouseLightGBMModelInteg(parameterized.TestCase): - @classmethod - def setUpClass(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - - self._db_manager = db_manager.DBManager(self._session) - self._db_manager.cleanup_schemas() - self._db_manager.cleanup_stages() - self._db_manager.cleanup_user_functions() - - # To create different UDF names among different runs - self.run_id = uuid.uuid4().hex - self._test_schema_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "model_deployment_lightgbm_model_test_schema" - ) - self._db_manager.create_schema(self._test_schema_name) - self._db_manager.use_schema(self._test_schema_name) - - self.deploy_stage_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "deployment_stage" - ) - self.full_qual_stage = self._db_manager.create_stage( - self.deploy_stage_name, schema_name=self._test_schema_name, sse_encrypted=False - ) - - @classmethod - def tearDownClass(self) -> None: - self._db_manager.drop_stage(self.deploy_stage_name, schema_name=self._test_schema_name) - self._db_manager.drop_schema(self._test_schema_name) - self._session.close() - - def base_test_case( - self, - name: str, - model: model_types.SupportedModelType, - sample_input_data: model_types.SupportedDataType, - test_input: model_types.SupportedDataType, - deploy_params: Dict[str, Tuple[Dict[str, Any], Callable[[Union[pd.DataFrame, SnowparkDataFrame]], Any]]], - permanent_deploy: Optional[bool] = False, - ) -> None: - warehouse_model_integ_test_utils.base_test_case( - self._db_manager, - run_id=self.run_id, - full_qual_stage=self.full_qual_stage, - name=name, - model=model, - sample_input_data=sample_input_data, - test_input=test_input, - deploy_params=deploy_params, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_lightgbm_classifier( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_X = cal_data.data - cal_y = cal_data.target - cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] - cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) - - classifier = lightgbm.LGBMClassifier() - classifier.fit(cal_X_train, cal_y_train) - - self.base_test_case( - name="lightgbm_model", - model=classifier, - sample_input_data=cal_X_test, - test_input=cal_X_test, - deploy_params={ - "predict": ( - {}, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(classifier.predict(cal_X_test), axis=1) - ), - ), - "predict_proba": ( - {}, - lambda res: np.testing.assert_allclose( - res.values, - classifier.predict_proba(cal_X_test), - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_lightgbm_classifier_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_X = cal_data.data - cal_y = cal_data.target - cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] - cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) - - classifier = lightgbm.LGBMClassifier() - classifier.fit(cal_X_train, cal_y_train) - - y_df_expected = pd.concat( - [ - cal_X_test.reset_index(drop=True), - pd.DataFrame(classifier.predict(cal_X_test), columns=["output_feature_0"]), - ], - axis=1, - ) - y_df_expected_proba = pd.concat( - [ - cal_X_test.reset_index(drop=True), - pd.DataFrame(classifier.predict_proba(cal_X_test), columns=["output_feature_0", "output_feature_1"]), - ], - axis=1, - ) - - cal_data_sp_df_train = self._session.create_dataframe(cal_X_train) - cal_data_sp_df_test = self._session.create_dataframe(cal_X_test) - self.base_test_case( - name="lightgbm_model_sp", - model=classifier, - sample_input_data=cal_data_sp_df_train, - test_input=cal_data_sp_df_test, - deploy_params={ - "predict": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), - ), - "predict_proba": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected_proba, check_dtype=False), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_lightgbm_booster( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_X = cal_data.data - cal_y = cal_data.target - cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] - cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) - - regressor = lightgbm.train({"objective": "regression"}, lightgbm.Dataset(cal_X_train, label=cal_y_train)) - y_pred = regressor.predict(cal_X_test) - - self.base_test_case( - name="lightgbm_booster", - model=regressor, - sample_input_data=cal_X_test, - test_input=cal_X_test, - deploy_params={ - "predict": ( - {}, - lambda res: np.testing.assert_allclose(res.values, np.expand_dims(y_pred, axis=1), rtol=1e-6), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_lightgbm_booster_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_X = cal_data.data - cal_y = cal_data.target - cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] - cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) - - regressor = lightgbm.train({"objective": "regression"}, lightgbm.Dataset(cal_X_train, label=cal_y_train)) - y_df_expected = pd.concat( - [ - cal_X_test.reset_index(drop=True), - pd.DataFrame(regressor.predict(cal_X_test), columns=["output_feature_0"]), - ], - axis=1, - ) - - cal_data_sp_df_train = self._session.create_dataframe(cal_X_train) - cal_data_sp_df_test = self._session.create_dataframe(cal_X_test) - self.base_test_case( - name="lightgbm_booster_sp", - model=regressor, - sample_input_data=cal_data_sp_df_train, - test_input=cal_data_sp_df_test, - deploy_params={ - "predict": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), - ), - }, - permanent_deploy=permanent_deploy, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/warehouse_mlflow_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_mlflow_model_integ_test.py deleted file mode 100644 index d7c2b874..00000000 --- a/tests/integ/snowflake/ml/model/warehouse_mlflow_model_integ_test.py +++ /dev/null @@ -1,185 +0,0 @@ -import uuid -from importlib import metadata as importlib_metadata -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import mlflow -import numpy as np -import pandas as pd -from absl.testing import absltest, parameterized -from sklearn import datasets, ensemble, model_selection - -from snowflake.ml._internal import env -from snowflake.ml.model import type_hints as model_types -from snowflake.ml.model._signatures import numpy_handler -from snowflake.ml.utils import connection_params -from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session -from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import db_manager - - -class TestWarehouseMLFlowModelInteg(parameterized.TestCase): - @classmethod - def setUpClass(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - - self._db_manager = db_manager.DBManager(self._session) - self._db_manager.cleanup_schemas() - self._db_manager.cleanup_stages() - self._db_manager.cleanup_user_functions() - - # To create different UDF names among different runs - self.run_id = uuid.uuid4().hex - self._test_schema_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "model_deployment_mlflow_model_test_schema" - ) - self._db_manager.create_schema(self._test_schema_name) - self._db_manager.use_schema(self._test_schema_name) - - self.deploy_stage_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "deployment_stage" - ) - self.full_qual_stage = self._db_manager.create_stage( - self.deploy_stage_name, schema_name=self._test_schema_name, sse_encrypted=False - ) - - @classmethod - def tearDownClass(self) -> None: - self._db_manager.drop_stage(self.deploy_stage_name, schema_name=self._test_schema_name) - self._db_manager.drop_schema(self._test_schema_name) - self._session.close() - - def base_test_case( - self, - name: str, - model: model_types.SupportedModelType, - sample_input_data: model_types.SupportedDataType, - test_input: model_types.SupportedDataType, - deploy_params: Dict[str, Tuple[Dict[str, Any], Callable[[Union[pd.DataFrame, SnowparkDataFrame]], Any]]], - permanent_deploy: Optional[bool] = False, - ) -> None: - warehouse_model_integ_test_utils.base_test_case( - self._db_manager, - run_id=self.run_id, - full_qual_stage=self.full_qual_stage, - name=name, - model=model, - sample_input_data=sample_input_data, - test_input=test_input, - deploy_params=deploy_params, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_mlflow_model_deploy_sklearn_df( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - db = datasets.load_diabetes(as_frame=True) - X_train, X_test, y_train, y_test = model_selection.train_test_split(db.data, db.target) - with mlflow.start_run() as run: - rf = ensemble.RandomForestRegressor(n_estimators=100, max_depth=6, max_features=3) - rf.fit(X_train, y_train) - - # Use the model to make predictions on the test dataset. - predictions = rf.predict(X_test) - signature = mlflow.models.signature.infer_signature(X_test, predictions) - mlflow.sklearn.log_model( - rf, - "model", - signature=signature, - metadata={"author": "halu", "version": "1"}, - conda_env={ - "dependencies": [f"python=={env.PYTHON_VERSION}"] - + list( - map( - lambda pkg: f"{pkg}=={importlib_metadata.distribution(pkg).version}", - [ - "mlflow", - "cloudpickle", - "numpy", - "scikit-learn", - "scipy", - "typing-extensions", - ], - ) - ), - "name": "mlflow-env", - }, - ) - - run_id = run.info.run_id - - self.base_test_case( - name="mlflow_model_sklearn_df", - model=mlflow.pyfunc.load_model(f"runs:/{run_id}/model"), - sample_input_data=None, - test_input=X_test, - deploy_params={ - "": ( - {}, - lambda res: np.testing.assert_allclose(np.expand_dims(predictions, axis=1), res.to_numpy()), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_mlflow_model_deploy_sklearn( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - db = datasets.load_diabetes() - X_train, X_test, y_train, y_test = model_selection.train_test_split(db.data, db.target) - with mlflow.start_run() as run: - rf = ensemble.RandomForestRegressor(n_estimators=100, max_depth=6, max_features=3) - rf.fit(X_train, y_train) - - # Use the model to make predictions on the test dataset. - predictions = rf.predict(X_test) - signature = mlflow.models.signature.infer_signature(X_test, predictions) - mlflow.sklearn.log_model( - rf, - "model", - signature=signature, - metadata={"author": "halu", "version": "1"}, - conda_env={ - "dependencies": [f"python=={env.PYTHON_VERSION}"] - + list( - map( - lambda pkg: f"{pkg}=={importlib_metadata.distribution(pkg).version}", - [ - "mlflow", - "cloudpickle", - "numpy", - "scikit-learn", - "scipy", - "typing-extensions", - ], - ) - ), - "name": "mlflow-env", - }, - ) - - run_id = run.info.run_id - - X_test_df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df([X_test]) - - self.base_test_case( - name="mlflow_model_sklearn", - model=mlflow.pyfunc.load_model(f"runs:/{run_id}/model"), - sample_input_data=None, - test_input=X_test_df, - deploy_params={ - "": ( - {}, - lambda res: np.testing.assert_allclose(np.expand_dims(predictions, axis=1), res.to_numpy()), - ), - }, - permanent_deploy=permanent_deploy, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/warehouse_model_compat_v1_test.py b/tests/integ/snowflake/ml/model/warehouse_model_compat_v1_test.py deleted file mode 100644 index 3263e5c9..00000000 --- a/tests/integ/snowflake/ml/model/warehouse_model_compat_v1_test.py +++ /dev/null @@ -1,701 +0,0 @@ -import posixpath -import uuid -from typing import Callable, Tuple - -import numpy as np -import pandas as pd -from absl.testing import absltest -from packaging import version -from sklearn import datasets - -from snowflake.ml._internal import env -from snowflake.ml.model import _api as model_api, deploy_platforms -from snowflake.snowpark import session -from tests.integ.snowflake.ml.test_utils import common_test_base, db_manager - - -@absltest.skipIf( - version.Version(env.PYTHON_VERSION) >= version.Version("3.11"), - "Skip compat test for Python higher than 3.11 since we previously does not support it.", -) -class TestWarehouseCustomModelCompat(common_test_base.CommonTestBase): - def setUp(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - super().setUp() - self.run_id = uuid.uuid4().hex - self.session_stage = self.session.get_session_stage() - self.model_stage_path = posixpath.join(self.session_stage, self.run_id) - self.model_stage_file_path = posixpath.join(self.session_stage, self.run_id, f"{self.run_id}.zip") - - def _log_model_factory( - self, - ) -> Tuple[Callable[[session.Session, str, str], None], Tuple[str, str]]: - def log_model(session: session.Session, run_id: str, model_stage_file_path: str) -> None: - import pandas as pd - - from snowflake.ml.model import ( # type: ignore[attr-defined] - _model as model_api, - custom_model, - ) - - class DemoModel(custom_model.CustomModel): - def __init__(self, context: custom_model.ModelContext) -> None: - super().__init__(context) - - @custom_model.inference_api - def predict(self, input: pd.DataFrame) -> pd.DataFrame: - return pd.DataFrame({"output": input["c1"]}) - - lm = DemoModel(custom_model.ModelContext()) - pd_df = pd.DataFrame([[1, 2, 3], [4, 2, 5]], columns=["c1", "c2", "c3"]) - - model_api.save_model( - name=run_id, - model=lm, - sample_input=pd_df, - metadata={"author": "halu", "version": "1"}, - session=session, - model_stage_file_path=model_stage_file_path, - ) - - return log_model, (self.run_id, self.model_stage_file_path) - - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_log_model_factory, version_range=">=1.0.8,<=1.0.11" # type: ignore[misc, arg-type] - ) - def test_deploy_custom_model_compat_v1(self) -> None: - deploy_info = model_api.deploy( - self.session, - name=db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "predict"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - stage_path=self.model_stage_path, - target_method="predict", - options={}, - ) - assert deploy_info - - model_api.predict( - self.session, deployment=deploy_info, X=pd.DataFrame([[1, 2, 3], [4, 2, 5]], columns=["c1", "c2", "c3"]) - ) - - def _log_model_multiple_components_factory( - self, - ) -> Tuple[Callable[[session.Session, str, str], None], Tuple[str, str]]: - def log_model(session: session.Session, run_id: str, model_stage_file_path: str) -> None: - import os - import tempfile - - import pandas as pd - - from snowflake.ml.model import ( # type: ignore[attr-defined] - _model as model_api, - custom_model, - ) - - class DemoModel(custom_model.CustomModel): - def __init__(self, context: custom_model.ModelContext) -> None: - super().__init__(context) - - @custom_model.inference_api - def predict(self, input: pd.DataFrame) -> pd.DataFrame: - return pd.DataFrame({"output": input["c1"]}) - - class AsyncComposeModel(custom_model.CustomModel): - def __init__(self, context: custom_model.ModelContext) -> None: - super().__init__(context) - - @custom_model.inference_api - async def predict(self, input: pd.DataFrame) -> pd.DataFrame: - res1 = await self.context.model_ref("m1").predict.async_run(input) - res_sum = res1["output"] + self.context.model_ref("m2").predict(input)["output"] - return pd.DataFrame({"output": res_sum / 2}) - - class DemoModelWithArtifacts(custom_model.CustomModel): - def __init__(self, context: custom_model.ModelContext) -> None: - super().__init__(context) - with open(context.path("bias"), encoding="utf-8") as f: - v = int(f.read()) - self.bias = v - - @custom_model.inference_api - def predict(self, input: pd.DataFrame) -> pd.DataFrame: - return pd.DataFrame({"output": (input["c1"] + self.bias) > 12}) - - with tempfile.TemporaryDirectory() as tmpdir: - with open(os.path.join(tmpdir, "bias"), "w", encoding="utf-8") as f: - f.write("10") - lm_1 = DemoModelWithArtifacts( - custom_model.ModelContext(models={}, artifacts={"bias": os.path.join(tmpdir, "bias")}) - ) - lm_2 = DemoModel(custom_model.ModelContext()) - model_context = custom_model.ModelContext( - models={ - "m1": lm_1, - "m2": lm_2, - } - ) - acm = AsyncComposeModel(model_context) - pd_df = pd.DataFrame([[1, 2, 3], [4, 2, 5]], columns=["c1", "c2", "c3"]) - - model_api.save_model( - name=run_id, - model=acm, - sample_input=pd_df, - metadata={"author": "halu", "version": "1"}, - session=session, - model_stage_file_path=model_stage_file_path, - ) - - return log_model, (self.run_id, self.model_stage_file_path) - - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_log_model_multiple_components_factory, # type: ignore[misc, arg-type] - version_range=">=1.0.8,<=1.0.11", - ) - def test_deploy_custom_model_multiple_components_compat_v1(self) -> None: - deploy_info = model_api.deploy( - self.session, - name=db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "predict"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - stage_path=self.model_stage_path, - target_method="predict", - options={}, - ) - assert deploy_info - - model_api.predict( - self.session, deployment=deploy_info, X=pd.DataFrame([[1, 2, 3], [4, 2, 5]], columns=["c1", "c2", "c3"]) - ) - - def _log_sklearn_model_factory( - self, - ) -> Tuple[Callable[[session.Session, str, str], None], Tuple[str, str]]: - def log_model(session: session.Session, run_id: str, model_stage_file_path: str) -> None: - from sklearn import datasets, linear_model - - from snowflake.ml.model import ( # type: ignore[attr-defined] - _model as model_api, - ) - - iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True) - # LogisticRegression is for classfication task, such as iris - regr = linear_model.LogisticRegression() - regr.fit(iris_X, iris_y) - - model_api.save_model( - name=run_id, - model=regr, - sample_input=iris_X, - metadata={"author": "halu", "version": "1"}, - session=session, - model_stage_file_path=model_stage_file_path, - ) - - return log_model, (self.run_id, self.model_stage_file_path) - - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_log_sklearn_model_factory, version_range=">=1.0.6,<=1.0.11" # type: ignore[misc, arg-type] - ) - def test_deploy_sklearn_model_compat_v1(self) -> None: - deploy_info = model_api.deploy( - self.session, - name=db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "predict"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - stage_path=self.model_stage_path, - target_method="predict", - options={}, - ) - assert deploy_info - - iris_X, _ = datasets.load_iris(return_X_y=True, as_frame=True) - model_api.predict(self.session, deployment=deploy_info, X=iris_X) - - def _log_xgboost_model_factory( - self, - ) -> Tuple[Callable[[session.Session, str, str], None], Tuple[str, str]]: - def log_model(session: session.Session, run_id: str, model_stage_file_path: str) -> None: - import xgboost - from sklearn import datasets, model_selection - - from snowflake.ml.model import ( # type: ignore[attr-defined] - _model as model_api, - ) - - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_X = cal_data.data - cal_y = cal_data.target - cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) - regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) - regressor.fit(cal_X_train, cal_y_train) - - model_api.save_model( - name=run_id, - model=regressor, - sample_input=cal_X_test, - metadata={"author": "halu", "version": "1"}, - session=session, - model_stage_file_path=model_stage_file_path, - ) - - return log_model, (self.run_id, self.model_stage_file_path) - - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_log_xgboost_model_factory, version_range=">=1.0.6,<=1.0.11" # type: ignore[misc, arg-type] - ) - def test_deploy_xgboost_model_compat_v1(self) -> None: - deploy_info = model_api.deploy( - self.session, - name=db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "predict"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - stage_path=self.model_stage_path, - target_method="predict", - options={}, - ) - assert deploy_info - - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_X = cal_data.data - model_api.predict(self.session, deployment=deploy_info, X=cal_X) - - def _log_xgboost_booster_model_factory( - self, - ) -> Tuple[Callable[[session.Session, str, str], None], Tuple[str, str]]: - def log_model(session: session.Session, run_id: str, model_stage_file_path: str) -> None: - import xgboost - from sklearn import datasets, model_selection - - from snowflake.ml.model import ( # type: ignore[attr-defined] - _model as model_api, - ) - - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_X = cal_data.data - cal_y = cal_data.target - cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) - params = dict(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3, objective="binary:logistic") - regressor = xgboost.train(params, xgboost.DMatrix(data=cal_X_train, label=cal_y_train)) - - model_api.save_model( - name=run_id, - model=regressor, - sample_input=cal_X_test, - metadata={"author": "halu", "version": "1"}, - session=session, - model_stage_file_path=model_stage_file_path, - ) - - return log_model, (self.run_id, self.model_stage_file_path) - - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_log_xgboost_booster_model_factory, # type: ignore[misc, arg-type] - version_range=">=1.0.6,<=1.0.11", - ) - def test_deploy_xgboost_booster_model_compat_v1(self) -> None: - deploy_info = model_api.deploy( - self.session, - name=db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "predict"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - stage_path=self.model_stage_path, - target_method="predict", - options={}, - ) - assert deploy_info - - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_X = cal_data.data - model_api.predict(self.session, deployment=deploy_info, X=cal_X) - - def _log_snowml_sklearn_model_factory( - self, - ) -> Tuple[Callable[[session.Session, str, str], None], Tuple[str, str]]: - def log_model(session: session.Session, run_id: str, model_stage_file_path: str) -> None: - from sklearn import datasets - - from snowflake.ml.model import ( # type: ignore[attr-defined] - _model as model_api, - ) - from snowflake.ml.modeling.linear_model import ( # type: ignore[attr-defined] - LogisticRegression, - ) - - iris_X = datasets.load_iris(as_frame=True).frame - iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] - - INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] - LABEL_COLUMNS = "TARGET" - OUTPUT_COLUMNS = "PREDICTED_TARGET" - regr = LogisticRegression(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) - test_features = iris_X - regr.fit(test_features) - - model_api.save_model( - name=run_id, - model=regr, - metadata={"author": "halu", "version": "1"}, - session=session, - model_stage_file_path=model_stage_file_path, - ) - - return log_model, (self.run_id, self.model_stage_file_path) - - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_log_snowml_sklearn_model_factory, # type: ignore[misc, arg-type] - version_range=">=1.0.8,<=1.0.11", - ) - def test_deploy_snowml_sklearn_model_compat_v1(self) -> None: - deploy_info = model_api.deploy( - self.session, - name=db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "predict"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - stage_path=self.model_stage_path, - target_method="predict", - options={}, - ) - assert deploy_info - - iris_X = datasets.load_iris(as_frame=True).frame - iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] - - model_api.predict(self.session, deployment=deploy_info, X=iris_X) - - def _log_snowml_xgboost_model_factory( - self, - ) -> Tuple[Callable[[session.Session, str, str], None], Tuple[str, str]]: - def log_model(session: session.Session, run_id: str, model_stage_file_path: str) -> None: - from sklearn import datasets - - from snowflake.ml.model import ( # type: ignore[attr-defined] - _model as model_api, - ) - from snowflake.ml.modeling.xgboost import ( # type: ignore[attr-defined] - XGBRegressor, - ) - - iris_X = datasets.load_iris(as_frame=True).frame - iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] - - INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] - LABEL_COLUMNS = "TARGET" - OUTPUT_COLUMNS = "PREDICTED_TARGET" - regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) - test_features = iris_X - regr.fit(test_features) - - model_api.save_model( - name=run_id, - model=regr, - metadata={"author": "halu", "version": "1"}, - session=session, - model_stage_file_path=model_stage_file_path, - ) - - return log_model, (self.run_id, self.model_stage_file_path) - - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_log_snowml_xgboost_model_factory, # type: ignore[misc, arg-type] - version_range=">=1.0.8,<=1.0.11", - ) - def test_deploy_snowml_xgboost_model_compat_v1(self) -> None: - deploy_info = model_api.deploy( - self.session, - name=db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "predict"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - stage_path=self.model_stage_path, - target_method="predict", - options={}, - ) - assert deploy_info - - iris_X = datasets.load_iris(as_frame=True).frame - iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] - - model_api.predict(self.session, deployment=deploy_info, X=iris_X) - - def _log_pytorch_model_factory( - self, - ) -> Tuple[Callable[[session.Session, str, str], None], Tuple[str, str]]: - def log_model(session: session.Session, run_id: str, model_stage_file_path: str) -> None: - import numpy as np - import torch - - from snowflake.ml.model import ( # type: ignore[attr-defined] - _model as model_api, - ) - - class TorchModel(torch.nn.Module): - def __init__(self, n_input: int, n_hidden: int, n_out: int, dtype: torch.dtype = torch.float32) -> None: - super().__init__() - self.model = torch.nn.Sequential( - torch.nn.Linear(n_input, n_hidden, dtype=dtype), - torch.nn.ReLU(), - torch.nn.Linear(n_hidden, n_out, dtype=dtype), - torch.nn.Sigmoid(), - ) - - def forward(self, tensor: torch.Tensor) -> torch.Tensor: - return self.model(tensor) # type: ignore[no-any-return] - - n_input, n_hidden, n_out, batch_size, learning_rate = 10, 15, 1, 100, 0.01 - x = np.random.rand(batch_size, n_input) - dtype = torch.float32 - data_x = torch.from_numpy(x).to(dtype=dtype) - data_y = (torch.rand(size=(batch_size, 1)) < 0.5).to(dtype=dtype) - - model = TorchModel(n_input, n_hidden, n_out, dtype=dtype) - loss_function = torch.nn.MSELoss() - optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) - for _epoch in range(100): - pred_y = model.forward(data_x) - loss = loss_function(pred_y, data_y) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - model_api.save_model( - name=run_id, - model=model, - sample_input=[data_x], - metadata={"author": "halu", "version": "1"}, - session=session, - model_stage_file_path=model_stage_file_path, - ) - - return log_model, (self.run_id, self.model_stage_file_path) - - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_log_pytorch_model_factory, # type: ignore[misc, arg-type] - version_range=">=1.0.6,<=1.0.11", - additional_packages=["pytorch"], - ) - def test_deploy_pytorch_model_compat_v1(self) -> None: - import torch - - deploy_info = model_api.deploy( - self.session, - name=db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "forward"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - stage_path=self.model_stage_path, - target_method="forward", - options={}, - ) - assert deploy_info - - n_input, batch_size = 10, 100 - x = np.random.rand(batch_size, n_input) - dtype = torch.float32 - data_x = torch.from_numpy(x).to(dtype=dtype) - - model_api.predict(self.session, deployment=deploy_info, X=[data_x]) - - def _log_torchscript_model_factory( - self, - ) -> Tuple[Callable[[session.Session, str, str], None], Tuple[str, str]]: - def log_model(session: session.Session, run_id: str, model_stage_file_path: str) -> None: - import numpy as np - import torch - - from snowflake.ml.model import ( # type: ignore[attr-defined] - _model as model_api, - ) - - class TorchModel(torch.nn.Module): - def __init__(self, n_input: int, n_hidden: int, n_out: int, dtype: torch.dtype = torch.float32) -> None: - super().__init__() - self.model = torch.nn.Sequential( - torch.nn.Linear(n_input, n_hidden, dtype=dtype), - torch.nn.ReLU(), - torch.nn.Linear(n_hidden, n_out, dtype=dtype), - torch.nn.Sigmoid(), - ) - - def forward(self, tensor: torch.Tensor) -> torch.Tensor: - return self.model(tensor) # type: ignore[no-any-return] - - n_input, n_hidden, n_out, batch_size, learning_rate = 10, 15, 1, 100, 0.01 - x = np.random.rand(batch_size, n_input) - dtype = torch.float32 - data_x = torch.from_numpy(x).to(dtype=dtype) - data_y = (torch.rand(size=(batch_size, 1)) < 0.5).to(dtype=dtype) - - model = TorchModel(n_input, n_hidden, n_out, dtype=dtype) - loss_function = torch.nn.MSELoss() - optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) - for _epoch in range(100): - pred_y = model.forward(data_x) - loss = loss_function(pred_y, data_y) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - model_script = torch.jit.script(model) # type:ignore[attr-defined] - - model_api.save_model( - name=run_id, - model=model_script, - sample_input=[data_x], - metadata={"author": "halu", "version": "1"}, - session=session, - model_stage_file_path=model_stage_file_path, - ) - - return log_model, (self.run_id, self.model_stage_file_path) - - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_log_torchscript_model_factory, # type: ignore[misc, arg-type] - version_range=">=1.0.6,<=1.0.11", - additional_packages=["pytorch"], - ) - def test_deploy_torchscript_model_compat_v1(self) -> None: - import torch - - deploy_info = model_api.deploy( - self.session, - name=db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "forward"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - stage_path=self.model_stage_path, - target_method="forward", - options={}, - ) - assert deploy_info - - n_input, batch_size = 10, 100 - x = np.random.rand(batch_size, n_input) - dtype = torch.float32 - data_x = torch.from_numpy(x).to(dtype=dtype) - - model_api.predict(self.session, deployment=deploy_info, X=[data_x]) - - def _log_tensorflow_model_factory( - self, - ) -> Tuple[Callable[[session.Session, str, str], None], Tuple[str, str]]: - def log_model(session: session.Session, run_id: str, model_stage_file_path: str) -> None: - from typing import Optional - - import tensorflow as tf - - from snowflake.ml.model import ( # type: ignore[attr-defined] - _model as model_api, - ) - - class SimpleModule(tf.Module): - def __init__(self, name: Optional[str] = None) -> None: - super().__init__(name=name) - self.a_variable = tf.Variable(5.0, name="train_me") - self.non_trainable_variable = tf.Variable(5.0, trainable=False, name="do_not_train_me") - - @tf.function(input_signature=[tf.TensorSpec(shape=(None, 1), dtype=tf.float32)]) # type: ignore[misc] - def __call__(self, tensor: tf.Tensor) -> tf.Tensor: - return self.a_variable * tensor + self.non_trainable_variable - - model = SimpleModule(name="simple") - data_x = tf.constant([[5.0], [10.0]]) - - model_api.save_model( - name=run_id, - model=model, - sample_input=[data_x], - metadata={"author": "halu", "version": "1"}, - session=session, - model_stage_file_path=model_stage_file_path, - ) - - return log_model, (self.run_id, self.model_stage_file_path) - - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_log_tensorflow_model_factory, # type: ignore[misc, arg-type] - version_range=">=1.0.6,<=1.0.11", - additional_packages=["tensorflow"], - ) - def test_deploy_tensorflow_model_compat_v1(self) -> None: - import tensorflow as tf - - deploy_info = model_api.deploy( - self.session, - name=db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "__call__"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - stage_path=self.model_stage_path, - target_method="__call__", - options={}, - ) - assert deploy_info - - data_x = tf.constant([[5.0], [10.0]]) - - model_api.predict(self.session, deployment=deploy_info, X=[data_x]) - - def _log_keras_model_factory( - self, - ) -> Tuple[Callable[[session.Session, str, str], None], Tuple[str, str]]: - def log_model(session: session.Session, run_id: str, model_stage_file_path: str) -> None: - import numpy as np - import tensorflow as tf - - from snowflake.ml.model import ( # type: ignore[attr-defined] - _model as model_api, - ) - - class KerasModel(tf.keras.Model): - def __init__(self, n_hidden: int, n_out: int) -> None: - super().__init__() - self.fc_1 = tf.keras.layers.Dense(n_hidden, activation="relu") - self.fc_2 = tf.keras.layers.Dense(n_out, activation="sigmoid") - - def call(self, tensor: tf.Tensor) -> tf.Tensor: - input = tensor - x = self.fc_1(input) - x = self.fc_2(x) - return x - - dtype = tf.float32 - n_input, n_hidden, n_out, batch_size, learning_rate = 10, 15, 1, 100, 0.01 - x = np.random.rand(batch_size, n_input) - data_x = tf.convert_to_tensor(x, dtype=dtype) - raw_data_y = tf.random.uniform((batch_size, 1)) - raw_data_y = tf.where(raw_data_y > 0.5, tf.ones_like(raw_data_y), tf.zeros_like(raw_data_y)) - data_y = tf.cast(raw_data_y, dtype=dtype) - - model = KerasModel(n_hidden, n_out) - model.compile( - optimizer=tf.keras.optimizers.SGD(learning_rate=learning_rate), loss=tf.keras.losses.MeanSquaredError() - ) - model.fit(data_x, data_y, batch_size=batch_size, epochs=100) - - model_api.save_model( - name=run_id, - model=model, - sample_input=[data_x], - metadata={"author": "halu", "version": "1"}, - session=session, - model_stage_file_path=model_stage_file_path, - ) - - return log_model, (self.run_id, self.model_stage_file_path) - - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_log_keras_model_factory, # type: ignore[misc, arg-type] - version_range=">=1.0.6,<=1.0.11", - additional_packages=["tensorflow"], - ) - def test_deploy_keras_model_compat_v1(self) -> None: - import tensorflow as tf - - deploy_info = model_api.deploy( - self.session, - name=db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "predict"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - stage_path=self.model_stage_path, - target_method="predict", - options={}, - ) - assert deploy_info - - dtype = tf.float32 - n_input, batch_size = 10, 100 - x = np.random.rand(batch_size, n_input) - data_x = tf.convert_to_tensor(x, dtype=dtype) - - model_api.predict(self.session, deployment=deploy_info, X=[data_x]) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/warehouse_model_integ_test_utils.py b/tests/integ/snowflake/ml/model/warehouse_model_integ_test_utils.py deleted file mode 100644 index eab17a27..00000000 --- a/tests/integ/snowflake/ml/model/warehouse_model_integ_test_utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import posixpath -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import pandas as pd - -from snowflake.ml.model import ( - _api as model_api, - deploy_platforms, - type_hints as model_types, -) -from snowflake.snowpark import DataFrame as SnowparkDataFrame -from tests.integ.snowflake.ml.test_utils import db_manager, test_env_utils - - -def base_test_case( - db: db_manager.DBManager, - run_id: str, - full_qual_stage: str, - name: str, - model: model_types.SupportedModelType, - sample_input_data: model_types.SupportedDataType, - test_input: model_types.SupportedDataType, - deploy_params: Dict[str, Tuple[Dict[str, Any], Callable[[Union[pd.DataFrame, SnowparkDataFrame]], Any]]], - permanent_deploy: Optional[bool] = False, - additional_dependencies: Optional[List[str]] = None, -) -> None: - tmp_stage = db._session.get_session_stage() - conda_dependencies = [ - test_env_utils.get_latest_package_version_spec_in_server(db._session, "snowflake-snowpark-python!=1.12.0") - ] - if additional_dependencies: - conda_dependencies.extend(additional_dependencies) - - if permanent_deploy: - permanent_deploy_args = {"permanent_udf_stage_location": f"@{full_qual_stage}/"} - perm_model_name = "perm" - else: - permanent_deploy_args = {} - perm_model_name = "temp" - - actual_name = f"{name}_{perm_model_name}" - - model_api.save_model( - name=actual_name, - model=model, - sample_input_data=sample_input_data, - conda_dependencies=conda_dependencies, - metadata={"author": "halu", "version": "1"}, - session=db._session, - stage_path=posixpath.join(tmp_stage, f"{actual_name}_{run_id}"), - options={"relax_version": False}, - ) - - for target_method, (additional_deploy_options, check_func) in deploy_params.items(): - function_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - run_id, f"{actual_name}_{target_method}" - ) - # This is to test the case for omitting target_method when deploying. - if target_method == "": - target_method_arg = None - else: - target_method_arg = target_method - deploy_info = model_api.deploy( - name=function_name, - session=db._session, - stage_path=posixpath.join(tmp_stage, f"{actual_name}_{run_id}"), - platform=deploy_platforms.TargetPlatform.WAREHOUSE, - target_method=target_method_arg, - options={ - **permanent_deploy_args, # type: ignore[arg-type] - **additional_deploy_options, - }, # type: ignore[call-overload] - ) - - assert deploy_info is not None - res = model_api.predict(session=db._session, deployment=deploy_info, X=test_input) - - check_func(res) - - if permanent_deploy: - db.drop_function(function_name=function_name, args=["OBJECT"]) diff --git a/tests/integ/snowflake/ml/model/warehouse_pytorch_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_pytorch_model_integ_test.py deleted file mode 100644 index 41ed4cc1..00000000 --- a/tests/integ/snowflake/ml/model/warehouse_pytorch_model_integ_test.py +++ /dev/null @@ -1,234 +0,0 @@ -import uuid -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import pandas as pd -import torch -from absl.testing import absltest, parameterized - -from snowflake.ml.model import type_hints as model_types -from snowflake.ml.model._signatures import pytorch_handler, snowpark_handler -from snowflake.ml.utils import connection_params -from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session -from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import ( - dataframe_utils, - db_manager, - model_factory, -) - - -class TestWarehousePytorchModelINteg(parameterized.TestCase): - @classmethod - def setUpClass(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - - self._db_manager = db_manager.DBManager(self._session) - self._db_manager.cleanup_schemas() - self._db_manager.cleanup_stages() - self._db_manager.cleanup_user_functions() - - # To create different UDF names among different runs - self.run_id = uuid.uuid4().hex - self._test_schema_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "model_deployment_pytorch_model_test_schema" - ) - self._db_manager.create_schema(self._test_schema_name) - self._db_manager.use_schema(self._test_schema_name) - - self.deploy_stage_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "deployment_stage" - ) - self.full_qual_stage = self._db_manager.create_stage( - self.deploy_stage_name, schema_name=self._test_schema_name, sse_encrypted=False - ) - - @classmethod - def tearDownClass(self) -> None: - self._db_manager.drop_stage(self.deploy_stage_name, schema_name=self._test_schema_name) - self._db_manager.drop_schema(self._test_schema_name) - self._session.close() - - def base_test_case( - self, - name: str, - model: model_types.SupportedModelType, - sample_input_data: model_types.SupportedDataType, - test_input: model_types.SupportedDataType, - deploy_params: Dict[str, Tuple[Dict[str, Any], Callable[[Union[pd.DataFrame, SnowparkDataFrame]], Any]]], - permanent_deploy: Optional[bool] = False, - ) -> None: - warehouse_model_integ_test_utils.base_test_case( - self._db_manager, - run_id=self.run_id, - full_qual_stage=self.full_qual_stage, - name=name, - model=model, - sample_input_data=sample_input_data, - test_input=test_input, - deploy_params=deploy_params, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_pytorch_tensor_as_sample( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - model, data_x, data_y = model_factory.ModelFactory.prepare_torch_model(torch.float32) - x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) - y_pred = model.forward(data_x).detach() - - self.base_test_case( - name="pytorch_model_tensor_as_sample", - model=model, - sample_input_data=[data_x], - test_input=x_df, - deploy_params={ - "": ( - {}, - lambda res: torch.testing.assert_close( - pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(res)[0], y_pred, check_dtype=False - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_pytorch_df_as_sample( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - model, data_x, data_y = model_factory.ModelFactory.prepare_torch_model(torch.float64) - x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) - y_pred = model.forward(data_x).detach() - - self.base_test_case( - name="pytorch_model_df_as_sample", - model=model, - sample_input_data=x_df, - test_input=x_df, - deploy_params={ - "": ( - {}, - lambda res: torch.testing.assert_close( - pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(res)[0], y_pred - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_pytorch_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - model, data_x, data_y = model_factory.ModelFactory.prepare_torch_model(torch.float64) - x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) - x_df.columns = ["col_0"] - y_pred = model.forward(data_x) - x_df_sp = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(self._session, x_df) - y_pred_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([y_pred]) - y_pred_df.columns = ["output_feature_0"] - y_df_expected = pd.concat([x_df, y_pred_df], axis=1) - - self.base_test_case( - name="pytorch_model_sp", - model=model, - sample_input_data=x_df, - test_input=x_df_sp, - deploy_params={ - "": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_torchscript_tensor_as_sample( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - model, data_x, data_y = model_factory.ModelFactory.prepare_jittable_torch_model(torch.float32) - x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) - model_script = torch.jit.script(model) # type:ignore[attr-defined] - y_pred = model_script.forward(data_x).detach() - - self.base_test_case( - name="torch_script_model_tensor_as_sample", - model=model_script, - sample_input_data=[data_x], - test_input=x_df, - deploy_params={ - "": ( - {}, - lambda res: torch.testing.assert_close( - pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(res)[0], y_pred, check_dtype=False - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_torchscript_df_as_sample( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - model, data_x, data_y = model_factory.ModelFactory.prepare_jittable_torch_model(torch.float64) - x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) - model_script = torch.jit.script(model) # type:ignore[attr-defined] - y_pred = model_script.forward(data_x).detach() - - self.base_test_case( - name="torch_script_model_df_as_sample", - model=model_script, - sample_input_data=x_df, - test_input=x_df, - deploy_params={ - "": ( - {}, - lambda res: torch.testing.assert_close( - pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(res)[0], y_pred - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_torchscript_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - model, data_x, data_y = model_factory.ModelFactory.prepare_jittable_torch_model(torch.float64) - x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) - x_df.columns = ["col_0"] - model_script = torch.jit.script(model) # type:ignore[attr-defined] - y_pred = model_script.forward(data_x) - x_df_sp = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(self._session, x_df) - y_pred_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([y_pred]) - y_pred_df.columns = ["output_feature_0"] - y_df_expected = pd.concat([x_df, y_pred_df], axis=1) - - self.base_test_case( - name="torch_script_model_sp", - model=model_script, - sample_input_data=x_df, - test_input=x_df_sp, - deploy_params={ - "": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), - ), - }, - permanent_deploy=permanent_deploy, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/warehouse_sentence_transformers_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_sentence_transformers_model_integ_test.py deleted file mode 100644 index d0564b61..00000000 --- a/tests/integ/snowflake/ml/model/warehouse_sentence_transformers_model_integ_test.py +++ /dev/null @@ -1,178 +0,0 @@ -import os -import random -import tempfile -import uuid -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import pandas as pd -from absl.testing import absltest, parameterized - -from snowflake.ml.model import model_signature, type_hints as model_types -from snowflake.ml.model._packager.model_handlers.sentence_transformers import ( - _sentence_transformer_encode, -) -from snowflake.ml.model._signatures import ( - snowpark_handler, - utils as model_signature_utils, -) -from snowflake.ml.utils import connection_params -from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session -from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import dataframe_utils, db_manager - -MODEL_NAMES = ["intfloat/e5-base-v2"] # cant load models in parallel -SENTENCE_TRANSFORMERS_CACHE_DIR = "SENTENCE_TRANSFORMERS_HOME" - - -class TestWarehouseSentenceTransformerInteg(parameterized.TestCase): - @classmethod - def setUpClass(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - - self._db_manager = db_manager.DBManager(self._session) - self._db_manager.cleanup_schemas() - self._db_manager.cleanup_stages() - self._db_manager.cleanup_user_functions() - - self.cache_dir = tempfile.TemporaryDirectory() - self._original_cache_dir = os.getenv(SENTENCE_TRANSFORMERS_CACHE_DIR, None) - os.environ[SENTENCE_TRANSFORMERS_CACHE_DIR] = self.cache_dir.name - - # To create different UDF names among different runs - self.run_id = uuid.uuid4().hex - self._test_schema_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "model_deployment_sentence_transformers_model_test_schema" - ) - self._db_manager.create_schema(self._test_schema_name) - self._db_manager.use_schema(self._test_schema_name) - - self.deploy_stage_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "deployment_stage" - ) - self.full_qual_stage = self._db_manager.create_stage( - self.deploy_stage_name, - schema_name=self._test_schema_name, - sse_encrypted=False, - ) - - @classmethod - def tearDownClass(self) -> None: - self._db_manager.drop_stage(self.deploy_stage_name, schema_name=self._test_schema_name) - self._db_manager.drop_schema(self._test_schema_name) - self._session.close() - if self._original_cache_dir: - os.environ[SENTENCE_TRANSFORMERS_CACHE_DIR] = self._original_cache_dir - self.cache_dir.cleanup() - - def base_test_case( - self, - name: str, - model: model_types.SupportedModelType, - sample_input_data: model_types.SupportedModelType, - test_input: model_types.SupportedDataType, - deploy_params: Dict[ - str, - Tuple[Dict[str, Any], Callable[[Union[pd.DataFrame, SnowparkDataFrame]], Any]], - ], - permanent_deploy: Optional[bool] = False, - additional_dependencies: Optional[List[str]] = None, - ) -> None: - warehouse_model_integ_test_utils.base_test_case( - self._db_manager, - run_id=self.run_id, - full_qual_stage=self.full_qual_stage, - name=name, - model=model, - sample_input_data=sample_input_data, - test_input=test_input, - deploy_params=deploy_params, - permanent_deploy=permanent_deploy, - additional_dependencies=additional_dependencies, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_sentence_transformers( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - # We have to import here due to cache location issue. - # Only by doing so can we make the cache dir setting effective. - import sentence_transformers - - # Sample Data - sentences = pd.DataFrame( - { - "SENTENCES": [ - "Why don’t scientists trust atoms? Because they make up everything.", - "I told my wife she should embrace her mistakes. She gave me a hug.", - "Im reading a book on anti-gravity. Its impossible to put down!", - "Did you hear about the mathematician who’s afraid of negative numbers?", - "Parallel lines have so much in common. It’s a shame they’ll never meet.", - ] - } - ) - model = sentence_transformers.SentenceTransformer(random.choice(MODEL_NAMES)) - embeddings = _sentence_transformer_encode(model, sentences) - sig = {"encode": model_signature.infer_signature(sentences, embeddings)} - embeddings = model_signature_utils.rename_pandas_df(embeddings, sig["encode"].outputs) - - self.base_test_case( - name="sentence_transformers_model", - model=model, - sample_input_data=sentences, - test_input=sentences, - deploy_params={ - "encode": ( - {}, - lambda res: res.equals(embeddings), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_sentence_transformers_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - # We have to import here due to cache location issue. - # Only by doing so can we make the cache dir setting effective. - import sentence_transformers - - # Sample Data - sentences = pd.DataFrame( - { - "SENTENCES": [ - "Why don’t scientists trust atoms? Because they make up everything.", - "I told my wife she should embrace her mistakes. She gave me a hug.", - "Im reading a book on anti-gravity. Its impossible to put down!", - "Did you hear about the mathematician who’s afraid of negative numbers?", - "Parallel lines have so much in common. It’s a shame they’ll never meet.", - ] - } - ) - sentences_sp = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(self._session, sentences) - model = sentence_transformers.SentenceTransformer(random.choice(MODEL_NAMES)) - embeddings = _sentence_transformer_encode(model, sentences) - sig = {"encode": model_signature.infer_signature(sentences, embeddings)} - embeddings = model_signature_utils.rename_pandas_df(embeddings, sig["encode"].outputs) - y_df_expected = pd.concat([sentences_sp.to_pandas(), embeddings], axis=1) - - self.base_test_case( - name="sentence_transformers_model", - model=model, - sample_input_data=sentences, - test_input=sentences_sp, - deploy_params={ - "encode": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, atol=1e-6), - ), - }, - permanent_deploy=permanent_deploy, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/warehouse_sklearn_xgboost_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_sklearn_xgboost_model_integ_test.py deleted file mode 100644 index 1190c762..00000000 --- a/tests/integ/snowflake/ml/model/warehouse_sklearn_xgboost_model_integ_test.py +++ /dev/null @@ -1,283 +0,0 @@ -import uuid -from typing import Any, Callable, Dict, Optional, Tuple, Union, cast - -import inflection -import numpy as np -import pandas as pd -import xgboost -from absl.testing import absltest, parameterized -from sklearn import datasets, ensemble, linear_model, model_selection, multioutput - -from snowflake.ml.model import type_hints as model_types -from snowflake.ml.utils import connection_params -from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session -from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import dataframe_utils, db_manager - - -class TestWarehouseSKLearnXGBoostModelInteg(parameterized.TestCase): - @classmethod - def setUpClass(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - - self._db_manager = db_manager.DBManager(self._session) - self._db_manager.cleanup_schemas() - self._db_manager.cleanup_stages() - self._db_manager.cleanup_user_functions() - - # To create different UDF names among different runs - self.run_id = uuid.uuid4().hex - self._test_schema_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "model_deployment_sklearn_xgboost_model_test_schema" - ) - self._db_manager.create_schema(self._test_schema_name) - self._db_manager.use_schema(self._test_schema_name) - - self.deploy_stage_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "deployment_stage" - ) - self.full_qual_stage = self._db_manager.create_stage( - self.deploy_stage_name, schema_name=self._test_schema_name, sse_encrypted=False - ) - - @classmethod - def tearDownClass(self) -> None: - self._db_manager.drop_stage(self.deploy_stage_name, schema_name=self._test_schema_name) - self._db_manager.drop_schema(self._test_schema_name) - self._session.close() - - def base_test_case( - self, - name: str, - model: model_types.SupportedModelType, - sample_input_data: model_types.SupportedDataType, - test_input: model_types.SupportedDataType, - deploy_params: Dict[str, Tuple[Dict[str, Any], Callable[[Union[pd.DataFrame, SnowparkDataFrame]], Any]]], - permanent_deploy: Optional[bool] = False, - ) -> None: - warehouse_model_integ_test_utils.base_test_case( - self._db_manager, - run_id=self.run_id, - full_qual_stage=self.full_qual_stage, - name=name, - model=model, - sample_input_data=sample_input_data, - test_input=test_input, - deploy_params=deploy_params, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_skl_model_deploy( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - iris_X, iris_y = datasets.load_iris(return_X_y=True) - # LogisticRegression is for classfication task, such as iris - regr = linear_model.LogisticRegression() - regr.fit(iris_X, iris_y) - self.base_test_case( - name="skl_model", - model=regr, - sample_input_data=iris_X, - test_input=iris_X, - deploy_params={ - "predict": ( - {}, - lambda res: np.testing.assert_allclose(res["output_feature_0"].values, regr.predict(iris_X)), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_skl_model_proba_deploy( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - iris_X, iris_y = datasets.load_iris(return_X_y=True) - model = ensemble.RandomForestClassifier(random_state=42) - model.fit(iris_X[:10], iris_y[:10]) - self.base_test_case( - name="skl_model_proba_deploy", - model=model, - sample_input_data=iris_X, - test_input=iris_X[:10], - deploy_params={ - "predict": ( - {}, - lambda res: np.testing.assert_allclose(res["output_feature_0"].values, model.predict(iris_X[:10])), - ), - "predict_proba": ( - {}, - lambda res: np.testing.assert_allclose(res.values, model.predict_proba(iris_X[:10])), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_skl_multiple_output_model_proba_deploy( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - iris_X, iris_y = datasets.load_iris(return_X_y=True) - target2 = np.random.randint(0, 6, size=iris_y.shape) - dual_target = np.vstack([iris_y, target2]).T - model = multioutput.MultiOutputClassifier(ensemble.RandomForestClassifier(random_state=42)) - model.fit(iris_X[:10], dual_target[:10]) - self.base_test_case( - name="skl_multiple_output_model_proba", - model=model, - sample_input_data=iris_X, - test_input=iris_X[-10:], - deploy_params={ - "predict": ( - {}, - lambda res: np.testing.assert_allclose(res.to_numpy(), model.predict(iris_X[-10:])), - ), - "predict_proba": ( - {}, - lambda res: np.testing.assert_allclose( - np.hstack([np.array(res[col].to_list()) for col in cast(pd.DataFrame, res)]), - np.hstack(model.predict_proba(iris_X[-10:])), - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_xgb( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_X = cal_data.data - cal_y = cal_data.target - cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] - cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) - regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) - regressor.fit(cal_X_train, cal_y_train) - self.base_test_case( - name="xgb_model", - model=regressor, - sample_input_data=cal_X_test, - test_input=cal_X_test, - deploy_params={ - "predict": ( - {}, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(regressor.predict(cal_X_test), axis=1) - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_xgb_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - cal_data = datasets.load_breast_cancer(as_frame=True).frame - cal_data.columns = [inflection.parameterize(c, "_") for c in cal_data] - cal_data_sp_df = self._session.create_dataframe(cal_data) - cal_data_sp_df_train, cal_data_sp_df_test = tuple(cal_data_sp_df.random_split([0.25, 0.75], seed=2568)) - regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) - cal_data_pd_df_train = cal_data_sp_df_train.to_pandas() - regressor.fit(cal_data_pd_df_train.drop(columns=["target"]), cal_data_pd_df_train["target"]) - cal_data_sp_df_test_X = cal_data_sp_df_test.drop('"target"') - - y_df_expected = pd.concat( - [ - cal_data_sp_df_test_X.to_pandas(), - pd.DataFrame(regressor.predict(cal_data_sp_df_test_X.to_pandas()), columns=["output_feature_0"]), - ], - axis=1, - ) - self.base_test_case( - name="xgb_model_sp", - model=regressor, - sample_input_data=cal_data_sp_df_train.drop('"target"'), - test_input=cal_data_sp_df_test_X, - deploy_params={ - "predict": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_xgb_booster( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_X = cal_data.data - cal_y = cal_data.target - cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] - cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) - params = dict(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3, objective="binary:logistic") - regressor = xgboost.train(params, xgboost.DMatrix(data=cal_X_train, label=cal_y_train)) - y_pred = regressor.predict(xgboost.DMatrix(data=cal_X_test)) - self.base_test_case( - name="xgb_booster", - model=regressor, - sample_input_data=cal_X_test, - test_input=cal_X_test, - deploy_params={ - "predict": ( - {}, - lambda res: np.testing.assert_allclose(res.values, np.expand_dims(y_pred, axis=1), rtol=1e-6), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_xgb_booster_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - cal_data = datasets.load_breast_cancer(as_frame=True).frame - cal_data.columns = [inflection.parameterize(c, "_") for c in cal_data] - cal_data_sp_df = self._session.create_dataframe(cal_data) - cal_data_sp_df_train, cal_data_sp_df_test = tuple(cal_data_sp_df.random_split([0.25, 0.75], seed=2568)) - cal_data_pd_df_train = cal_data_sp_df_train.to_pandas() - params = dict(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3, objective="binary:logistic") - regressor = xgboost.train( - params, - xgboost.DMatrix(data=cal_data_pd_df_train.drop(columns=["target"]), label=cal_data_pd_df_train["target"]), - ) - cal_data_sp_df_test_X = cal_data_sp_df_test.drop('"target"') - y_df_expected = pd.concat( - [ - cal_data_sp_df_test_X.to_pandas(), - pd.DataFrame( - regressor.predict(xgboost.DMatrix(data=cal_data_sp_df_test_X.to_pandas())), - columns=["output_feature_0"], - ), - ], - axis=1, - ) - self.base_test_case( - name="xgb_booster_sp", - model=regressor, - sample_input_data=cal_data_sp_df_train.drop('"target"'), - test_input=cal_data_sp_df_test_X, - deploy_params={ - "predict": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), - ), - }, - permanent_deploy=permanent_deploy, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/warehouse_snowml_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_snowml_model_integ_test.py deleted file mode 100644 index 869bfb4f..00000000 --- a/tests/integ/snowflake/ml/model/warehouse_snowml_model_integ_test.py +++ /dev/null @@ -1,167 +0,0 @@ -import uuid -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import numpy as np -import pandas as pd -from absl.testing import absltest, parameterized -from sklearn import datasets - -from snowflake.ml.model import type_hints as model_types -from snowflake.ml.modeling.lightgbm import LGBMRegressor -from snowflake.ml.modeling.linear_model import LogisticRegression -from snowflake.ml.modeling.xgboost import XGBRegressor -from snowflake.ml.utils import connection_params -from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session -from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import db_manager - - -class TestWarehouseSnowMLModelInteg(parameterized.TestCase): - @classmethod - def setUpClass(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - - self._db_manager = db_manager.DBManager(self._session) - self._db_manager.cleanup_schemas() - self._db_manager.cleanup_stages() - self._db_manager.cleanup_user_functions() - - # To create different UDF names among different runs - self.run_id = uuid.uuid4().hex - self._test_schema_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "model_deployment_snowml_model_test_schema" - ) - self._db_manager.create_schema(self._test_schema_name) - self._db_manager.use_schema(self._test_schema_name) - - self.deploy_stage_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "deployment_stage" - ) - self.full_qual_stage = self._db_manager.create_stage( - self.deploy_stage_name, schema_name=self._test_schema_name, sse_encrypted=False - ) - - @classmethod - def tearDownClass(self) -> None: - self._db_manager.drop_stage(self.deploy_stage_name, schema_name=self._test_schema_name) - self._db_manager.drop_schema(self._test_schema_name) - self._session.close() - - def base_test_case( - self, - name: str, - model: model_types.SupportedModelType, - sample_input_data: model_types.SupportedDataType, - test_input: model_types.SupportedDataType, - deploy_params: Dict[str, Tuple[Dict[str, Any], Callable[[Union[pd.DataFrame, SnowparkDataFrame]], Any]]], - permanent_deploy: Optional[bool] = False, - ) -> None: - warehouse_model_integ_test_utils.base_test_case( - self._db_manager, - run_id=self.run_id, - full_qual_stage=self.full_qual_stage, - name=name, - model=model, - sample_input_data=sample_input_data, - test_input=test_input, - deploy_params=deploy_params, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_snowml_model_deploy_snowml_sklearn( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - iris_X = datasets.load_iris(as_frame=True).frame - iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] - - INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] - LABEL_COLUMNS = "TARGET" - OUTPUT_COLUMNS = "PREDICTED_TARGET" - regr = LogisticRegression(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) - test_features = iris_X - regr.fit(test_features) - - self.base_test_case( - name="snowml_model_sklearn", - model=regr, - sample_input_data=None, - test_input=test_features, - deploy_params={ - "predict": ( - {}, - lambda res: np.testing.assert_allclose( - res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_snowml_model_deploy_xgboost( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - iris_X = datasets.load_iris(as_frame=True).frame - iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] - - INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] - LABEL_COLUMNS = "TARGET" - OUTPUT_COLUMNS = "PREDICTED_TARGET" - regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) - test_features = iris_X[:10] - regr.fit(test_features) - - self.base_test_case( - name="snowml_model_xgb", - model=regr, - sample_input_data=None, - test_input=test_features, - deploy_params={ - "predict": ( - {}, - lambda res: np.testing.assert_allclose( - res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_snowml_model_deploy_lightgbm( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - iris_X = datasets.load_iris(as_frame=True).frame - iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] - - INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] - LABEL_COLUMNS = "TARGET" - OUTPUT_COLUMNS = "PREDICTED_TARGET" - regr = LGBMRegressor(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) - test_features = iris_X[:10] - regr.fit(test_features) - - self.base_test_case( - name="snowml_model_lightgbm", - model=regr, - sample_input_data=None, - test_input=test_features, - deploy_params={ - "predict": ( - {}, - lambda res: np.testing.assert_allclose( - res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/model/warehouse_tensorflow_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_tensorflow_model_integ_test.py deleted file mode 100644 index 72b5ae59..00000000 --- a/tests/integ/snowflake/ml/model/warehouse_tensorflow_model_integ_test.py +++ /dev/null @@ -1,281 +0,0 @@ -import uuid -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import numpy as np -import pandas as pd -import pytest -import tensorflow as tf -from absl.testing import absltest, parameterized - -from snowflake.ml.model import type_hints as model_types -from snowflake.ml.model._signatures import ( - numpy_handler, - snowpark_handler, - tensorflow_handler, -) -from snowflake.ml.utils import connection_params -from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session -from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import ( - dataframe_utils, - db_manager, - model_factory, -) - - -def prepare_keras_model( - dtype: "tf.dtypes.DType" = tf.float32, -) -> Tuple["tf.keras.Model", "tf.Tensor", "tf.Tensor"]: - class KerasModel(tf.keras.Model): - def __init__(self, n_hidden: int, n_out: int) -> None: - super().__init__() - self.fc_1 = tf.keras.layers.Dense(n_hidden, activation="relu") - self.fc_2 = tf.keras.layers.Dense(n_out, activation="sigmoid") - - def call(self, tensor: "tf.Tensor") -> "tf.Tensor": - input = tensor - x = self.fc_1(input) - x = self.fc_2(x) - return x - - n_input, n_hidden, n_out, batch_size, learning_rate = 10, 15, 1, 100, 0.01 - x = np.random.rand(batch_size, n_input) - data_x = tf.convert_to_tensor(x, dtype=dtype) - raw_data_y = tf.random.uniform((batch_size, 1)) - raw_data_y = tf.where(raw_data_y > 0.5, tf.ones_like(raw_data_y), tf.zeros_like(raw_data_y)) - data_y = tf.cast(raw_data_y, dtype=dtype) - - model = KerasModel(n_hidden, n_out) - model.compile( - optimizer=tf.keras.optimizers.SGD(learning_rate=learning_rate), loss=tf.keras.losses.MeanSquaredError() - ) - model.fit(data_x, data_y, batch_size=batch_size, epochs=100) - return model, data_x, data_y - - -@pytest.mark.pip_incompatible -class TestWarehouseTensorflowModelInteg(parameterized.TestCase): - @classmethod - def setUpClass(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - - self._db_manager = db_manager.DBManager(self._session) - self._db_manager.cleanup_schemas() - self._db_manager.cleanup_stages() - self._db_manager.cleanup_user_functions() - - # To create different UDF names among different runs - self.run_id = uuid.uuid4().hex - self._test_schema_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "model_deployment_tensorflow_model_test_schema" - ) - self._db_manager.create_schema(self._test_schema_name) - self._db_manager.use_schema(self._test_schema_name) - - self.deploy_stage_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "deployment_stage" - ) - self.full_qual_stage = self._db_manager.create_stage( - self.deploy_stage_name, schema_name=self._test_schema_name, sse_encrypted=False - ) - - @classmethod - def tearDownClass(self) -> None: - self._db_manager.drop_stage(self.deploy_stage_name, schema_name=self._test_schema_name) - self._db_manager.drop_schema(self._test_schema_name) - self._session.close() - - def base_test_case( - self, - name: str, - model: model_types.SupportedModelType, - sample_input_data: model_types.SupportedDataType, - test_input: model_types.SupportedDataType, - deploy_params: Dict[str, Tuple[Dict[str, Any], Callable[[Union[pd.DataFrame, SnowparkDataFrame]], Any]]], - permanent_deploy: Optional[bool] = False, - ) -> None: - warehouse_model_integ_test_utils.base_test_case( - self._db_manager, - run_id=self.run_id, - full_qual_stage=self.full_qual_stage, - name=name, - model=model, - sample_input_data=sample_input_data, - test_input=test_input, - deploy_params=deploy_params, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_tf_tensor_as_sample( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - - model, data_x = model_factory.ModelFactory.prepare_tf_model() - x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) - y_pred = model(data_x) - - self.base_test_case( - name="tf_model_tensor_as_sample", - model=model, - sample_input_data=[data_x], - test_input=x_df, - deploy_params={ - "": ( - {}, - lambda res: np.testing.assert_allclose( - tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(res)[0].numpy(), - y_pred.numpy(), - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_tf_df_as_sample( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - model, data_x = model_factory.ModelFactory.prepare_tf_model() - x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) - y_pred = model(data_x) - - self.base_test_case( - name="tf_model_df_as_sample", - model=model, - sample_input_data=x_df, - test_input=x_df, - deploy_params={ - "": ( - {}, - lambda res: np.testing.assert_allclose( - tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(res)[0].numpy(), - y_pred.numpy(), - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_tf_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - model, data_x = model_factory.ModelFactory.prepare_tf_model() - x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) - x_df.columns = ["col_0"] - y_pred = model(data_x) - x_df_sp = snowpark_handler.SnowparkDataFrameHandler.convert_from_df( - self._session, - x_df, - ) - y_pred_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([y_pred]) - y_pred_df.columns = ["output_feature_0"] - y_df_expected = pd.concat([x_df, y_pred_df], axis=1) - - self.base_test_case( - name="tf_model_sp", - model=model, - sample_input_data=x_df, - test_input=x_df_sp, - deploy_params={ - "": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_keras_tensor_as_sample( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - model, data_x, data_y = prepare_keras_model() - x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) - y_pred = model.predict(data_x) - - self.base_test_case( - name="keras_model_tensor_as_sample", - model=model, - sample_input_data=[data_x], - test_input=x_df, - deploy_params={ - "": ( - {}, - lambda res: np.testing.assert_allclose( - tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(res)[0].numpy(), - y_pred, - atol=1e-6, - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_keras_df_as_sample( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - model, data_x, data_y = prepare_keras_model() - x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) - y_pred = model.predict(data_x) - - self.base_test_case( - name="keras_model_df_as_sample", - model=model, - sample_input_data=x_df, - test_input=x_df, - deploy_params={ - "": ( - {}, - lambda res: np.testing.assert_allclose( - tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(res)[0].numpy(), - y_pred, - atol=1e-6, - ), - ), - }, - permanent_deploy=permanent_deploy, - ) - - @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] - def test_keras_sp( - self, - permanent_deploy: Optional[bool] = False, - ) -> None: - model, data_x, data_y = prepare_keras_model() - x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) - x_df.columns = ["col_0"] - y_pred = model.predict(data_x) - x_df_sp = snowpark_handler.SnowparkDataFrameHandler.convert_from_df( - self._session, - x_df, - ) - y_pred_df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df([y_pred]) - y_pred_df.columns = ["output_feature_0"] - y_df_expected = pd.concat([x_df, y_pred_df], axis=1) - - self.base_test_case( - name="keras_model_sp", - model=model, - sample_input_data=x_df, - test_input=x_df_sp, - deploy_params={ - "": ( - {}, - lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), - ), - }, - permanent_deploy=permanent_deploy, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/modeling/model_selection/check_output_hpo_integ_test.py b/tests/integ/snowflake/ml/modeling/model_selection/check_output_hpo_integ_test.py index d6ea8558..b7a3584c 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/check_output_hpo_integ_test.py +++ b/tests/integ/snowflake/ml/modeling/model_selection/check_output_hpo_integ_test.py @@ -90,7 +90,7 @@ def _compare_cv_results(self, cv_result_1: Dict[str, Any], cv_result_2: Dict[str elif k == "params": # compare the parameter combination self.assertEqual(v.tolist(), cv_result_2[k]) elif k.endswith("test_score"): # compare the test score - np.testing.assert_allclose(v, cv_result_2[k], rtol=1.0e-7, atol=1.0e-7) + np.testing.assert_allclose(v, cv_result_2[k], rtol=1.0e-5) # Do not compare the fit time def _compare_global_variables(self, sk_obj: SkGridSearchCV, sklearn_reg: SkGridSearchCV) -> None: @@ -100,7 +100,7 @@ def _compare_global_variables(self, sk_obj: SkGridSearchCV, sklearn_reg: SkGridS assert isinstance(sk_obj.refit_time_, float) if hasattr(sk_obj, "best_score_"): # if refit = callable and no best_score specified, then this attribute is empty - np.testing.assert_allclose(sk_obj.best_score_, sklearn_reg.best_score_) + np.testing.assert_allclose(sk_obj.best_score_, sklearn_reg.best_score_, rtol=1.0e-5) self.assertEqual(sk_obj.multimetric_, sklearn_reg.multimetric_) self.assertEqual(sk_obj.best_index_, sklearn_reg.best_index_) if hasattr(sk_obj, "n_splits_"): # n_splits_ is only available in RandomSearchCV @@ -118,15 +118,13 @@ def _compare_global_variables(self, sk_obj: SkGridSearchCV, sklearn_reg: SkGridS np.testing.assert_allclose( getattr(sk_obj.best_estimator_, variable_name), getattr(sklearn_reg.best_estimator_, variable_name), - rtol=1.0e-7, - atol=1.0e-7, + rtol=1.0e-5, ) else: np.testing.assert_allclose( getattr(sk_obj.best_estimator_, variable_name), getattr(sklearn_reg.best_estimator_, variable_name), - rtol=1.0e-7, - atol=1.0e-7, + rtol=1.0e-5, ) self.assertEqual(sk_obj.n_features_in_, sklearn_reg.n_features_in_) if hasattr(sk_obj, "feature_names_in_") and hasattr( diff --git a/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py b/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py index 2c2acc4b..0244f5a8 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py +++ b/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py @@ -78,8 +78,7 @@ def _compare_global_variables( np.testing.assert_allclose( sk_obj.best_score_, sklearn_reg.best_score_, - rtol=1.0e-5, - atol=1.0e-4, + rtol=1.0e-4, ) self.assertEqual(sk_obj.multimetric_, sklearn_reg.multimetric_) self.assertEqual(sk_obj.best_index_, sklearn_reg.best_index_) diff --git a/tests/integ/snowflake/ml/modeling/model_selection/randomized_search_integ_test.py b/tests/integ/snowflake/ml/modeling/model_selection/randomized_search_integ_test.py index 27babe84..3e379e0f 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/randomized_search_integ_test.py +++ b/tests/integ/snowflake/ml/modeling/model_selection/randomized_search_integ_test.py @@ -78,8 +78,7 @@ def _compare_global_variables( np.testing.assert_allclose( sk_obj.best_score_, sklearn_reg.best_score_, - rtol=1.0e-5, - atol=1.0e-4, + rtol=1.0e-4, ) self.assertEqual(sk_obj.multimetric_, sklearn_reg.multimetric_) self.assertEqual(sk_obj.best_index_, sklearn_reg.best_index_) diff --git a/tests/integ/snowflake/ml/modeling/model_selection/search_single_node_test.py b/tests/integ/snowflake/ml/modeling/model_selection/search_single_node_test.py index 4ebb3534..cdb9f93c 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/search_single_node_test.py +++ b/tests/integ/snowflake/ml/modeling/model_selection/search_single_node_test.py @@ -83,7 +83,7 @@ def test_not_single_node_grid(self, mock_is_single_node) -> None: "learning_rate": [0.1], } - estimator = XGBClassifier() + estimator = XGBClassifier(n_jobs=3) reg = GridSearchCV(estimator=estimator, param_grid=parameters, cv=2, verbose=True) reg.set_input_cols(input_cols) output_cols = ["OUTPUT_" + c for c in label_col] @@ -107,7 +107,7 @@ def test_not_single_node_random(self, mock_is_single_node) -> None: "learning_rate": [0.1], # reduce the parameters into one to accelerate the test process } - estimator = XGBClassifier() + estimator = XGBClassifier(n_jobs=3) reg = RandomizedSearchCV(estimator=estimator, param_distributions=parameters, cv=2, verbose=True) reg.set_input_cols(input_cols) output_cols = ["OUTPUT_" + c for c in label_col] diff --git a/tests/integ/snowflake/ml/modeling/pipeline/pipeline_test.py b/tests/integ/snowflake/ml/modeling/pipeline/pipeline_test.py index ffb98a72..1b73a595 100644 --- a/tests/integ/snowflake/ml/modeling/pipeline/pipeline_test.py +++ b/tests/integ/snowflake/ml/modeling/pipeline/pipeline_test.py @@ -10,6 +10,7 @@ import inflection import joblib import numpy as np +import pytest from absl.testing.absltest import TestCase, main from sklearn.compose import ColumnTransformer as SkColumnTransformer from sklearn.datasets import load_diabetes, load_iris @@ -62,6 +63,13 @@ def tearDown(self) -> None: if os.path.exists(filepath): os.remove(filepath) + @pytest.mark.skipif( + os.getenv("IN_SPCS_ML_RUNTIME") == "True", + reason=( + "Skipping this test on Container Runtimes. " + "See: https://snowflakecomputing.atlassian.net/browse/SNOW-1648870" + ), + ) def test_single_step(self) -> None: """ Test Pipeline with a single step. @@ -247,6 +255,13 @@ def test_pipeline_with_regression_estimators(self) -> None: np.testing.assert_allclose(actual_results, sk_predict_results, rtol=1.0e-1, atol=1.0e-2) + @pytest.mark.skipif( + os.getenv("IN_SPCS_ML_RUNTIME") == "True", + reason=( + "Skipping this test on Container Runtimes. " + "See: https://snowflakecomputing.atlassian.net/browse/SNOW-1648870" + ), + ) def test_pipeline_with_classifier_estimators(self) -> None: input_df_pandas = load_iris(as_frame=True).frame # Normalize column names @@ -481,6 +496,13 @@ def test_pipeline_signature(self) -> None: } self.assertEqual(model_signatures["predict"].to_dict(), expected_model_signatures["predict"].to_dict()) + @pytest.mark.skipif( + os.getenv("IN_SPCS_ML_RUNTIME") == "True", + reason=( + "Skipping this test on Container Runtimes. " + "See: https://snowflakecomputing.atlassian.net/browse/SNOW-1648870" + ), + ) def test_pipeline_with_label_encoder_output_col(self) -> None: input_df_pandas = load_diabetes(as_frame=True).frame # Normalize column names @@ -499,6 +521,13 @@ def test_pipeline_with_label_encoder_output_col(self) -> None: assert "TARGET_OUT" in snow_df_output.columns + @pytest.mark.skipif( + os.getenv("IN_SPCS_ML_RUNTIME") == "True", + reason=( + "Skipping this test on Container Runtimes. " + "See: https://snowflakecomputing.atlassian.net/browse/SNOW-1648870" + ), + ) def test_pipeline_score_samples(self) -> None: input_df_pandas = load_iris(as_frame=True).frame # Normalize column names diff --git a/tests/integ/snowflake/ml/modeling/preprocessing/label_encoder_test.py b/tests/integ/snowflake/ml/modeling/preprocessing/label_encoder_test.py index 5c409b71..23261408 100644 --- a/tests/integ/snowflake/ml/modeling/preprocessing/label_encoder_test.py +++ b/tests/integ/snowflake/ml/modeling/preprocessing/label_encoder_test.py @@ -1,13 +1,11 @@ import importlib import os -import pickle import sys import tempfile from typing import List from unittest import TestCase import cloudpickle -import joblib import numpy as np from absl.testing.absltest import main from sklearn.preprocessing import LabelEncoder as SklearnLabelEncoder @@ -197,8 +195,10 @@ def test_serde(self) -> None: with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as file: self._to_be_deleted_files.append(file.name) label_encoder_dump_cloudpickle = cloudpickle.dumps(label_encoder) - label_encoder_dump_pickle = pickle.dumps(label_encoder) - joblib.dump(label_encoder, file.name) + # disabling pickle and joblib serde due to the below error + # _pickle.PicklingError: Can't pickle : it's not the same object as snowflake.ml.modeling.preprocessing.label_encoder.LabelEncoder # noqa: E501 + # label_encoder_dump_pickle = pickle.dumps(label_encoder) + # joblib.dump(label_encoder, file.name) self._session.close() @@ -216,14 +216,14 @@ def test_serde(self) -> None: actual_arr_cloudpickle = transformed_df_cloudpickle[output_cols].to_pandas().to_numpy().flatten() # pickle - label_encoder_load_pickle = pickle.loads(label_encoder_dump_pickle) - transformed_df_pickle = label_encoder_load_pickle.transform(df2) - actual_arr_pickle = transformed_df_pickle[output_cols].to_pandas().to_numpy().flatten() + # label_encoder_load_pickle = pickle.loads(label_encoder_dump_pickle) + # transformed_df_pickle = label_encoder_load_pickle.transform(df2) + # actual_arr_pickle = transformed_df_pickle[output_cols].to_pandas().to_numpy().flatten() # joblib - label_encoder_load_joblib = joblib.load(file.name) - transformed_df_joblib = label_encoder_load_joblib.transform(df2) - actual_arr_joblib = transformed_df_joblib[output_cols].to_pandas().to_numpy().flatten() + # label_encoder_load_joblib = joblib.load(file.name) + # transformed_df_joblib = label_encoder_load_joblib.transform(df2) + # actual_arr_joblib = transformed_df_joblib[output_cols].to_pandas().to_numpy().flatten() # sklearn label_encoder_sklearn = SklearnLabelEncoder() @@ -231,8 +231,8 @@ def test_serde(self) -> None: sklearn_arr = label_encoder_sklearn.transform(df_pandas[input_cols]) np.testing.assert_allclose(actual_arr_cloudpickle, sklearn_arr) - np.testing.assert_allclose(actual_arr_pickle, sklearn_arr) - np.testing.assert_allclose(actual_arr_joblib, sklearn_arr) + # np.testing.assert_allclose(actual_arr_pickle, sklearn_arr) + # np.testing.assert_allclose(actual_arr_joblib, sklearn_arr) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/modeling/preprocessing/ordinal_encoder_test.py b/tests/integ/snowflake/ml/modeling/preprocessing/ordinal_encoder_test.py index d4959a12..9d15d52d 100644 --- a/tests/integ/snowflake/ml/modeling/preprocessing/ordinal_encoder_test.py +++ b/tests/integ/snowflake/ml/modeling/preprocessing/ordinal_encoder_test.py @@ -1,13 +1,11 @@ #!/usr/bin/env python3 import importlib import os -import pickle import sys import tempfile from typing import Any, Dict, List, Tuple import cloudpickle -import joblib import numpy as np import pandas as pd import pytest @@ -847,8 +845,9 @@ def test_serde(self) -> None: with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as file: self._to_be_deleted_files.append(file.name) encoder_dump_cloudpickle = cloudpickle.dumps(encoder) - encoder_dump_pickle = pickle.dumps(encoder) - joblib.dump(encoder, file.name) + # TODO(SNOW-1704904): Disabling pickle and joblib serde due to the below error + # encoder_dump_pickle = pickle.dumps(encoder) + # joblib.dump(encoder, file.name) self._session.close() @@ -866,14 +865,14 @@ def test_serde(self) -> None: actual_arr_cloudpickle = transformed_df_cloudpickle.sort(id_col)[output_cols].to_pandas().to_numpy() # pickle - encoder_load_pickle = pickle.loads(encoder_dump_pickle) - transformed_df_pickle = encoder_load_pickle.transform(df2[input_cols_extended]) - actual_arr_pickle = transformed_df_pickle.sort(id_col)[output_cols].to_pandas().to_numpy() + # encoder_load_pickle = pickle.loads(encoder_dump_pickle) + # transformed_df_pickle = encoder_load_pickle.transform(df2[input_cols_extended]) + # actual_arr_pickle = transformed_df_pickle.sort(id_col)[output_cols].to_pandas().to_numpy() # joblib - encoder_load_joblib = joblib.load(file.name) - transformed_df_joblib = encoder_load_joblib.transform(df2[input_cols_extended]) - actual_arr_joblib = transformed_df_joblib.sort(id_col)[output_cols].to_pandas().to_numpy() + # encoder_load_joblib = joblib.load(file.name) + # transformed_df_joblib = encoder_load_joblib.transform(df2[input_cols_extended]) + # actual_arr_joblib = transformed_df_joblib.sort(id_col)[output_cols].to_pandas().to_numpy() # sklearn encoder_sklearn = SklearnOrdinalEncoder() @@ -881,8 +880,8 @@ def test_serde(self) -> None: sklearn_arr = encoder_sklearn.transform(df_pandas[input_cols]) np.testing.assert_allclose(actual_arr_cloudpickle, sklearn_arr, equal_nan=True) - np.testing.assert_allclose(actual_arr_pickle, sklearn_arr, equal_nan=True) - np.testing.assert_allclose(actual_arr_joblib, sklearn_arr, equal_nan=True) + # np.testing.assert_allclose(actual_arr_pickle, sklearn_arr, equal_nan=True) + # np.testing.assert_allclose(actual_arr_joblib, sklearn_arr, equal_nan=True) def test_same_input_output_cols(self) -> None: """ diff --git a/tests/integ/snowflake/ml/observability/BUILD.bazel b/tests/integ/snowflake/ml/monitoring/BUILD.bazel similarity index 89% rename from tests/integ/snowflake/ml/observability/BUILD.bazel rename to tests/integ/snowflake/ml/monitoring/BUILD.bazel index f72c76ec..79f5d4a3 100644 --- a/tests/integ/snowflake/ml/observability/BUILD.bazel +++ b/tests/integ/snowflake/ml/monitoring/BUILD.bazel @@ -5,8 +5,8 @@ py_test( timeout = "long", srcs = ["model_monitor_integ_test.py"], deps = [ - "//snowflake/ml/beta/observability:observability_lib", "//snowflake/ml/model/_client/model:model_version_impl", + "//snowflake/ml/monitoring:model_monitor_impl", "//snowflake/ml/registry:registry_impl", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/test_utils:db_manager", diff --git a/tests/integ/snowflake/ml/observability/model_monitor_integ_test.py b/tests/integ/snowflake/ml/monitoring/model_monitor_integ_test.py similarity index 74% rename from tests/integ/snowflake/ml/observability/model_monitor_integ_test.py rename to tests/integ/snowflake/ml/monitoring/model_monitor_integ_test.py index 17436eae..630f2a25 100644 --- a/tests/integ/snowflake/ml/observability/model_monitor_integ_test.py +++ b/tests/integ/snowflake/ml/monitoring/model_monitor_integ_test.py @@ -4,24 +4,24 @@ from absl.testing import absltest, parameterized from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.beta.observability import ( - model_monitor, - model_monitor_config, - model_monitor_registry, - monitor_sql_client, -) from snowflake.ml.model._client.model import model_version_impl +from snowflake.ml.monitoring._client import model_monitor, monitor_sql_client +from snowflake.ml.monitoring.entities import model_monitor_config from snowflake.ml.registry import registry from snowflake.ml.utils import connection_params from snowflake.snowpark import Session from tests.integ.snowflake.ml.test_utils import db_manager, model_factory +INPUT_FEATURE_COLUMNS_NAMES = [f"input_feature_{i}" for i in range(64)] + class ModelMonitorRegistryIntegrationTest(parameterized.TestCase): - def _create_test_table(self, fully_qualified_table_name: str): + def _create_test_table(self, fully_qualified_table_name: str, id_column_type: str = "STRING") -> None: + s = ", ".join([f"{i} FLOAT" for i in INPUT_FEATURE_COLUMNS_NAMES]) self._session.sql( f"""CREATE OR REPLACE TABLE {fully_qualified_table_name} - (label FLOAT, prediction FLOAT, F1 FLOAT, id STRING, timestamp TIMESTAMP)""" + (label FLOAT, prediction FLOAT, + {s}, id {id_column_type}, timestamp TIMESTAMP)""" ).collect() @classmethod @@ -45,9 +45,6 @@ def setUp(self) -> None: db_name=self._db_name, sse_encrypted=True, ) - model_monitor_registry.ModelMonitorRegistry.setup( - session=self._session, database_name=self._db_name, schema_name=self._schema_name - ) self._warehouse_name = "REGTEST_ML_SMALL" self._db_manager.set_warehouse(self._warehouse_name) @@ -64,7 +61,7 @@ def tearDownClass(cls) -> None: def _add_sample_model_version_and_monitor( self, - monitor_registry: model_monitor_registry.ModelMonitorRegistry, + monitor_registry: registry.Registry, source_table: str, model_name: str, version_name: str, @@ -80,22 +77,27 @@ def _add_sample_model_version_and_monitor( return monitor_registry.add_monitor( name=monitor_name, - source_table_name=source_table, - model_monitor_config=model_monitor_config.ModelMonitorConfig( - model_version=model_version, - model_function_name="predict", + table_config=model_monitor_config.ModelMonitorTableConfig( + source_table=source_table, prediction_columns=["prediction"], label_columns=["label"], id_columns=["id"], timestamp_column="timestamp", + ), + model_monitor_config=model_monitor_config.ModelMonitorConfig( + model_version=model_version, + model_function_name="predict", background_compute_warehouse_name=self._warehouse_name, ), ) def test_add_model_monitor(self) -> None: - # Create an instance of the ModelMonitorRegistry class - _monitor_registry = model_monitor_registry.ModelMonitorRegistry( - session=self._session, database_name=self._db_name, schema_name=self._schema_name + # Create an instance of the Registry class with Monitoring enabled. + _monitor_registry = registry.Registry( + session=self._session, + database_name=self._db_name, + schema_name=self._schema_name, + options={"enable_monitoring": True}, ) source_table_name = "TEST_TABLE" @@ -133,14 +135,17 @@ def test_add_model_monitor(self) -> None: ).collect() for col in table_columns: - self.assertTrue(col["name"].upper() in ["PREDICTION", "LABEL", "F1", "ID", "TIMESTAMP"]) + self.assertTrue( + col["name"].upper() + in ["PREDICTION", "LABEL", "ID", "TIMESTAMP", *[i.upper() for i in INPUT_FEATURE_COLUMNS_NAMES]] + ) df = self._session.create_dataframe( [ - (1.0, 1.0, 1.0, "1", "2021-01-01 00:00:00"), - (2.0, 2.0, 2.0, "2", "2021-01-01 00:00:00"), + (1.0, 1.0, *[1.0] * 64), + (1.0, 1.0, *[1.0] * 64), ], - ["LABEL", "PREDICTION", "F1", "ID", "TIMESTAMP"], + ["LABEL", "PREDICTION", *[i.upper() for i in INPUT_FEATURE_COLUMNS_NAMES]], ) monitor.set_baseline(df) self.assertEqual( @@ -152,15 +157,14 @@ def test_add_model_monitor(self) -> None: 2, ) - pandas_df = pd.DataFrame( - { - "LABEL": [1.0, 2.0, 3.0], - "PREDICTION": [1.0, 2.0, 3.0], - "F1": [1.0, 2.0, 3.0], - "ID": ["1", "2", "3"], - "TIMESTAMP": ["2021-01-01 00:00:00", "2021-01-01 00:00:00", "2021-01-01 00:00:00"], - } - ) + pandas_cols = { + "LABEL": [1.0, 2.0, 3.0], + "PREDICTION": [1.0, 2.0, 3.0], + } + for i in range(64): + pandas_cols[f"INPUT_FEATURE_{i}"] = [1.0, 2.0, 3.0] + + pandas_df = pd.DataFrame(pandas_cols) monitor.set_baseline(pandas_df) self.assertEqual( self._session.sql( @@ -174,10 +178,10 @@ def test_add_model_monitor(self) -> None: # create a snowpark dataframe that does not conform to the existing schema df = self._session.create_dataframe( [ - (1.0, "bad", 1.0, "1", "2021-01-01 00:00:00"), - (2.0, "very_bad", 2.0, "2", "2021-01-01 00:00:00"), + (1.0, "bad", *[1.0] * 64), + (2.0, "very_bad", *[2.0] * 64), ], - ["LABEL", "PREDICTION", "F1", "ID", "TIMESTAMP"], + ["LABEL", "PREDICTION", *[i.upper() for i in INPUT_FEATURE_COLUMNS_NAMES]], ) with self.assertRaises(ValueError) as e: monitor.set_baseline(df) @@ -225,9 +229,39 @@ def test_add_model_monitor(self) -> None: 0, ) + def test_add_model_monitor_varchar(self) -> None: + _monitor_registry = registry.Registry( + session=self._session, + database_name=self._db_name, + schema_name=self._schema_name, + options={"enable_monitoring": True}, + ) + source_table = "TEST_TABLE" + self._create_test_table(f"{self._db_name}.{self._schema_name}.{source_table}", id_column_type="VARCHAR(64)") + + model_name = "TEST_MODEL" + version_name = "TEST_VERSION" + monitor_name = f"TEST_MONITOR_{model_name}_{version_name}_{self.run_id}" + self._add_sample_model_version_and_monitor( + _monitor_registry, source_table, model_name, version_name, monitor_name + ) + + self.assertEqual( + self._session.sql( + f"""SELECT * + FROM {self._db_name}.{self._schema_name}.{monitor_sql_client.SNOWML_MONITORING_METADATA_TABLE_NAME} + WHERE FULLY_QUALIFIED_MODEL_NAME = '{self._db_name}.{self._schema_name}.{model_name}' AND + MODEL_VERSION_NAME = '{version_name}'""" + ).count(), + 1, + ) + def test_show_model_monitors(self) -> None: - _monitor_registry = model_monitor_registry.ModelMonitorRegistry( - session=self._session, database_name=self._db_name, schema_name=self._schema_name + _monitor_registry = registry.Registry( + session=self._session, + database_name=self._db_name, + schema_name=self._schema_name, + options={"enable_monitoring": True}, ) source_table_1 = "TEST_TABLE_1" self._create_test_table(f"{self._db_name}.{self._schema_name}.{source_table_1}") diff --git a/tests/integ/snowflake/ml/registry/BUILD.bazel b/tests/integ/snowflake/ml/registry/BUILD.bazel index 93451bd4..c38fdce1 100644 --- a/tests/integ/snowflake/ml/registry/BUILD.bazel +++ b/tests/integ/snowflake/ml/registry/BUILD.bazel @@ -1,84 +1,4 @@ -load("//bazel:py_rules.bzl", "py_library", "py_test") - -py_test( - name = "model_registry_basic_integ_test", - timeout = "long", - srcs = ["model_registry_basic_integ_test.py"], - deps = [ - "//snowflake/ml/registry:model_registry", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:db_manager", - ], -) - -py_test( - name = "model_registry_integ_test", - timeout = "long", - srcs = ["model_registry_integ_test.py"], - shard_count = 3, - deps = [ - "//snowflake/ml/registry:model_registry", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:db_manager", - "//tests/integ/snowflake/ml/test_utils:model_factory", - "//tests/integ/snowflake/ml/test_utils:test_env_utils", - ], -) - -py_test( - name = "model_registry_compat_test", - timeout = "long", - srcs = ["model_registry_compat_test.py"], - shard_count = 4, - deps = [ - "//snowflake/ml/_internal:env", - "//snowflake/ml/registry:model_registry", - "//tests/integ/snowflake/ml/test_utils:common_test_base", - "//tests/integ/snowflake/ml/test_utils:db_manager", - ], -) - -py_test( - name = "model_registry_schema_evolution_integ_test", - timeout = "long", - srcs = ["model_registry_schema_evolution_integ_test.py"], - shard_count = 2, - deps = [ - "//snowflake/ml/registry:model_registry", - "//snowflake/ml/utils:connection_params", - "//tests/integ/snowflake/ml/test_utils:db_manager", - "//tests/integ/snowflake/ml/test_utils:model_factory", - "//tests/integ/snowflake/ml/test_utils:test_env_utils", - ], -) - -py_library( - name = "model_registry_snowservice_integ_test_base", - testonly = True, - srcs = ["model_registry_snowservice_integ_test_base.py"], - deps = [ - "//snowflake/ml/model:deploy_platforms", - "//snowflake/ml/registry:model_registry", - "//tests/integ/snowflake/ml/test_utils:model_factory", - "//tests/integ/snowflake/ml/test_utils:spcs_integ_test_base", - "//tests/integ/snowflake/ml/test_utils:test_env_utils", - ], -) - -py_test( - name = "model_registry_snowservice_integ_test", - timeout = "eternal", # 3600s - srcs = ["model_registry_snowservice_integ_test.py"], - deps = [":model_registry_snowservice_integ_test_base"], -) - -py_test( - name = "model_registry_snowservice_merge_gate_integ_test", - timeout = "eternal", # 3600s - srcs = ["model_registry_snowservice_merge_gate_integ_test.py"], - shard_count = 2, - deps = [":model_registry_snowservice_integ_test_base"], -) +load("//bazel:py_rules.bzl", "py_test") py_test( name = "registry_compat_test", diff --git a/tests/integ/snowflake/ml/registry/model/BUILD.bazel b/tests/integ/snowflake/ml/registry/model/BUILD.bazel index 9b21f080..6e767abf 100644 --- a/tests/integ/snowflake/ml/registry/model/BUILD.bazel +++ b/tests/integ/snowflake/ml/registry/model/BUILD.bazel @@ -145,13 +145,15 @@ py_test( name = "registry_modeling_model_test", timeout = "long", srcs = ["registry_modeling_model_test.py"], - shard_count = 2, + shard_count = 4, deps = [ ":registry_model_test_base", + "//snowflake/ml/dataset", "//snowflake/ml/modeling/lightgbm:lgbm_regressor", "//snowflake/ml/modeling/linear_model:logistic_regression", "//snowflake/ml/modeling/pipeline", "//snowflake/ml/modeling/xgboost:xgb_regressor", + "//tests/integ/snowflake/ml/test_utils:dataframe_utils", ], ) diff --git a/tests/integ/snowflake/ml/registry/model/multiple_model_test.py b/tests/integ/snowflake/ml/registry/model/multiple_model_test.py index 056bc865..f58d08d5 100644 --- a/tests/integ/snowflake/ml/registry/model/multiple_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/multiple_model_test.py @@ -21,7 +21,7 @@ def predict(self, input: pd.DataFrame) -> pd.DataFrame: return pd.DataFrame({"output": input["c1"] + self.bias}) -class ModelWithAdditionalImportTest(registry_model_test_base.RegistryModelTestBase): +class MultipleModelTest(registry_model_test_base.RegistryModelTestBase): def test_multiple_model(self) -> None: version = "v1" arr = np.array([[1], [4]]) @@ -35,7 +35,7 @@ def test_multiple_model(self) -> None: ) name_1 = f"model_{self._run_id}_1" - self.registry.log_model(lm_1, model_name=name_1, version_name=version, sample_input_data=pd_df) + mv1 = self.registry.log_model(lm_1, model_name=name_1, version_name=version, sample_input_data=pd_df) with tempfile.TemporaryDirectory() as tmpdir: with open(os.path.join(tmpdir, "bias"), "w", encoding="utf-8") as f: @@ -45,7 +45,7 @@ def test_multiple_model(self) -> None: ) name_2 = f"model_{self._run_id}_2" - self.registry.log_model(lm_2, model_name=name_2, version_name=version, sample_input_data=pd_df) + mv2 = self.registry.log_model(lm_2, model_name=name_2, version_name=version, sample_input_data=pd_df) res = ( self.session.sql(f"SELECT {name_1}!predict(1):output as A, {name_2}!predict(1):output as B") @@ -55,6 +55,13 @@ def test_multiple_model(self) -> None: self.assertDictEqual(res, {"A": "11", "B": "21"}) + res = ( + mv1.run(mv2.run(self.session.create_dataframe(pd_df)).select('"output"').rename({'"output"': '"c1"'})) + .select('"output"') + .to_pandas() + ) + pd.testing.assert_frame_equal(res, pd.DataFrame({"output": [31, 34]})) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_catboost_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_catboost_model_test.py index 939507f1..aa194e40 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_catboost_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_catboost_model_test.py @@ -6,6 +6,7 @@ from absl.testing import absltest, parameterized from sklearn import datasets, model_selection +from snowflake.ml.model import model_signature from tests.integ.snowflake.ml.registry.model import registry_model_test_base from tests.integ.snowflake.ml.test_utils import dataframe_utils @@ -194,6 +195,89 @@ def test_catboost_classifier_explain_sp( }, ) + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_catboost_with_signature_and_sample_data( + self, + registry_test_fn: str, + ) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True) + cal_X = pd.DataFrame(cal_data.data, columns=cal_data.feature_names) + cal_y = pd.Series(cal_data.target) + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + + classifier = catboost.CatBoostClassifier() + classifier.fit(cal_X_train, cal_y_train) + y_pred = classifier.predict(cal_X_test) + y_pred_proba = classifier.predict_proba(cal_X_test) + y_pred_log_proba = classifier.predict_log_proba(cal_X_test) + sig = { + "predict": model_signature.infer_signature(cal_X_test, y_pred), + "predict_proba": model_signature.infer_signature(cal_X_test, y_pred_proba), + "predict_log_proba": model_signature.infer_signature(cal_X_test, y_pred_log_proba), + } + expected_explanations = shap.Explainer(classifier)(cal_X_test).values + + # passing both signature and sample_input_data when enable_explainability is True + getattr(self, registry_test_fn)( + model=classifier, + sample_input_data=cal_X_test, + prediction_assert_fns={ + "predict": ( + cal_X_test, + lambda res: np.testing.assert_allclose( + res.values, np.expand_dims(classifier.predict(cal_X_test), axis=1) + ), + ), + "predict_proba": ( + cal_X_test, + lambda res: np.testing.assert_allclose(res.values, classifier.predict_proba(cal_X_test)), + ), + "predict_log_proba": ( + cal_X_test, + lambda res: np.testing.assert_allclose(res.values, classifier.predict_log_proba(cal_X_test)), + ), + "explain": ( + cal_X_test, + lambda res: np.testing.assert_allclose(res.values, expected_explanations), + ), + }, + options={"enable_explainability": True}, + signatures=sig, + ) + + with self.assertRaisesRegex( + ValueError, "Signatures and sample_input_data both cannot be specified at the same time." + ): + getattr(self, registry_test_fn)( + model=classifier, + sample_input_data=cal_X_test, + prediction_assert_fns={ + "predict": ( + cal_X_test, + lambda res: np.testing.assert_allclose( + res.values, np.expand_dims(classifier.predict(cal_X_test), axis=1) + ), + ), + "predict_proba": ( + cal_X_test, + lambda res: np.testing.assert_allclose(res.values, classifier.predict_proba(cal_X_test)), + ), + "predict_log_proba": ( + cal_X_test, + lambda res: np.testing.assert_allclose(res.values, classifier.predict_log_proba(cal_X_test)), + ), + "explain": ( + cal_X_test, + lambda res: np.testing.assert_allclose(res.values, expected_explanations), + ), + }, + signatures=sig, + additional_version_suffix="v2", + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_in_sproc_test.py b/tests/integ/snowflake/ml/registry/model/registry_in_sproc_test.py index a3e73552..a4fd18f3 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_in_sproc_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_in_sproc_test.py @@ -11,7 +11,7 @@ class RegistryInSprocTest(registry_model_test_base.RegistryModelTestBase): - @common_test_base.CommonTestBase.sproc_test(test_owners_rights=False) + @common_test_base.CommonTestBase.sproc_test(test_owners_rights=False, additional_packages=["inflection"]) def test_workflow(self) -> None: model, test_features, _ = model_factory.ModelFactory.prepare_sklearn_model() self.mv_1 = self.registry.log_model( diff --git a/tests/integ/snowflake/ml/registry/model/registry_lightgbm_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_lightgbm_model_test.py index 149722e9..250c70b7 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_lightgbm_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_lightgbm_model_test.py @@ -6,6 +6,7 @@ from absl.testing import absltest, parameterized from sklearn import datasets, model_selection +from snowflake.ml.model import model_signature from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils from tests.integ.snowflake.ml.registry.model import registry_model_test_base from tests.integ.snowflake.ml.test_utils import dataframe_utils @@ -345,6 +346,73 @@ def test_lightgbm_booster_explain_sp( }, ) + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_lightgbm_with_signature_and_sample_data( + self, + registry_test_fn: str, + ) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True) + cal_X = cal_data.data + cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + + classifier = lightgbm.LGBMClassifier() + classifier.fit(cal_X_train, cal_y_train) + y_pred = pd.DataFrame(classifier.predict(cal_X_test), columns=["output_feature_0"]) + sig = { + "predict": model_signature.infer_signature(cal_X_test, y_pred), + } + + expected_explanations = shap.Explainer(classifier)(cal_X_test).values + + getattr(self, registry_test_fn)( + model=classifier, + sample_input_data=cal_X_test, + prediction_assert_fns={ + "predict": ( + cal_X_test, + lambda res: np.testing.assert_allclose( + res.values, np.expand_dims(classifier.predict(cal_X_test), axis=1) + ), + ), + "explain": ( + cal_X_test, + lambda res: np.testing.assert_allclose( + dataframe_utils.convert2D_json_to_3D(res.values), expected_explanations, rtol=1e-5 + ), + ), + }, + options={"enable_explainability": True}, + signatures=sig, + ) + + with self.assertRaisesRegex( + ValueError, "Signatures and sample_input_data both cannot be specified at the same time." + ): + getattr(self, registry_test_fn)( + model=classifier, + sample_input_data=cal_X_test, + prediction_assert_fns={ + "predict": ( + cal_X_test, + lambda res: np.testing.assert_allclose( + res.values, np.expand_dims(classifier.predict(cal_X_test), axis=1) + ), + ), + "explain": ( + cal_X_test, + lambda res: np.testing.assert_allclose( + dataframe_utils.convert2D_json_to_3D(res.values), expected_explanations, rtol=1e-5 + ), + ), + }, + signatures=sig, + additional_version_suffix="v2", + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py b/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py index 09aeef66..de08b835 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py +++ b/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py @@ -2,7 +2,7 @@ import uuid from typing import Any, Callable, Dict, List, Optional, Tuple -from snowflake.ml.model import type_hints as model_types +from snowflake.ml.model import model_signature, type_hints as model_types from snowflake.ml.registry import registry from tests.integ.snowflake.ml.test_utils import ( common_test_base, @@ -41,6 +41,8 @@ def _test_registry_model( sample_input_data: Optional[model_types.SupportedDataType] = None, additional_dependencies: Optional[List[str]] = None, options: Optional[model_types.ModelSaveOption] = None, + signatures: Optional[Dict[str, model_signature.ModelSignature]] = None, + additional_version_suffix: Optional[str] = None, ) -> None: conda_dependencies = [ test_env_utils.get_latest_package_version_spec_in_server(self.session, "snowflake-snowpark-python!=1.12.0") @@ -48,9 +50,13 @@ def _test_registry_model( if additional_dependencies: conda_dependencies.extend(additional_dependencies) + version_suffix = self._run_id + if additional_version_suffix: + version_suffix = version_suffix + "_" + additional_version_suffix + # Get the name of the caller as the model name name = f"model_{inspect.stack()[1].function}" - version = f"ver_{self._run_id}" + version = f"ver_{version_suffix}" mv = self.registry.log_model( model=model, model_name=name, @@ -58,6 +64,7 @@ def _test_registry_model( sample_input_data=sample_input_data, conda_dependencies=conda_dependencies, options=options, + signatures=signatures, ) for target_method, (test_input, check_func) in prediction_assert_fns.items(): @@ -77,6 +84,8 @@ def _test_registry_model_from_model_version( sample_input_data: Optional[model_types.SupportedDataType] = None, additional_dependencies: Optional[List[str]] = None, options: Optional[model_types.ModelSaveOption] = None, + signatures: Optional[Dict[str, model_signature.ModelSignature]] = None, + additional_version_suffix: Optional[str] = None, ) -> None: conda_dependencies = [ test_env_utils.get_latest_package_version_spec_in_server(self.session, "snowflake-snowpark-python!=1.12.0") @@ -84,11 +93,15 @@ def _test_registry_model_from_model_version( if additional_dependencies: conda_dependencies.extend(additional_dependencies) + version_suffix = self._run_id + if additional_version_suffix: + version_suffix = version_suffix + "_" + additional_version_suffix + # Get the name of the caller as the model name source_name = f"source_model_{inspect.stack()[1].function}" name = f"model_{inspect.stack()[1].function}" - source_version = f"source_ver_{self._run_id}" - version = f"ver_{self._run_id}" + source_version = f"source_ver_{version_suffix}" + version = f"ver_{version_suffix}" source_mv = self.registry.log_model( model=model, model_name=source_name, @@ -96,6 +109,7 @@ def _test_registry_model_from_model_version( sample_input_data=sample_input_data, conda_dependencies=conda_dependencies, options=options, + signatures=signatures, ) # Create a new model when the model doesn't exist @@ -112,7 +126,7 @@ def _test_registry_model_from_model_version( self.registry.show_models() # Add a version when the model exists - version2 = f"ver_{self._run_id}_2" + version2 = f"ver_{version_suffix}_2" mv2 = self.registry.log_model( model=source_mv, model_name=name, diff --git a/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py index 4da12e1d..c52eda23 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py @@ -8,6 +8,7 @@ from sklearn import datasets from snowflake.ml import dataset +from snowflake.ml._internal.utils import identifier from snowflake.ml.model._model_composer import model_composer from snowflake.ml.modeling.lightgbm import LGBMRegressor from snowflake.ml.modeling.linear_model import LogisticRegression @@ -15,14 +16,14 @@ from snowflake.ml.modeling.xgboost import XGBRegressor from snowflake.snowpark import types as T from tests.integ.snowflake.ml.registry.model import registry_model_test_base -from tests.integ.snowflake.ml.test_utils import test_env_utils +from tests.integ.snowflake.ml.test_utils import dataframe_utils, test_env_utils class TestRegistryModelingModelInteg(registry_model_test_base.RegistryModelTestBase): @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) - def test_snowml_model_deploy_snowml_sklearn( + def test_snowml_model_deploy_snowml_sklearn_explain_disabled( self, registry_test_fn: str, ) -> None: @@ -46,12 +47,13 @@ def test_snowml_model_deploy_snowml_sklearn( ), ), }, + options={"enable_explainability": False}, ) @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) - def test_snowml_model_deploy_xgboost( + def test_snowml_model_deploy_snowml_sklearn_explain_default( self, registry_test_fn: str, ) -> None: @@ -61,10 +63,14 @@ def test_snowml_model_deploy_xgboost( INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] LABEL_COLUMNS = "TARGET" OUTPUT_COLUMNS = "PREDICTED_TARGET" - regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) - test_features = iris_X[:10] + EXPLAIN_OUTPUT_COLUMNS = [identifier.concat_names([feature, "_explanation"]) for feature in INPUT_COLUMNS] + regr = LogisticRegression(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X regr.fit(test_features) + test_data = test_features[INPUT_COLUMNS] + expected_explanations = shap.Explainer(regr.to_sklearn(), masker=test_data)(test_data).values + getattr(self, registry_test_fn)( model=regr, prediction_assert_fns={ @@ -74,14 +80,22 @@ def test_snowml_model_deploy_xgboost( res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values ), ), + "explain": ( + test_features, + lambda res: np.testing.assert_allclose( + dataframe_utils.convert2D_json_to_3D(res[EXPLAIN_OUTPUT_COLUMNS].values), + expected_explanations, + rtol=1e-4, + ), + ), }, - options={"enable_explainability": False}, + sample_input_data=test_data, ) @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) - def test_snowml_model_deploy_xgboost_explain_default( + def test_snowml_model_deploy_snowml_sklearn_explain_enabled( self, registry_test_fn: str, ) -> None: @@ -90,37 +104,40 @@ def test_snowml_model_deploy_xgboost_explain_default( INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] LABEL_COLUMNS = "TARGET" - PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" - EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] - - regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + OUTPUT_COLUMNS = "PREDICTED_TARGET" + EXPLAIN_OUTPUT_COLUMNS = [identifier.concat_names([feature, "_explanation"]) for feature in INPUT_COLUMNS] + regr = LogisticRegression(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) test_features = iris_X regr.fit(test_features) - expected_explanations = shap.Explainer(regr.to_xgboost())(test_features[INPUT_COLUMNS]).values - + test_data = test_features[INPUT_COLUMNS] + expected_explanations = shap.Explainer(regr.to_sklearn(), masker=test_data)(test_data).values getattr(self, registry_test_fn)( model=regr, prediction_assert_fns={ "predict": ( test_features, lambda res: np.testing.assert_allclose( - res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values + res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values ), ), "explain": ( test_features, lambda res: np.testing.assert_allclose( - res[EXPLAIN_OUTPUT_COLUMNS].values, expected_explanations, rtol=1e-4 + dataframe_utils.convert2D_json_to_3D(res[EXPLAIN_OUTPUT_COLUMNS].values), + expected_explanations, + rtol=1e-4, ), ), }, + sample_input_data=test_data, + options={"enable_explainability": True}, ) @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) - def test_snowml_model_deploy_xgboost_explain_enabled( + def test_snowml_model_deploy_xgboost_explain_disabled( self, registry_test_fn: str, ) -> None: @@ -129,38 +146,28 @@ def test_snowml_model_deploy_xgboost_explain_enabled( INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] LABEL_COLUMNS = "TARGET" - PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" - EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] - - regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) - test_features = iris_X + OUTPUT_COLUMNS = "PREDICTED_TARGET" + regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X[:10] regr.fit(test_features) - expected_explanations = shap.Explainer(regr.to_xgboost())(test_features[INPUT_COLUMNS]).values - getattr(self, registry_test_fn)( model=regr, prediction_assert_fns={ "predict": ( test_features, lambda res: np.testing.assert_allclose( - res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values - ), - ), - "explain": ( - test_features, - lambda res: np.testing.assert_allclose( - res[EXPLAIN_OUTPUT_COLUMNS].values, expected_explanations, rtol=1e-4 + res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values ), ), }, - options={"enable_explainability": True}, + options={"enable_explainability": False}, ) @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) - def test_snowml_model_deploy_xgboost_explain( + def test_snowml_model_deploy_xgboost_explain_default( self, registry_test_fn: str, ) -> None: @@ -170,7 +177,7 @@ def test_snowml_model_deploy_xgboost_explain( INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] LABEL_COLUMNS = "TARGET" PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" - EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] + EXPLAIN_OUTPUT_COLUMNS = [identifier.concat_names([feature, "_explanation"]) for feature in INPUT_COLUMNS] regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) test_features = iris_X @@ -194,13 +201,12 @@ def test_snowml_model_deploy_xgboost_explain( ), ), }, - options={"enable_explainability": True}, ) @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) - def test_snowml_model_deploy_lightgbm( + def test_snowml_model_deploy_xgboost_explain_enabled( self, registry_test_fn: str, ) -> None: @@ -209,28 +215,38 @@ def test_snowml_model_deploy_lightgbm( INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] LABEL_COLUMNS = "TARGET" - OUTPUT_COLUMNS = "PREDICTED_TARGET" - regr = LGBMRegressor(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) - test_features = iris_X[:10] + PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" + EXPLAIN_OUTPUT_COLUMNS = [identifier.concat_names([feature, "_explanation"]) for feature in INPUT_COLUMNS] + + regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X regr.fit(test_features) + expected_explanations = shap.Explainer(regr.to_xgboost())(test_features[INPUT_COLUMNS]).values + getattr(self, registry_test_fn)( model=regr, prediction_assert_fns={ "predict": ( test_features, lambda res: np.testing.assert_allclose( - res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values + res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values + ), + ), + "explain": ( + test_features, + lambda res: np.testing.assert_allclose( + res[EXPLAIN_OUTPUT_COLUMNS].values, expected_explanations, rtol=1e-4 ), ), }, - options={"enable_explainability": False}, + options={"enable_explainability": True}, ) @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) - def test_snowml_model_deploy_lightgbm_explain_default( + def test_snowml_model_deploy_lightgbm_explain_disabled( self, registry_test_fn: str, ) -> None: @@ -239,38 +255,28 @@ def test_snowml_model_deploy_lightgbm_explain_default( INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] LABEL_COLUMNS = "TARGET" - PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" - EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] - regr = LGBMRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) - test_features = iris_X + OUTPUT_COLUMNS = "PREDICTED_TARGET" + regr = LGBMRegressor(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X[:10] regr.fit(test_features) - expected_explanations = shap.Explainer(regr.to_lightgbm())(test_features[INPUT_COLUMNS]).values - getattr(self, registry_test_fn)( model=regr, prediction_assert_fns={ "predict": ( test_features, lambda res: np.testing.assert_allclose( - res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values - ), - ), - "explain": ( - test_features, - lambda res: np.testing.assert_allclose( - res[EXPLAIN_OUTPUT_COLUMNS].values, - expected_explanations, - rtol=1e-5, + res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values ), ), }, + options={"enable_explainability": False}, ) @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) - def test_snowml_model_deploy_lightgbm_explain_enabled( + def test_snowml_model_deploy_lightgbm_explain_default( self, registry_test_fn: str, ) -> None: @@ -280,7 +286,7 @@ def test_snowml_model_deploy_lightgbm_explain_enabled( INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] LABEL_COLUMNS = "TARGET" PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" - EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] + EXPLAIN_OUTPUT_COLUMNS = [identifier.concat_names([feature, "_explanation"]) for feature in INPUT_COLUMNS] regr = LGBMRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) test_features = iris_X regr.fit(test_features) @@ -305,13 +311,12 @@ def test_snowml_model_deploy_lightgbm_explain_enabled( ), ), }, - options={"enable_explainability": True}, ) @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) - def test_snowml_model_deploy_lightgbm_explain( + def test_snowml_model_deploy_lightgbm_explain_enabled( self, registry_test_fn: str, ) -> None: @@ -321,13 +326,12 @@ def test_snowml_model_deploy_lightgbm_explain( INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] LABEL_COLUMNS = "TARGET" PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" - EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] + EXPLAIN_OUTPUT_COLUMNS = [identifier.concat_names([feature, "_explanation"]) for feature in INPUT_COLUMNS] regr = LGBMRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) test_features = iris_X regr.fit(test_features) expected_explanations = shap.Explainer(regr.to_lightgbm())(test_features[INPUT_COLUMNS]).values - print(expected_explanations) getattr(self, registry_test_fn)( model=regr, diff --git a/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py index cf1c435a..c6861ad7 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py @@ -6,6 +6,7 @@ from absl.testing import absltest, parameterized from sklearn import datasets, ensemble, linear_model, multioutput +from snowflake.ml.model import model_signature from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils from snowflake.snowpark import exceptions as snowpark_exceptions from tests.integ.snowflake.ml.registry.model import registry_model_test_base @@ -53,18 +54,6 @@ def test_skl_model_explain( classifier.fit(iris_X_df, iris_y) expected_explanations = shap.Explainer(classifier, iris_X_df)(iris_X_df).values - with self.assertRaisesRegex( - ValueError, - "Sample input data is required to enable explainability. Currently we only support this for " - + "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`.", - ): - getattr(self, registry_test_fn)( - model=classifier, - sample_input_data=iris_X, - prediction_assert_fns={}, - options={"enable_explainability": True}, - ) - getattr(self, registry_test_fn)( model=classifier, sample_input_data=iris_X_df, @@ -220,6 +209,68 @@ def test_skl_unsupported_explain( self.assertNotIn(mv.model_name, [m.name for m in self.registry.models()]) + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_skl_model_with_signature_and_sample_data( + self, + registry_test_fn: str, + ) -> None: + iris_X, iris_y = datasets.load_iris(return_X_y=True) + # sample input needs to be pandas dataframe for now + iris_X_df = pd.DataFrame(iris_X, columns=["c1", "c2", "c3", "c4"]) + classifier = linear_model.LogisticRegression() + classifier.fit(iris_X_df, iris_y) + expected_explanations = shap.Explainer(classifier, iris_X_df)(iris_X_df).values + + y_pred = pd.DataFrame(classifier.predict(iris_X_df), columns=["output_feature_0"]) + sig = { + "predict": model_signature.infer_signature(iris_X_df, y_pred), + } + + getattr(self, registry_test_fn)( + model=classifier, + sample_input_data=iris_X_df, + prediction_assert_fns={ + "predict": ( + iris_X_df, + lambda res: np.testing.assert_allclose(res["output_feature_0"].values, classifier.predict(iris_X)), + ), + "explain": ( + iris_X_df, + lambda res: np.testing.assert_allclose( + dataframe_utils.convert2D_json_to_3D(res.values), expected_explanations + ), + ), + }, + options={"enable_explainability": True}, + signatures=sig, + ) + + with self.assertRaisesRegex( + ValueError, "Signatures and sample_input_data both cannot be specified at the same time." + ): + getattr(self, registry_test_fn)( + model=classifier, + sample_input_data=iris_X_df, + prediction_assert_fns={ + "predict": ( + iris_X_df, + lambda res: np.testing.assert_allclose( + res["output_feature_0"].values, classifier.predict(iris_X) + ), + ), + "explain": ( + iris_X_df, + lambda res: np.testing.assert_allclose( + dataframe_utils.convert2D_json_to_3D(res.values), expected_explanations + ), + ), + }, + signatures=sig, + additional_version_suffix="v2", + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py index ae8eb46c..a691998a 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py @@ -6,6 +6,7 @@ from absl.testing import absltest, parameterized from sklearn import datasets, model_selection +from snowflake.ml.model import model_signature from tests.integ.snowflake.ml.registry.model import registry_model_test_base from tests.integ.snowflake.ml.test_utils import dataframe_utils @@ -326,6 +327,55 @@ def test_xgb_booster_explain_sp( }, ) + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_xgb_booster_with_signature_and_sample_data( + self, + registry_test_fn: str, + ) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True) + cal_X = cal_data.data + cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + params = dict(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3, objective="binary:logistic") + regressor = xgboost.train(params, xgboost.DMatrix(data=cal_X_train, label=cal_y_train)) + y_pred = pd.DataFrame( + regressor.predict(xgboost.DMatrix(data=cal_X_test)), + columns=["output_feature_0"], + ) + expected_explanations = shap.Explainer(regressor)(cal_X_test).values + sig = {"predict": model_signature.infer_signature(cal_X_test, y_pred)} + getattr(self, registry_test_fn)( + model=regressor, + sample_input_data=cal_X_test, + prediction_assert_fns={ + "explain": ( + cal_X_test, + lambda res: np.testing.assert_allclose(res.values, expected_explanations, rtol=1e-4), + ), + }, + options={"enable_explainability": True}, + signatures=sig, + ) + + with self.assertRaisesRegex( + ValueError, "Signatures and sample_input_data both cannot be specified at the same time." + ): + getattr(self, registry_test_fn)( + model=regressor, + sample_input_data=cal_X_test, + prediction_assert_fns={ + "explain": ( + cal_X_test, + lambda res: np.testing.assert_allclose(res.values, expected_explanations, rtol=1e-4), + ), + }, + signatures=sig, + additional_version_suffix="v2", + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model_registry_basic_integ_test.py b/tests/integ/snowflake/ml/registry/model_registry_basic_integ_test.py deleted file mode 100644 index 3728da6b..00000000 --- a/tests/integ/snowflake/ml/registry/model_registry_basic_integ_test.py +++ /dev/null @@ -1,171 +0,0 @@ -import uuid -from typing import Optional - -from absl.testing import absltest, parameterized - -from snowflake.ml.registry import model_registry -from snowflake.ml.utils import connection_params -from snowflake.snowpark import Session -from tests.integ.snowflake.ml.test_utils import db_manager - -_RUN_ID = uuid.uuid4().hex -_PRE_CREATED_DB_NAME_UPPER = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - _RUN_ID, "REGISTRY_PRE_CREATED_DB_SYSTEM_UPPER" -).upper() -_PRE_CREATED_DB_NAME_LOWER = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - _RUN_ID, "registry_pre_created_db_system_lower" -).lower() -_PRE_CREATED_DB_AND_SCHEMA_NAME_UPPER = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - _RUN_ID, "REGISTRY_PRE_CREATED_DB_AND_SCHEMA_UPPER" -).upper() -_PRE_CREATED_DB_AND_SCHEMA_NAME_LOWER = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - _RUN_ID, "registry_pre_created_db_and_schema_lower" -).lower() -_CUSTOM_NEW_DB_NAME_UPPER = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - _RUN_ID, "REGISTRY_NEW_DB_CUSTOM_UPPER" -).upper() -_CUSTOM_NEW_DB_NAME_LOWER = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - _RUN_ID, "registry_new_db_custom_lower" -).lower() -_CUSTOM_NEW_SCHEMA_NAME_UPPER = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - _RUN_ID, "REGISTRY_NEW_SCHEMA_CUSTOM_UPPER" -).upper() -_CUSTOM_NEW_SCHEMA_NAME_LOWER = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - _RUN_ID, "registry_new_schema_custom_lower" -).lower() - - -class TestModelRegistryBasicInteg(parameterized.TestCase): - @classmethod - def setUpClass(cls) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - cls._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - cls._database = cls._session.get_current_database() - cls._schema = cls._session.get_current_schema() - assert cls._database is not None - assert cls._schema is not None - - cls._db_manager = db_manager.DBManager(cls._session) - - cls._db_manager.cleanup_databases() - - cls._db_manager.create_database(_PRE_CREATED_DB_NAME_UPPER) - cls._db_manager.create_database(_PRE_CREATED_DB_NAME_LOWER) - cls._db_manager.create_schema( - _PRE_CREATED_DB_AND_SCHEMA_NAME_UPPER, - _PRE_CREATED_DB_AND_SCHEMA_NAME_UPPER, - ) - cls._db_manager.create_schema( - _PRE_CREATED_DB_AND_SCHEMA_NAME_LOWER, - _PRE_CREATED_DB_AND_SCHEMA_NAME_LOWER, - ) - - # restore the session to use the original database and schema - cls._session.use_database(cls._database) - cls._session.use_schema(cls._schema) - assert cls._database == cls._session.get_current_database() - assert cls._schema == cls._session.get_current_schema() - - @classmethod - def tearDownClass(cls) -> None: - cls._db_manager.drop_database(_PRE_CREATED_DB_NAME_UPPER, if_exists=True) - cls._db_manager.drop_database(_PRE_CREATED_DB_NAME_LOWER, if_exists=True) - cls._db_manager.drop_database(_PRE_CREATED_DB_AND_SCHEMA_NAME_UPPER, if_exists=True) - cls._db_manager.drop_database(_PRE_CREATED_DB_AND_SCHEMA_NAME_UPPER, if_exists=True) - cls._db_manager.drop_database(_CUSTOM_NEW_DB_NAME_UPPER, if_exists=True) - cls._db_manager.drop_database(_CUSTOM_NEW_DB_NAME_LOWER, if_exists=True) - cls._session.close() - - def _validate_restore_db_and_schema(self) -> None: - """Validate that the database and schema are restored after creating registry.""" - self.assertEqual(self._database, self._session.get_current_database()) - self.assertEqual(self._schema, self._session.get_current_schema()) - - @parameterized.parameters( # type: ignore[misc] - {"database_name": _PRE_CREATED_DB_NAME_UPPER, "schema_name": None}, - { - "database_name": db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - _RUN_ID, "REGISTRY_NEW_DB_SYSTEM_UPPER" - ).upper(), - "schema_name": None, - }, - { - "database_name": _CUSTOM_NEW_DB_NAME_UPPER, - "schema_name": _CUSTOM_NEW_SCHEMA_NAME_UPPER, - }, - {"database_name": _PRE_CREATED_DB_NAME_LOWER, "schema_name": None}, - { - "database_name": db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - _RUN_ID, "registry_new_db_system_lower" - ).lower(), - "schema_name": None, - }, - { - "database_name": _CUSTOM_NEW_DB_NAME_LOWER, - "schema_name": _CUSTOM_NEW_SCHEMA_NAME_LOWER, - }, - { - "database_name": db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - _RUN_ID, 'registry_new_db_system_with""' - ).lower(), - "schema_name": None, - }, - { - "database_name": _PRE_CREATED_DB_AND_SCHEMA_NAME_UPPER, - "schema_name": _CUSTOM_NEW_SCHEMA_NAME_UPPER, - }, - { - "database_name": _PRE_CREATED_DB_AND_SCHEMA_NAME_LOWER, - "schema_name": _CUSTOM_NEW_SCHEMA_NAME_LOWER, - }, - { - "database_name": _PRE_CREATED_DB_AND_SCHEMA_NAME_UPPER, - "schema_name": _PRE_CREATED_DB_AND_SCHEMA_NAME_UPPER, - }, - { - "database_name": _PRE_CREATED_DB_AND_SCHEMA_NAME_LOWER, - "schema_name": _PRE_CREATED_DB_AND_SCHEMA_NAME_LOWER, - }, - ) - def test_create_and_drop_model_registry(self, database_name: str, schema_name: Optional[str] = None) -> None: - if schema_name: - create_result = model_registry.create_model_registry( - session=self._session, database_name=database_name, schema_name=schema_name - ) - self.assertTrue(create_result) - self._validate_restore_db_and_schema() - - # Test create again, should be non-op - create_result = model_registry.create_model_registry( - session=self._session, database_name=database_name, schema_name=schema_name - ) - - self.assertTrue(create_result) - self._validate_restore_db_and_schema() - - _ = model_registry.ModelRegistry( - session=self._session, database_name=database_name, schema_name=schema_name - ) - - self._db_manager.drop_schema(schema_name, database_name) - self.assertTrue(self._db_manager.assert_schema_existence(schema_name, database_name, exists=False)) - self._validate_restore_db_and_schema() - else: - create_result = model_registry.create_model_registry(session=self._session, database_name=database_name) - self.assertTrue(create_result) - self._validate_restore_db_and_schema() - - # Test create again, should be non-op - create_result = model_registry.create_model_registry(session=self._session, database_name=database_name) - self.assertTrue(create_result) - self._validate_restore_db_and_schema() - - _ = model_registry.ModelRegistry(session=self._session, database_name=database_name) - - self._db_manager.drop_database(database_name) - self.assertTrue(self._db_manager.assert_database_existence(database_name, exists=False)) - self._validate_restore_db_and_schema() - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model_registry_compat_test.py b/tests/integ/snowflake/ml/registry/model_registry_compat_test.py deleted file mode 100644 index 1ee212c5..00000000 --- a/tests/integ/snowflake/ml/registry/model_registry_compat_test.py +++ /dev/null @@ -1,117 +0,0 @@ -import uuid -from typing import Callable, Tuple - -from absl.testing import absltest, parameterized -from packaging import version -from sklearn import datasets - -from snowflake.ml._internal import env -from snowflake.ml.registry import model_registry -from snowflake.snowpark import session -from tests.integ.snowflake.ml.test_utils import common_test_base, db_manager - - -@absltest.skipIf( - version.Version(env.PYTHON_VERSION) >= version.Version("3.11"), - "Skip compat test for Python higher than 3.11 since we previously does not support it.", -) -class ModelRegistryCompatTest(common_test_base.CommonTestBase): - def setUp(self) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - super().setUp() - self.run_id = uuid.uuid4().hex - self._db_manager = db_manager.DBManager(self.session) - self.current_db = self.session.get_current_database() - self.current_schema = self.session.get_current_schema() - self.registry_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "registry_db") - - def tearDown(self) -> None: - self._db_manager.drop_database(self.registry_name, if_exists=True) - self.session.use_database(self.current_db) - self.session.use_schema(self.current_schema) - super().tearDown() - - def _prepare_registry_fn_factory( - self, - ) -> Tuple[Callable[[session.Session, str], None], Tuple[str]]: - def prepare_registry(session: session.Session, registry_name: str) -> None: - from snowflake.connector.errors import ProgrammingError - from snowflake.ml.registry import model_registry - - try: - model_registry.create_model_registry(session=session, database_name=registry_name) - except ProgrammingError: - # Previous versions of library will call use even in the sproc env, which is not allowed. - # This is to suppress the error - pass - - return prepare_registry, (self.registry_name,) - - # Starting from 1.0.1 as we had a breaking change at that time. - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_prepare_registry_fn_factory, version_range=">=1.0.1" # type: ignore[misc] - ) - def test_open_registry_compat(self) -> None: - model_registry.ModelRegistry(session=self.session, database_name=self.registry_name, create_if_not_exists=True) - - def _prepare_registry_and_log_model_fn_factory( - self, - ) -> Tuple[Callable[[session.Session, str, str], None], Tuple[str, str]]: - self.registry_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "registry_db") - - def prepare_registry_and_log_model(session: session.Session, registry_name: str, run_id: str) -> None: - from sklearn import datasets, linear_model - - from snowflake.connector.errors import ProgrammingError - from snowflake.ml.registry import model_registry - - try: - model_registry.create_model_registry(session=session, database_name=registry_name) - except ProgrammingError: - # Previous versions of library will call use even in the sproc env, which is not allowed. - # This is to suppress the error - pass - - registry = model_registry.ModelRegistry(session=session, database_name=registry_name) - - iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True) - # Normalize the column name to avoid set it as case_sensitive where there was a BCR in 1.1.2 - iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] - # LogisticRegression is for classfication task, such as iris - regr = linear_model.LogisticRegression() - regr.fit(iris_X, iris_y) - - registry.log_model( - model_name="model", - model_version=run_id, - model=regr, - sample_input_data=iris_X, - ) - - return prepare_registry_and_log_model, (self.registry_name, self.run_id) - - @common_test_base.CommonTestBase.compatibility_test( - prepare_fn_factory=_prepare_registry_and_log_model_fn_factory, # type: ignore[arg-type] - version_range=">=1.0.6", - ) - @parameterized.parameters({"permanent": True}) - def test_log_model_compat(self, permanent: bool) -> None: - registry = model_registry.ModelRegistry( - session=self.session, database_name=self.registry_name, create_if_not_exists=True - ) - model_ref = model_registry.ModelReference( - registry=registry, - model_name="model", - model_version=self.run_id, - ) - deployment_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self.run_id, "predict") - model_ref.deploy( # type: ignore[attr-defined] - deployment_name=deployment_name, target_method="predict", permanent=permanent - ) - iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True) - iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] - model_ref.predict(deployment_name, iris_X) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model_registry_integ_test.py b/tests/integ/snowflake/ml/registry/model_registry_integ_test.py deleted file mode 100644 index 2e7c6dad..00000000 --- a/tests/integ/snowflake/ml/registry/model_registry_integ_test.py +++ /dev/null @@ -1,366 +0,0 @@ -import uuid -from typing import Dict - -import numpy as np -import pandas as pd -from absl.testing import absltest, parameterized -from sklearn import metrics - -from snowflake import connector -from snowflake.ml.registry import model_registry -from snowflake.ml.utils import connection_params -from snowflake.snowpark import Session -from tests.integ.snowflake.ml.test_utils import ( - db_manager, - model_factory, - test_env_utils, -) - - -class TestModelRegistryInteg(parameterized.TestCase): - @classmethod - def setUpClass(cls) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - cls._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - cls.run_id = uuid.uuid4().hex - cls._db_manager = db_manager.DBManager(cls._session) - cls.registry_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(cls.run_id, "registry_db") - model_registry.create_model_registry(session=cls._session, database_name=cls.registry_name) - cls.perm_stage = "@" + cls._db_manager.create_stage( - "model_registry_test_stage", "PUBLIC", cls.registry_name, sse_encrypted=True - ) - - @classmethod - def tearDownClass(cls) -> None: - cls._db_manager.drop_database(cls.registry_name) - cls._session.close() - - def test_basic_workflow(self) -> None: - registry = model_registry.ModelRegistry(session=self._session, database_name=self.registry_name) - - # Prepare the model - model_name = "basic_model" - model_version = self.run_id - model, test_features, test_labels = model_factory.ModelFactory.prepare_sklearn_model() - - local_prediction = model.predict(test_features) - local_prediction_proba = model.predict_proba(test_features) - - model_tags: Dict[str, str] = {"stage": "testing", "classifier_type": "svm.SVC", "num_training_examples": "10"} - - # Test model logging - with self.assertRaisesRegex( - KeyError, f"The model {model_name}/{model_version} does not exist in the current registry." - ): - model_ref = model_registry.ModelReference( - registry=registry, model_name=model_name, model_version=model_version - ) - - model_ref = registry.log_model( - model_name=model_name, - model_version=model_version, - model=model, - tags=model_tags, - conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server( - self._session, "snowflake-snowpark-python!=1.12.0" - ) - ], - sample_input_data=test_features, - options={"embed_local_ml_library": True}, - ) - - with self.assertRaisesRegex( - connector.DataError, f"Model {model_name}/{model_version} already exists. Unable to log the model." - ): - registry.log_model( - model_name=model_name, - model_version=model_version, - model=model, - tags={"stage": "testing", "classifier_type": "svm.SVC"}, - conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server( - self._session, "snowflake-snowpark-python!=1.12.0" - ) - ], - sample_input_data=test_features, - options={"embed_local_ml_library": True}, - ) - - model_ref = model_registry.ModelReference(registry=registry, model_name=model_name, model_version=model_version) - - # Test getting model name and model version - self.assertEqual(model_ref.get_name(), model_name) - self.assertEqual(model_ref.get_version(), model_version) - - # Test metrics - test_accuracy = metrics.accuracy_score(test_labels, local_prediction) - - model_ref.set_metric(metric_name="test_accuracy", metric_value=test_accuracy) # type: ignore[attr-defined] - - model_ref.set_metric(metric_name="num_training_examples", metric_value=10) # type: ignore[attr-defined] - - model_ref.set_metric( # type: ignore[attr-defined] - metric_name="dataset_test", - metric_value={"accuracy": test_accuracy}, - ) - - test_confusion_matrix = metrics.confusion_matrix(test_labels, local_prediction) - - model_ref.set_metric( # type: ignore[attr-defined] - metric_name="confusion_matrix", - metric_value=test_confusion_matrix, - ) - - stored_metrics = model_ref.get_metrics() # type: ignore[attr-defined] - - np.testing.assert_almost_equal(stored_metrics.pop("confusion_matrix"), test_confusion_matrix) - - self.assertDictEqual( - stored_metrics, - { - "test_accuracy": test_accuracy, - "num_training_examples": 10, - "dataset_test": {"accuracy": test_accuracy}, - }, - ) - - model_ref.remove_metric("confusion_matrix") # type: ignore[attr-defined] - self.assertDictEqual( - model_ref.get_metrics(), # type: ignore[attr-defined] - { - "test_accuracy": test_accuracy, - "num_training_examples": 10, - "dataset_test": {"accuracy": test_accuracy}, - }, - ) - - with self.assertRaisesRegex( - connector.DataError, f"Model {model_name}/{model_version} has no metric named confusion_matrix." - ): - model_ref.remove_metric(metric_name="confusion_matrix") # type: ignore[attr-defined] - - model_ref.set_metric(metric_name="num_training_examples", metric_value=20) # type: ignore[attr-defined] - self.assertDictEqual( - model_ref.get_metrics(), # type: ignore[attr-defined] - { - "test_accuracy": test_accuracy, - "num_training_examples": 20, - "dataset_test": {"accuracy": test_accuracy}, - }, - ) - - # Test list models - model_list = registry.list_models().to_pandas() - - filtered_model_list = model_list.loc[model_list["ID"] == model_ref._id].reset_index(drop=True) - - self.assertEqual(filtered_model_list.shape[0], 1) - self.assertEqual(filtered_model_list["NAME"][0], second=model_name) - self.assertEqual(filtered_model_list["VERSION"][0], second=model_version) - - # Test tags - self.assertDictEqual(model_ref.get_tags(), model_tags) # type: ignore[attr-defined] - - model_ref.set_tag(tag_name="minor_version", tag_value="23") # type: ignore[attr-defined] - self.assertDictEqual(model_ref.get_tags(), {**model_tags, "minor_version": "23"}) # type: ignore[attr-defined] - - model_ref.remove_tag(tag_name="minor_version") # type: ignore[attr-defined] - self.assertDictEqual(model_ref.get_tags(), model_tags) # type: ignore[attr-defined] - - with self.assertRaisesRegex( - connector.DataError, f"Model {model_name}/{model_version} has no tag named minor_version." - ): - model_ref.remove_tag(tag_name="minor_version") # type: ignore[attr-defined] - - model_ref.set_tag("stage", "production") # type: ignore[attr-defined] - model_tags.update({"stage": "production"}) - self.assertDictEqual(model_ref.get_tags(), model_tags) # type: ignore[attr-defined] - - # Test model description - model_ref.set_model_description( # type: ignore[attr-defined] - description="My model is better than talkgpt-5!", - ) - self.assertEqual( - model_ref.get_model_description(), "My model is better than talkgpt-5!" # type: ignore[attr-defined] - ) - - # Test loading model - restored_model = model_ref.load_model() # type: ignore[attr-defined] - restored_prediction = restored_model.predict(test_features) - np.testing.assert_allclose(local_prediction, restored_prediction) - - # Test permanent deployment - permanent_deployment_name = f"{model_name}_{model_version}_perm_deploy" - deploy_info = model_ref.deploy( # type: ignore[attr-defined] - deployment_name=permanent_deployment_name, - target_method="predict", - permanent=True, - ) - self.assertEqual(deploy_info["details"], {}) - remote_prediction_perm = model_ref.predict(permanent_deployment_name, test_features) - np.testing.assert_allclose(remote_prediction_perm.to_numpy(), np.expand_dims(local_prediction, axis=1)) - - custom_permanent_deployment_name = f"{model_name}_{model_version}_custom_perm_deploy" - deploy_info = model_ref.deploy( # type: ignore[attr-defined] - deployment_name=custom_permanent_deployment_name, - target_method="predict_proba", - permanent=True, - options={"permanent_udf_stage_location": self.perm_stage}, - ) - self.assertEqual(deploy_info["details"], {}) - remote_prediction_proba_perm = model_ref.predict(custom_permanent_deployment_name, test_features) - np.testing.assert_allclose(remote_prediction_proba_perm.to_numpy(), local_prediction_proba) - - # Test deployment information - model_deployment_list = model_ref.list_deployments().to_pandas() # type: ignore[attr-defined] - self.assertEqual(model_deployment_list.shape[0], 2) - - filtered_model_deployment_list = model_deployment_list.loc[ - model_deployment_list["DEPLOYMENT_NAME"] == custom_permanent_deployment_name - ].reset_index(drop=True) - - self.assertEqual(filtered_model_deployment_list.shape[0], 1) - self.assertEqual(filtered_model_deployment_list["MODEL_NAME"][0], second=model_name) - self.assertEqual(filtered_model_deployment_list["MODEL_VERSION"][0], second=model_version) - self.assertEqual(filtered_model_deployment_list["STAGE_PATH"][0], second=self.perm_stage) - - self.assertEqual( - self._session.sql( - f"SHOW USER FUNCTIONS LIKE '%{custom_permanent_deployment_name}' IN DATABASE \"{self.registry_name}\";" - ).count(), - 1, - ) - - model_ref.delete_deployment(deployment_name=custom_permanent_deployment_name) # type: ignore[attr-defined] - - model_deployment_list = model_ref.list_deployments().to_pandas() # type: ignore[attr-defined] - self.assertEqual(model_deployment_list.shape[0], 1) - self.assertEqual(model_deployment_list["MODEL_NAME"][0], second=model_name) - self.assertEqual(model_deployment_list["MODEL_VERSION"][0], second=model_version) - self.assertEqual(model_deployment_list["DEPLOYMENT_NAME"][0], second=permanent_deployment_name) - - self.assertEqual( - self._session.sql( - f"SHOW USER FUNCTIONS LIKE '%{custom_permanent_deployment_name}' IN DATABASE \"{self.registry_name}\";" - ).count(), - 0, - ) - - # Test temp deployment - temp_deployment_name = f"{model_name}_{model_version}_temp_deploy" - model_ref.deploy( # type: ignore[attr-defined] - deployment_name=temp_deployment_name, - target_method="predict", - permanent=False, - ) - remote_prediction_temp = model_ref.predict(temp_deployment_name, test_features) - np.testing.assert_allclose(remote_prediction_temp.to_numpy(), np.expand_dims(local_prediction, axis=1)) - - model_history = model_ref.get_model_history().to_pandas() # type: ignore[attr-defined] - self.assertEqual(model_history.shape[0], 16) - - registry.delete_model(model_name=model_name, model_version=model_version, delete_artifact=True) - model_list = registry.list_models().to_pandas() - filtered_model_list = model_list.loc[model_list["ID"] == model_ref._id].reset_index(drop=True) - self.assertEqual(filtered_model_list.shape[0], 0) - - @parameterized.parameters( - model_factory.ModelFactory.prepare_snowml_model_gmm, model_factory.ModelFactory.prepare_snowml_model_xgb - ) - def test_snowml_model(self, model_prepare_callable: callable) -> None: - registry = model_registry.ModelRegistry(session=self._session, database_name=self.registry_name) - - model_name = "snowml_xgb_classifier" - model_version = self.run_id - model, test_features, _ = model_prepare_callable() - - local_prediction = model.predict(test_features) - local_prediction_proba = model.predict_proba(test_features) - - registry.log_model( - model_name=model_name, - model_version=model_version, - model=model, - conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server( - self._session, "snowflake-snowpark-python!=1.12.0" - ) - ], - options={"embed_local_ml_library": True}, - ) - - model_ref = model_registry.ModelReference(registry=registry, model_name=model_name, model_version=model_version) - - restored_model = model_ref.load_model() # type: ignore[attr-defined] - restored_prediction = restored_model.predict(test_features) - pd.testing.assert_frame_equal(local_prediction, restored_prediction) - - temp_predict_deployment_name = f"{model_name}_{model_version}_predict_temp_deploy" - deploy_info = model_ref.deploy( # type: ignore[attr-defined] - deployment_name=temp_predict_deployment_name, - target_method="predict", - permanent=False, - ) - self.assertEqual(deploy_info["details"], {}) - remote_prediction_temp = model_ref.predict(temp_predict_deployment_name, test_features) - - # TODO: Remove check_dtype=False after SNOW-853634 gets fixed. - pd.testing.assert_frame_equal(remote_prediction_temp, local_prediction, check_dtype=False) - - temp_predict_proba_deployment_name = f"{model_name}_{model_version}_predict_proba_temp_deploy" - model_ref.deploy( # type: ignore[attr-defined] - deployment_name=temp_predict_proba_deployment_name, - target_method="predict_proba", - permanent=False, - ) - remote_prediction_proba_temp = model_ref.predict(temp_predict_proba_deployment_name, test_features) - # TODO: Remove check_dtype=False after SNOW-853634 gets fixed. - pd.testing.assert_frame_equal(remote_prediction_proba_temp, local_prediction_proba, check_dtype=False) - - registry.delete_model(model_name=model_name, model_version=model_version, delete_artifact=True) - - def test_snowml_pipeline(self) -> None: - registry = model_registry.ModelRegistry(session=self._session, database_name=self.registry_name) - - model_name = "snowml_pipeline" - model_version = self.run_id - model, test_features = model_factory.ModelFactory.prepare_snowml_pipeline(self._session) - - local_prediction = model.predict(test_features) - - registry.log_model( - model_name=model_name, - model_version=model_version, - model=model, - conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server( - self._session, "snowflake-snowpark-python!=1.12.0" - ) - ], - options={"embed_local_ml_library": True}, - ) - - model_ref = model_registry.ModelReference(registry=registry, model_name=model_name, model_version=model_version) - - restored_model = model_ref.load_model() # type: ignore[attr-defined] - restored_prediction = restored_model.predict(test_features) - pd.testing.assert_frame_equal(local_prediction.to_pandas(), restored_prediction.to_pandas()) - - temp_predict_deployment_name = f"{model_name}_{model_version}_predict_temp_deploy" - deploy_info = model_ref.deploy( # type: ignore[attr-defined] - deployment_name=temp_predict_deployment_name, - target_method="predict", - permanent=False, - ) - self.assertEqual(deploy_info["details"], {}) - remote_prediction_temp = model_ref.predict(temp_predict_deployment_name, test_features.to_pandas()) - # TODO: Remove .astype(dtype={"OUTPUT_TARGET": np.float64} after SNOW-853638 gets fixed. - pd.testing.assert_frame_equal( - remote_prediction_temp, - local_prediction.to_pandas().astype(dtype={"OUTPUT_TARGET": np.float64}), - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model_registry_schema_evolution_integ_test.py b/tests/integ/snowflake/ml/registry/model_registry_schema_evolution_integ_test.py deleted file mode 100644 index c4f316c6..00000000 --- a/tests/integ/snowflake/ml/registry/model_registry_schema_evolution_integ_test.py +++ /dev/null @@ -1,335 +0,0 @@ -import uuid -from typing import Any, Dict - -from absl.testing import absltest - -from snowflake.ml._internal.utils import identifier -from snowflake.ml.registry import ( - _initial_schema, - _schema, - _schema_upgrade_plans, - _schema_version_manager, - model_registry, -) -from snowflake.ml.utils import connection_params -from snowflake.snowpark import Session -from tests.integ.snowflake.ml.test_utils import ( - db_manager, - model_factory, - test_env_utils, -) - - -class UpgradePlan_0(_schema_upgrade_plans.BaseSchemaUpgradePlans): - def __init__( - self, - session: Session, - database_name: str, - schema_name: str, - statement_params: Dict[str, Any], - ) -> None: - super().__init__(session, database_name, schema_name, statement_params) - - def upgrade(self) -> None: - self._session.sql( - f"""ALTER TABLE {self._database}.{self._schema}._SYSTEM_REGISTRY_MODELS - RENAME COLUMN CREATION_CONTEXT TO CREATION_CONTEXT_ABC - """ - ).collect() - - -class UpgradePlan_1(_schema_upgrade_plans.BaseSchemaUpgradePlans): - def __init__( - self, - session: Session, - database_name: str, - schema_name: str, - statement_params: Dict[str, Any], - ) -> None: - super().__init__(session, database_name, schema_name, statement_params) - - def upgrade(self) -> None: - self._session.sql( - f"""ALTER TABLE {self._database}.{self._schema}._SYSTEM_REGISTRY_MODELS - RENAME COLUMN CREATION_CONTEXT_ABC TO CREATION_CONTEXT - """ - ).collect() - - -class ModelRegistrySchemaEvolutionIntegTest(absltest.TestCase): - @classmethod - def setUpClass(cls) -> None: - """Creates Snowpark and Snowflake environments for testing.""" - cls.session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - cls.run_id = uuid.uuid4().hex - cls.active_dbs = [] - cls.db_manager = db_manager.DBManager(cls.session) - - @classmethod - def tearDownClass(cls) -> None: - for db in cls.active_dbs: - cls.db_manager.drop_database(db) - cls.session.close() - - def setUp(self) -> None: - self.original_registry_schema = _schema._REGISTRY_TABLE_SCHEMA.copy() - self.original_schema_version = _schema._CURRENT_SCHEMA_VERSION - self.original_schema_upgrade_plans = _schema._SCHEMA_UPGRADE_PLANS.copy() - - def tearDown(self) -> None: - _schema._CURRENT_SCHEMA_VERSION = self.original_schema_version - _schema._SCHEMA_UPGRADE_PLANS = self.original_schema_upgrade_plans - _schema._REGISTRY_TABLE_SCHEMA = self.original_registry_schema - _schema._CURRENT_TABLE_SCHEMAS[_initial_schema._MODELS_TABLE_NAME] = _schema._REGISTRY_TABLE_SCHEMA - - def _check_version_table_exist(self, registry_name: str, schema_name: str) -> bool: - result = self.session.sql( - f"""SHOW TABLES LIKE '{_schema_version_manager._SCHEMA_VERSION_TABLE_NAME}' - IN "{registry_name}"."{schema_name}" - """ - ).collect() - return len(result) == 1 - - def _get_schema_version(self, registry_name: str, schema_name: str) -> int: - infer_db_name = identifier.get_inferred_name(registry_name) - infer_schema_name = identifier.get_inferred_name(schema_name) - full_table_name = f"{infer_db_name}.{infer_schema_name}.{_schema_version_manager._SCHEMA_VERSION_TABLE_NAME}" - result = self.session.sql(f"SELECT MAX(VERSION) AS MAX_VERSION FROM {full_table_name}").collect() - return result[0]["MAX_VERSION"] - - def _update_package_schema( - self, new_version: int, plan: _schema_upgrade_plans.BaseSchemaUpgradePlans, from_col: str, to_col: str - ): - _schema._CURRENT_SCHEMA_VERSION = new_version - _schema._SCHEMA_UPGRADE_PLANS[new_version] = plan # type: ignore[assignment] - for i, (col_name, _) in enumerate(_schema._REGISTRY_TABLE_SCHEMA): - if col_name == from_col: - _schema._REGISTRY_TABLE_SCHEMA[i] = (to_col, "VARCHAR") - _schema._CURRENT_TABLE_SCHEMAS[_initial_schema._MODELS_TABLE_NAME] = _schema._REGISTRY_TABLE_SCHEMA - - def test_svm_upgrade_deployed_schema(self) -> None: - registry_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "test_svm_upgrade_deployed_schema" - ) - schema_name = "SVM_TEST_SCHEMA" - model_registry.create_model_registry(session=self.session, database_name=registry_name, schema_name=schema_name) - self.active_dbs.append(registry_name) - - # No schema upgrade. svm.try_upgrade() is no-op - svm = _schema_version_manager.SchemaVersionManager( - self.session, identifier.get_inferred_name(registry_name), identifier.get_inferred_name(schema_name) - ) - - cur_version = _schema._CURRENT_SCHEMA_VERSION - self.assertEqual(svm.get_deployed_version(), cur_version) - svm.validate_schema_version() - svm.try_upgrade() - - # first upgrade. rename column from "CREATION_CONTEXT" to "CREATION_CONTEXT_ABC" - self._update_package_schema(cur_version + 1, UpgradePlan_0, "CREATION_CONTEXT", "CREATION_CONTEXT_ABC") - - self.assertEqual(svm.get_deployed_version(), cur_version) - with self.assertRaisesRegex(RuntimeError, "Registry schema version .* is ahead of deployed"): - svm.validate_schema_version() - svm.try_upgrade() - self.assertEqual(svm.get_deployed_version(), cur_version + 1) - svm.validate_schema_version() - df = self.session.sql(f"""SELECT * FROM "{registry_name}"."{schema_name}"._SYSTEM_REGISTRY_MODELS""") - self.assertTrue("CREATION_CONTEXT_ABC" in df.columns) - - # Second upgrade schema: rename column "CREATION_CONTEXT_ABC" back to "CREATION_CONTEXT" - self._update_package_schema(cur_version + 2, UpgradePlan_1, "CREATION_CONTEXT_ABC", "CREATION_CONTEXT") - - self.assertEqual(svm.get_deployed_version(), cur_version + 1) - with self.assertRaisesRegex(RuntimeError, "Registry schema version .* is ahead of deployed"): - svm.validate_schema_version() - svm.try_upgrade() - self.assertEqual(svm.get_deployed_version(), cur_version + 2) - svm.validate_schema_version() - df = self.session.sql(f"""SELECT * FROM "{registry_name}"."{schema_name}"._SYSTEM_REGISTRY_MODELS""") - self.assertTrue("CREATION_CONTEXT" in df.columns) - - def test_svm_upgrade_package_schema(self) -> None: - registry_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "test_svm_upgrade_package_schema" - ) - schema_name = "SVM_TEST_SCHEMA" - model_registry.create_model_registry(session=self.session, database_name=registry_name, schema_name=schema_name) - self.active_dbs.append(registry_name) - - # Upgrade deployed schema to a newer version - cur_version = _schema._CURRENT_SCHEMA_VERSION - self._update_package_schema(cur_version + 1, UpgradePlan_0, "CREATION_CONTEXT", "CREATION_CONTEXT_ABC") - - svm = _schema_version_manager.SchemaVersionManager( - self.session, identifier.get_inferred_name(registry_name), identifier.get_inferred_name(schema_name) - ) - svm.try_upgrade() - self.assertEqual(svm.get_deployed_version(), cur_version + 1) - svm.validate_schema_version() - - # Then downgrade package schema and check version should fail. - _schema._CURRENT_SCHEMA_VERSION = cur_version - with self.assertRaisesRegex(RuntimeError, "Deployed registry schema version .* is ahead of current package"): - svm.validate_schema_version() - - def test_model_registry_upgrade_deployed_schema(self) -> None: - registry_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "test_model_registry_upgrade_deployed_schema" - ) - schema_name = "SVM_TEST_SCHEMA" - - model_registry.create_model_registry(session=self.session, database_name=registry_name, schema_name=schema_name) - model_registry.ModelRegistry(session=self.session, database_name=registry_name, schema_name=schema_name) - self.active_dbs.append(registry_name) - - # Upgrade schema: rename column "CREATION_CONTEXT" to "CREATION_CONTEXT_ABC" - cur_version = _schema._CURRENT_SCHEMA_VERSION - self._update_package_schema(cur_version + 1, UpgradePlan_0, "CREATION_CONTEXT", "CREATION_CONTEXT_ABC") - - # model registry will create version table, and update deployed schema to version 1. - with self.assertRaisesRegex(RuntimeError, "Registry schema version .* is ahead of deployed"): - model_registry.ModelRegistry(session=self.session, database_name=registry_name, schema_name=schema_name) - model_registry.create_model_registry(session=self.session, database_name=registry_name, schema_name=schema_name) - - self.assertTrue(self._check_version_table_exist(registry_name, schema_name)) - self.assertEqual(self._get_schema_version(registry_name, schema_name), cur_version + 1) - df = self.session.sql(f"""SELECT * FROM "{registry_name}"."{schema_name}"._SYSTEM_REGISTRY_MODELS""") - self.assertTrue("CREATION_CONTEXT_ABC" in df.columns) - self.assertFalse("CREATION_CONTEXT" in df.columns) - - # second upgrade: rename column "CREATION_CONTEXT_ABC" to "CREATION_CONTEXT" - self._update_package_schema(cur_version + 2, UpgradePlan_1, "CREATION_CONTEXT_ABC", "CREATION_CONTEXT") - - # model registry will update deployed schema to version 2. - with self.assertRaisesRegex(RuntimeError, "Registry schema version .* is ahead of deployed"): - model_registry.ModelRegistry(session=self.session, database_name=registry_name, schema_name=schema_name) - model_registry.create_model_registry(session=self.session, database_name=registry_name, schema_name=schema_name) - - self.assertTrue(self._check_version_table_exist(registry_name, schema_name)) - self.assertEqual(self._get_schema_version(registry_name, schema_name), cur_version + 2) - df = self.session.sql(f"""SELECT * FROM "{registry_name}"."{schema_name}"._SYSTEM_REGISTRY_MODELS""") - self.assertFalse("CREATION_CONTEXT_ABC" in df.columns) - self.assertTrue("CREATION_CONTEXT" in df.columns) - - def test_model_registry_upgrade_package_schema(self) -> None: - registry_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "test_model_registry_upgrade_package_schema" - ) - schema_name = "SVM_TEST_SCHEMA" - model_registry.create_model_registry(session=self.session, database_name=registry_name, schema_name=schema_name) - self.active_dbs.append(registry_name) - - # Upgrade deployed schema to a newer version - cur_version = _schema._CURRENT_SCHEMA_VERSION - self._update_package_schema(cur_version + 1, UpgradePlan_0, "CREATION_CONTEXT", "CREATION_CONTEXT_ABC") - - svm = _schema_version_manager.SchemaVersionManager( - self.session, identifier.get_inferred_name(registry_name), identifier.get_inferred_name(schema_name) - ) - - svm.try_upgrade() - self.assertEqual(svm.get_deployed_version(), cur_version + 1) - svm.validate_schema_version() - - # Then downgrade package schema, and ModelRegistry will panic - _schema._CURRENT_SCHEMA_VERSION = cur_version - with self.assertRaisesRegex(RuntimeError, "Deployed registry schema version .* is ahead of current package"): - model_registry.ModelRegistry(session=self.session, database_name=registry_name, schema_name=schema_name) - - with self.assertRaisesRegex(RuntimeError, "Deployed registry schema version .* is ahead of current package"): - model_registry.create_model_registry( - session=self.session, database_name=registry_name, schema_name=schema_name - ) - - def test_model_registry_creation_has_artifact_ids(self) -> None: - registry_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "test_model_registry_creation_has_artifact_ids" - ) - schema_name = "SVM_TEST_SCHEMA" - model_registry.create_model_registry(session=self.session, database_name=registry_name, schema_name=schema_name) - self.active_dbs.append(registry_name) - - model_registry.ModelRegistry(session=self.session, database_name=registry_name, schema_name=schema_name) - - df = self.session.sql(f"""SELECT * FROM "{registry_name}"."{schema_name}"._SYSTEM_REGISTRY_MODELS""") - self.assertTrue("ARTIFACT_IDS" in df.columns) - self.assertTrue(self._get_schema_version(registry_name, schema_name) > 0) - - # downgrade deployed schema and delete training dataset id - self.session.sql( - f"""ALTER TABLE "{registry_name}"."{schema_name}"._SYSTEM_REGISTRY_MODELS - DROP COLUMN ARTIFACT_IDS - """ - ).collect() - self.session.sql( - f"""DROP TABLE "{registry_name}"."{schema_name}".{_schema_version_manager._SCHEMA_VERSION_TABLE_NAME} - """ - ).collect() - df = self.session.sql(f"""SELECT * FROM "{registry_name}"."{schema_name}"._SYSTEM_REGISTRY_MODELS""") - self.assertFalse("ARTIFACT_IDS" in df.columns) - self.assertFalse(self._check_version_table_exist(registry_name, schema_name)) - - # opening model registry will raise, re-create model registry will upgrade deployed schema - with self.assertRaisesRegex(RuntimeError, "Registry schema version .* is ahead of deployed"): - model_registry.ModelRegistry(session=self.session, database_name=registry_name, schema_name=schema_name) - model_registry.create_model_registry(session=self.session, database_name=registry_name, schema_name=schema_name) - - df = self.session.sql(f"""SELECT * FROM "{registry_name}"."{schema_name}"._SYSTEM_REGISTRY_MODELS""") - self.assertTrue("ARTIFACT_IDS" in df.columns) - self.assertTrue(self._get_schema_version(registry_name, schema_name) > 0) - - def test_api_schema_validation(self) -> None: - registry_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "test_api_schema_validation" - ) - schema_name = "SVM_TEST_SCHEMA" - model_registry.create_model_registry(session=self.session, database_name=registry_name, schema_name=schema_name) - self.active_dbs.append(registry_name) - - registry = model_registry.ModelRegistry( - session=self.session, database_name=registry_name, schema_name=schema_name - ) - model, test_features, _ = model_factory.ModelFactory.prepare_snowml_model_xgb() - registry.log_model( - model_name="m", - model_version="v1", - model=model, - conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server(self.session, "snowflake-snowpark-python") - ], - ) - - # upgrade deployed schema - version_table_path = f'"{registry_name}"."{schema_name}"._SYSTEM_REGISTRY_SCHEMA_VERSION' - self.session.sql( - f"""INSERT INTO {version_table_path} (VERSION, CREATION_TIME) - VALUES ({_schema._CURRENT_SCHEMA_VERSION + 1}, CURRENT_TIMESTAMP()) - """ - ).collect() - - with self.assertRaisesRegex(RuntimeError, "Deployed registry schema version .* is ahead of current package"): - registry.log_model( - model_name="m", - model_version="v2", - model=model, - conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server(self.session, "snowflake-snowpark-python") - ], - ) - - model_ref = model_registry.ModelReference(registry=registry, model_name="m", model_version="v1") - - with self.assertRaisesRegex(RuntimeError, "Deployed registry schema version .* is ahead of current package"): - model_ref.deploy( # type: ignore[attr-defined] - deployment_name="test_api_schema_validation", - target_method="predict", - permanent=False, - ) - - with self.assertRaisesRegex(RuntimeError, "Deployed registry schema version .* is ahead of current package"): - model_ref.predict("test_api_schema_validation", test_features) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model_registry_snowservice_integ_test.py b/tests/integ/snowflake/ml/registry/model_registry_snowservice_integ_test.py deleted file mode 100644 index fa710a8a..00000000 --- a/tests/integ/snowflake/ml/registry/model_registry_snowservice_integ_test.py +++ /dev/null @@ -1,146 +0,0 @@ -# TODO[shchen], SNOW-889081, re-enable once server-side image build is supported. -# -# import functools -# import tempfile -# import uuid -# -# import numpy as np -# import pandas as pd -# import torch -from absl.testing import absltest - -# -# from snowflake.ml.model import deploy_platforms -# from snowflake.ml.model._signatures import pytorch_handler, tensorflow_handler -# from tests.integ.snowflake.ml.registry.model_registry_integ_test_snowservice_base import ( -# TestModelRegistryIntegSnowServiceBase, -# ) -# from tests.integ.snowflake.ml.test_utils import model_factory -# -# -# class TestModelRegistryIntegWithSnowServiceDeployment(TestModelRegistryIntegSnowServiceBase): -# -# def test_sklearn_deployment_with_snowml_conda(self) -> None: -# self._test_snowservice_deployment( -# model_name="test_sklearn_model_with_snowml_conda", -# model_version=uuid.uuid4().hex, -# prepare_model_and_feature_fn=model_factory.ModelFactory.prepare_sklearn_model, -# embed_local_ml_library=False, -# conda_dependencies=["snowflake-ml-python==1.0.2"], -# prediction_assert_fn=lambda local_prediction, remote_prediction: np.testing.assert_allclose( -# remote_prediction.to_numpy(), np.expand_dims(local_prediction, axis=1) -# ), -# deployment_options={ -# "platform": deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, -# "target_method": "predict", -# "options": { -# "compute_pool": self._TEST_CPU_COMPUTE_POOL, -# "image_repo": self._db_manager.get_snowservice_image_repo(repo=self._TEST_IMAGE_REPO), -# "num_workers": 1, -# "external_access_integrations": self._SPCS_EAIS, -# }, -# }, -# ) -# -# -# def test_sklearn_deployment_with_local_source_code(self) -> None: -# self._test_snowservice_deployment( -# model_name="test_sklearn_model_with_local_source_code", -# model_version=uuid.uuid4().hex, -# prepare_model_and_feature_fn=model_factory.ModelFactory.prepare_sklearn_model, -# prediction_assert_fn=lambda local_prediction, remote_prediction: np.testing.assert_allclose( -# remote_prediction.to_numpy(), np.expand_dims(local_prediction, axis=1) -# ), -# deployment_options={ -# "platform": deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, -# "target_method": "predict", -# "options": { -# "compute_pool": self._TEST_CPU_COMPUTE_POOL, -# "image_repo": self._db_manager.get_snowservice_image_repo(repo=self._TEST_IMAGE_REPO), -# "external_access_integrations": self._SPCS_EAIS, -# }, -# }, -# ) -# -# -# def test_huggingface_custom_model_deployment(self) -> None: -# with tempfile.TemporaryDirectory() as tmpdir: -# self._test_snowservice_deployment( -# model_name="gpt2_model_gpu", -# model_version=uuid.uuid4().hex, -# conda_dependencies=["pytorch", "transformers"], -# prepare_model_and_feature_fn=functools.partial( -# model_factory.ModelFactory.prepare_gpt2_model, -# local_cache_dir=tmpdir, -# ), -# prediction_assert_fn=lambda local_prediction, remote_prediction: pd.testing.assert_frame_equal( -# remote_prediction, local_prediction, check_dtype=False -# ), -# deployment_options={ -# "platform": deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, -# "target_method": "predict", -# "options": { -# "compute_pool": self._TEST_CPU_COMPUTE_POOL, -# "image_repo": self._db_manager.get_snowservice_image_repo(repo=self._TEST_IMAGE_REPO), -# "num_workers": 1, -# "external_access_integrations": self._SPCS_EAIS, -# }, -# }, -# ) -# -# -# def test_torch_model_deployment_with_gpu(self) -> None: -# self._test_snowservice_deployment( -# model_name="torch_model", -# model_version=uuid.uuid4().hex, -# prepare_model_and_feature_fn=functools.partial( -# model_factory.ModelFactory.prepare_torch_model, force_remote_gpu_inference=True -# ), -# conda_dependencies=[ -# "pytorch-nightly::pytorch", -# "pytorch-nightly::pytorch-cuda==12.1", -# "nvidia::cuda==12.1.*", -# ], -# prediction_assert_fn=lambda local_prediction, remote_prediction: torch.testing.assert_close( -# pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(remote_prediction)[0], -# local_prediction[0], -# check_dtype=False, -# ), -# deployment_options={ -# "platform": deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, -# "target_method": "forward", -# "options": { -# "compute_pool": self._TEST_GPU_COMPUTE_POOL, -# "image_repo": self._db_manager.get_snowservice_image_repo(repo=self._TEST_IMAGE_REPO), -# "num_workers": 1, -# "use_gpu": True, -# "external_access_integrations": self._SPCS_EAIS, -# }, -# }, -# ) -# -# -# def test_keras_model_deployment(self) -> None: -# self._test_snowservice_deployment( -# model_name="keras_model", -# model_version=uuid.uuid4().hex, -# prepare_model_and_feature_fn=model_factory.ModelFactory.prepare_keras_model, -# prediction_assert_fn=lambda local_prediction, remote_prediction: np.testing.assert_allclose( -# tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(remote_prediction)[0].numpy(), -# local_prediction[0], -# atol=1e-6, -# ), -# deployment_options={ -# "platform": deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, -# "target_method": "predict", -# "options": { -# "compute_pool": self._TEST_CPU_COMPUTE_POOL, -# "image_repo": self._db_manager.get_snowservice_image_repo(repo=self._TEST_IMAGE_REPO), -# "external_access_integrations": self._SPCS_EAIS, -# }, -# }, -# ) -# -# -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model_registry_snowservice_integ_test_base.py b/tests/integ/snowflake/ml/registry/model_registry_snowservice_integ_test_base.py deleted file mode 100644 index e1812eb8..00000000 --- a/tests/integ/snowflake/ml/registry/model_registry_snowservice_integ_test_base.py +++ /dev/null @@ -1,113 +0,0 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import pandas as pd -import yaml -from absl.testing import absltest - -from snowflake.ml.model import model_signature -from snowflake.ml.registry import model_registry -from snowflake.snowpark import DataFrame as SnowparkDataFrame -from tests.integ.snowflake.ml.test_utils import ( - model_factory, - spcs_integ_test_base, - test_env_utils, -) - - -def is_valid_yaml(yaml_string) -> bool: - try: - yaml.safe_load(yaml_string) - return True - except yaml.YAMLError: - return False - - -class TestModelRegistryIntegSnowServiceBase(spcs_integ_test_base.SpcsIntegTestBase): - def setUp(self) -> None: - super().setUp() - model_registry.create_model_registry( - session=self._session, database_name=self._test_db, schema_name=self._test_schema - ) - self.registry = model_registry.ModelRegistry( - session=self._session, database_name=self._test_db, schema_name=self._test_schema - ) - - def tearDown(self) -> None: - super().tearDown() - - def _test_snowservice_deployment( - self, - model_name: str, - model_version: str, - prepare_model_and_feature_fn: Callable[[], Tuple[Any, Any, Any]], - deployment_options: Dict[str, Any], - prediction_assert_fn: Callable[[Any, Union[pd.DataFrame, SnowparkDataFrame]], Any], - pip_requirements: Optional[List[str]] = None, - conda_dependencies: Optional[List[str]] = None, - embed_local_ml_library: Optional[bool] = True, - omit_target_method_when_deploy: bool = False, - ) -> None: - model, test_features, *_ = prepare_model_and_feature_fn() - if omit_target_method_when_deploy: - target_method = deployment_options.pop("target_method") - else: - target_method = deployment_options["target_method"] - - if hasattr(model, "predict_with_device"): - local_prediction = model.predict_with_device(test_features, model_factory.DEVICE.CPU) - else: - local_prediction = getattr(model, target_method)(test_features) - - # In test, latest snowpark version might not be in conda channel yet, which can cause image build to fail. - # Instead we rely on snowpark version on information.schema table. Note that this will not affect end user - # as by the time they use it, the latest snowpark should be available in conda already. - conda_dependencies = conda_dependencies or [] - conda_dependencies.append(test_env_utils.get_latest_package_version_spec_in_conda("snowflake-snowpark-python")) - - self.registry.log_model( - model_name=model_name, - model_version=model_version, - model=model, - conda_dependencies=conda_dependencies, - pip_requirements=pip_requirements, - signatures={target_method: model_signature.infer_signature(test_features, local_prediction)}, - options={"embed_local_ml_library": embed_local_ml_library}, - ) - - model_ref = model_registry.ModelReference( - registry=self.registry, model_name=model_name, model_version=model_version - ) - - deployment_name = f"{model_name}_{model_version}_deployment" - deployment_options["deployment_name"] = deployment_name - deploy_info = model_ref.deploy(**deployment_options) # type: ignore[attr-defined] - deploy_details = deploy_info["details"] - self.assertNotEmpty(deploy_details) - self.assertTrue(deploy_details["service_info"]) - self.assertTrue(deploy_details["service_function_sql"]) - - remote_prediction = model_ref.predict(deployment_name, test_features) - prediction_assert_fn(local_prediction, remote_prediction) - - model_deployment_list = model_ref.list_deployments().to_pandas() # type: ignore[attr-defined] - self.assertEqual(model_deployment_list.shape[0], 1) - self.assertEqual(model_deployment_list["MODEL_NAME"][0], model_name) - self.assertEqual(model_deployment_list["MODEL_VERSION"][0], model_version) - self.assertEqual(model_deployment_list["DEPLOYMENT_NAME"][0], deployment_name) - - deployment = self.registry._get_deployment( - model_name=model_name, model_version=model_version, deployment_name=deployment_name - ) - service_name = f"service_{deployment['MODEL_ID']}" - model_ref.delete_deployment(deployment_name=deployment_name) # type: ignore[attr-defined] - self.assertEqual(model_ref.list_deployments().to_pandas().shape[0], 0) # type: ignore[attr-defined] - - service_lst = self._session.sql(f"SHOW SERVICES LIKE '{service_name}' in account;").collect() - self.assertEqual(len(service_lst), 0, "Service was not deleted successfully") - self.assertEqual(self.registry.list_models().to_pandas().shape[0], 1) - self.registry.delete_model(model_name=model_name, model_version=model_version, delete_artifact=True) - self.assertEqual(self.registry.list_models().to_pandas().shape[0], 0) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model_registry_snowservice_merge_gate_integ_test.py b/tests/integ/snowflake/ml/registry/model_registry_snowservice_merge_gate_integ_test.py deleted file mode 100644 index 21d68f05..00000000 --- a/tests/integ/snowflake/ml/registry/model_registry_snowservice_merge_gate_integ_test.py +++ /dev/null @@ -1,70 +0,0 @@ -import re -import uuid - -import pandas as pd -from absl.testing import absltest - -from snowflake.ml.model import deploy_platforms -from tests.integ.snowflake.ml.registry.model_registry_snowservice_integ_test_base import ( - TestModelRegistryIntegSnowServiceBase, -) -from tests.integ.snowflake.ml.test_utils import model_factory - - -class TestModelRegistryIntegWithSnowServiceDeployment(TestModelRegistryIntegSnowServiceBase): - def test_snowml_model_deployment_xgboost(self) -> None: - def _run_deployment() -> None: - self._test_snowservice_deployment( - model_name="xgboost_model", - model_version=uuid.uuid4().hex, - prepare_model_and_feature_fn=model_factory.ModelFactory.prepare_snowml_model_xgb, - prediction_assert_fn=lambda local_prediction, remote_prediction: pd.testing.assert_frame_equal( - remote_prediction, local_prediction, check_dtype=False - ), - deployment_options={ - "platform": deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, - "target_method": "predict", - "options": { - "compute_pool": self._TEST_CPU_COMPUTE_POOL, - "enable_remote_image_build": True, - "external_access_integrations": self._SPCS_EAIS, - }, - }, - ) - - # First deployment - _run_deployment() - - # Second deployment. Ensure image building is skipped due to similar environment. - with self.assertLogs(level="WARNING") as cm: - _run_deployment() - image_pattern = r"Using existing image .* to skip image build" - image_pattern_found = any(re.search(image_pattern, s, re.MULTILINE | re.DOTALL) for s in cm.output) - self.assertTrue(image_pattern_found, "Should skip image build on second deployment") - - def test_snowml_model_deployment_xgboost_with_model_in_image(self) -> None: - def _run_deployment() -> None: - self._test_snowservice_deployment( - model_name="xgboost_model", - model_version=uuid.uuid4().hex, - prepare_model_and_feature_fn=model_factory.ModelFactory.prepare_snowml_model_xgb, - prediction_assert_fn=lambda local_prediction, remote_prediction: pd.testing.assert_frame_equal( - remote_prediction, local_prediction, check_dtype=False - ), - deployment_options={ - "platform": deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, - "target_method": "predict", - "options": { - "compute_pool": self._TEST_CPU_COMPUTE_POOL, - "enable_remote_image_build": True, - "model_in_image": True, - "external_access_integrations": self._SPCS_EAIS, - }, - }, - ) - - _run_deployment() - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/integ/snowflake/ml/registry/services/BUILD.bazel b/tests/integ/snowflake/ml/registry/services/BUILD.bazel new file mode 100644 index 00000000..c724520f --- /dev/null +++ b/tests/integ/snowflake/ml/registry/services/BUILD.bazel @@ -0,0 +1,62 @@ +load("@rules_python//python:defs.bzl", "py_library") +load("//bazel:py_rules.bzl", "py_test") + +package(default_visibility = ["//tests/integ/snowflake/ml:__subpackages__"]) + +py_library( + name = "registry_model_deployment_test_base", + testonly = True, + srcs = ["registry_model_deployment_test_base.py"], + deps = [ + "//snowflake/ml/_internal:file_utils", + "//snowflake/ml/_internal/utils:snowflake_env", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model:type_hints", + "//snowflake/ml/model/_client/model:model_version_impl", + "//snowflake/ml/registry", + "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:common_test_base", + "//tests/integ/snowflake/ml/test_utils:db_manager", + "//tests/integ/snowflake/ml/test_utils:test_env_utils", + ], +) + +py_test( + name = "registry_xgboost_model_deployment_test", + timeout = "eternal", + srcs = ["registry_xgboost_model_deployment_test.py"], + shard_count = 4, + deps = [ + ":registry_model_deployment_test_base", + ], +) + +py_test( + name = "registry_sentence_transformers_model_deployment_test", + timeout = "eternal", + srcs = ["registry_sentence_transformers_model_deployment_test.py"], + shard_count = 4, + deps = [ + ":registry_model_deployment_test_base", + ], +) + +py_test( + name = "registry_huggingface_pipeline_model_deployment_test", + timeout = "eternal", + srcs = ["registry_huggingface_pipeline_model_deployment_test.py"], + shard_count = 4, + deps = [ + ":registry_model_deployment_test_base", + ], +) + +py_test( + name = "registry_sklearn_model_deployment_test", + timeout = "long", + srcs = ["registry_sklearn_model_deployment_test.py"], + shard_count = 2, + deps = [ + ":registry_model_deployment_test_base", + ], +) diff --git a/tests/integ/snowflake/ml/registry/services/registry_huggingface_pipeline_model_deployment_test.py b/tests/integ/snowflake/ml/registry/services/registry_huggingface_pipeline_model_deployment_test.py new file mode 100644 index 00000000..bfe24acf --- /dev/null +++ b/tests/integ/snowflake/ml/registry/services/registry_huggingface_pipeline_model_deployment_test.py @@ -0,0 +1,73 @@ +import json +import os +import tempfile +from typing import List, Optional + +import pandas as pd +from absl.testing import absltest, parameterized + +from tests.integ.snowflake.ml.registry.services import ( + registry_model_deployment_test_base, +) + + +class TestRegistryHuggingFacePipelineDeploymentModelInteg( + registry_model_deployment_test_base.RegistryModelDeploymentTestBase +): + @classmethod + def setUpClass(self) -> None: + self.cache_dir = tempfile.TemporaryDirectory() + self._original_cache_dir = os.getenv("TRANSFORMERS_CACHE", None) + os.environ["TRANSFORMERS_CACHE"] = self.cache_dir.name + + @classmethod + def tearDownClass(self) -> None: + if self._original_cache_dir: + os.environ["TRANSFORMERS_CACHE"] = self._original_cache_dir + self.cache_dir.cleanup() + + @parameterized.product( # type: ignore[misc] + gpu_requests=[None, "1"], + pip_requirements=[None, ["transformers"]], + ) + def test_text_generation( + self, + gpu_requests: str, + pip_requirements: Optional[List[str]], + ) -> None: + import transformers + + model = transformers.pipeline( + task="text-generation", + model="openai-community/gpt2", + ) + + x_df = pd.DataFrame( + [['A descendant of the Lost City of Atlantis, who swam to Earth while saying, "']], + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["outputs"])) + + for row in res["outputs"]: + self.assertIsInstance(row, str) + resp = json.loads(row) + self.assertIsInstance(resp, list) + self.assertIn("generated_text", resp[0]) + + self._test_registry_model_deployment( + model=model, + prediction_assert_fns={ + "__call__": ( + x_df, + check_res, + ), + }, + options={"cuda_version": "11.8"} if gpu_requests else {}, + gpu_requests=gpu_requests, + pip_requirements=pip_requirements, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py b/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py new file mode 100644 index 00000000..c766a939 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py @@ -0,0 +1,195 @@ +import inspect +import os +import pathlib +import uuid +from typing import Any, Callable, Dict, List, Optional, Tuple + +import pytest +import yaml +from absl.testing import absltest + +from snowflake.ml._internal import file_utils +from snowflake.ml._internal.utils import snowflake_env, sql_identifier +from snowflake.ml.model import ModelVersion, type_hints as model_types +from snowflake.ml.model._client.service import model_deployment_spec +from snowflake.ml.registry import registry +from snowflake.snowpark._internal import utils as snowpark_utils +from tests.integ.snowflake.ml.test_utils import ( + common_test_base, + db_manager, + test_env_utils, +) + + +@pytest.mark.spcs_deployment_image +@absltest.skipUnless( + test_env_utils.get_current_snowflake_cloud_type() == snowflake_env.SnowflakeCloudType.AWS, + "SPCS only available in AWS", +) +class RegistryModelDeploymentTestBase(common_test_base.CommonTestBase): + _TEST_CPU_COMPUTE_POOL = "REGTEST_INFERENCE_CPU_POOL" + _TEST_GPU_COMPUTE_POOL = "REGTEST_INFERENCE_GPU_POOL" + _SPCS_EAI = "SPCS_EGRESS_ACCESS_INTEGRATION" + _TEST_SPCS_WH = "REGTEST_ML_SMALL" + + BUILDER_IMAGE_PATH = os.getenv("BUILDER_IMAGE_PATH", None) + BASE_CPU_IMAGE_PATH = os.getenv("BASE_CPU_IMAGE_PATH", None) + BASE_GPU_IMAGE_PATH = os.getenv("BASE_GPU_IMAGE_PATH", None) + + def setUp(self) -> None: + """Creates Snowpark and Snowflake environments for testing.""" + super().setUp() + + self._run_id = uuid.uuid4().hex[:2] + self._test_db = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self._run_id, "db").upper() + self._test_schema = "PUBLIC" + self._test_image_repo = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( + self._run_id, "image_repo" + ).upper() + + self.session.sql(f"USE WAREHOUSE {self._TEST_SPCS_WH}").collect() + + self._db_manager = db_manager.DBManager(self.session) + self._db_manager.create_database(self._test_db) + self._db_manager.create_image_repo(self._test_image_repo) + self._db_manager.cleanup_databases(expire_hours=6) + self.registry = registry.Registry(self.session) + + def tearDown(self) -> None: + self._db_manager.drop_database(self._test_db) + super().tearDown() + + def _deploy_model_with_image_override( + self, + mv: ModelVersion, + service_name: str, + gpu_requests: Optional[str] = None, + ) -> None: + """Deploy model with image override.""" + is_gpu = gpu_requests is not None + image_path = self.BASE_GPU_IMAGE_PATH if is_gpu else self.BASE_CPU_IMAGE_PATH + build_compute_pool = sql_identifier.SqlIdentifier(self._TEST_CPU_COMPUTE_POOL) + service_compute_pool = sql_identifier.SqlIdentifier( + self._TEST_GPU_COMPUTE_POOL if is_gpu else self._TEST_CPU_COMPUTE_POOL + ) + + # create a temp stage + database_name = sql_identifier.SqlIdentifier(self._test_db) + schema_name = sql_identifier.SqlIdentifier(self._test_schema) + stage_name = sql_identifier.SqlIdentifier( + snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE) + ) + image_repo_name = sql_identifier.SqlIdentifier(self._test_image_repo) + + mv._service_ops._stage_client.create_tmp_stage( + database_name=database_name, schema_name=schema_name, stage_name=stage_name + ) + stage_path = mv._service_ops._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name) + + deploy_spec_file_rel_path = model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH + + mv._service_ops._model_deployment_spec.save( + database_name=database_name, + schema_name=schema_name, + model_name=mv._model_name, + version_name=mv._version_name, + service_database_name=database_name, + service_schema_name=schema_name, + service_name=sql_identifier.SqlIdentifier(service_name), + image_build_compute_pool_name=build_compute_pool, + service_compute_pool_name=service_compute_pool, + image_repo_database_name=database_name, + image_repo_schema_name=schema_name, + image_repo_name=image_repo_name, + ingress_enabled=False, + max_instances=1, + num_workers=None, + max_batch_rows=None, + gpu=gpu_requests, + force_rebuild=True, + external_access_integration=sql_identifier.SqlIdentifier(self._SPCS_EAI), + ) + + with (mv._service_ops.workspace_path / deploy_spec_file_rel_path).open("r", encoding="utf-8") as f: + deploy_spec_dict = yaml.safe_load(f) + + deploy_spec_dict["image_build"]["builder_image"] = self.BUILDER_IMAGE_PATH + deploy_spec_dict["image_build"]["base_image"] = image_path + + with (mv._service_ops.workspace_path / deploy_spec_file_rel_path).open("w", encoding="utf-8") as f: + yaml.dump(deploy_spec_dict, f) + + file_utils.upload_directory_to_stage( + self.session, + local_path=mv._service_ops.workspace_path, + stage_path=pathlib.PurePosixPath(stage_path), + ) + + # deploy the model service + mv._service_ops._service_client.deploy_model( + stage_path=stage_path, model_deployment_spec_file_rel_path=deploy_spec_file_rel_path + ) + + def _test_registry_model_deployment( + self, + model: model_types.SupportedModelType, + prediction_assert_fns: Dict[str, Tuple[Any, Callable[[Any], Any]]], + sample_input_data: Optional[model_types.SupportedDataType] = None, + additional_dependencies: Optional[List[str]] = None, + pip_requirements: Optional[List[str]] = None, + options: Optional[model_types.ModelSaveOption] = None, + gpu_requests: Optional[str] = None, + ) -> None: + if self.BUILDER_IMAGE_PATH and self.BASE_CPU_IMAGE_PATH and self.BASE_GPU_IMAGE_PATH: + with_image_override = True + elif not self.BUILDER_IMAGE_PATH and not self.BASE_CPU_IMAGE_PATH and not self.BASE_GPU_IMAGE_PATH: + with_image_override = False + else: + raise ValueError( + "Please set or unset BUILDER_IMAGE_PATH, BASE_CPU_IMAGE_PATH, and BASE_GPU_IMAGE_PATH at the same time." + ) + + conda_dependencies = [ + test_env_utils.get_latest_package_version_spec_in_server( + self.session, "snowflake-snowpark-python!=1.12.0, <1.21.1" + ) + ] + if additional_dependencies: + conda_dependencies.extend(additional_dependencies) + + # Get the name of the caller as the model name + name = f"model_{inspect.stack()[1].function}" + version = f"ver_{self._run_id}" + mv = self.registry.log_model( + model=model, + model_name=name, + version_name=version, + sample_input_data=sample_input_data, + conda_dependencies=conda_dependencies, + pip_requirements=pip_requirements, + options=options, + ) + + service = f"service_{inspect.stack()[1].function}_{self._run_id}" + if with_image_override: + self._deploy_model_with_image_override( + mv, + service_name=service, + gpu_requests=gpu_requests, + ) + else: + mv.create_service( + service_name=service, + image_build_compute_pool=self._TEST_CPU_COMPUTE_POOL, + service_compute_pool=( + self._TEST_CPU_COMPUTE_POOL if gpu_requests is None else self._TEST_GPU_COMPUTE_POOL + ), + image_repo=self._test_image_repo, + gpu_requests=gpu_requests, + force_rebuild=True, + build_external_access_integration=self._SPCS_EAI, + ) + + for target_method, (test_input, check_func) in prediction_assert_fns.items(): + res = mv.run(test_input, function_name=target_method, service_name=service) + check_func(res) diff --git a/tests/integ/snowflake/ml/registry/services/registry_sentence_transformers_model_deployment_test.py b/tests/integ/snowflake/ml/registry/services/registry_sentence_transformers_model_deployment_test.py new file mode 100644 index 00000000..a39c5a5e --- /dev/null +++ b/tests/integ/snowflake/ml/registry/services/registry_sentence_transformers_model_deployment_test.py @@ -0,0 +1,80 @@ +import os +import random +import tempfile +from typing import List, Optional + +import pandas as pd +from absl.testing import absltest, parameterized + +from tests.integ.snowflake.ml.registry.services import ( + registry_model_deployment_test_base, +) + +MODEL_NAMES = ["intfloat/e5-base-v2"] # cant load models in parallel +SENTENCE_TRANSFORMERS_CACHE_DIR = "SENTENCE_TRANSFORMERS_HOME" + + +class TestRegistrySentenceTransformerDeploymentModelInteg( + registry_model_deployment_test_base.RegistryModelDeploymentTestBase +): + @classmethod + def setUpClass(self) -> None: + self.cache_dir = tempfile.TemporaryDirectory() + self._original_cache_dir = os.getenv(SENTENCE_TRANSFORMERS_CACHE_DIR, None) + os.environ[SENTENCE_TRANSFORMERS_CACHE_DIR] = self.cache_dir.name + + @classmethod + def tearDownClass(self) -> None: + if self._original_cache_dir: + os.environ[SENTENCE_TRANSFORMERS_CACHE_DIR] = self._original_cache_dir + self.cache_dir.cleanup() + + @parameterized.product( # type: ignore[misc] + gpu_requests=[None, "1"], + pip_requirements=[None, ["sentence-transformers"]], + ) + def test_sentence_transformers( + self, + gpu_requests: str, + pip_requirements: Optional[List[str]], + ) -> None: + import sentence_transformers + + # Sample Data + sentences = pd.DataFrame( + { + "SENTENCES": [ + "Why don’t scientists trust atoms? Because they make up everything.", + "I told my wife she should embrace her mistakes. She gave me a hug.", + "Im reading a book on anti-gravity. Its impossible to put down!", + "Did you hear about the mathematician who’s afraid of negative numbers?", + "Parallel lines have so much in common. It’s a shame they’ll never meet.", + ] + } + ) + model = sentence_transformers.SentenceTransformer(random.choice(MODEL_NAMES)) + embeddings = pd.DataFrame(model.encode(sentences["SENTENCES"].to_list(), batch_size=sentences.shape[0])) + + self._test_registry_model_deployment( + model=model, + sample_input_data=sentences, + prediction_assert_fns={ + "encode": ( + sentences, + lambda res: pd.testing.assert_frame_equal( + pd.DataFrame(res["output_feature_0"].to_list()), + embeddings, + rtol=1e-2, + atol=1e-3, + check_dtype=False, + ), + ), + }, + options={"cuda_version": "11.8"} if gpu_requests else {}, + gpu_requests=gpu_requests, + pip_requirements=pip_requirements, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/services/registry_sklearn_model_deployment_test.py b/tests/integ/snowflake/ml/registry/services/registry_sklearn_model_deployment_test.py new file mode 100644 index 00000000..9a258a5e --- /dev/null +++ b/tests/integ/snowflake/ml/registry/services/registry_sklearn_model_deployment_test.py @@ -0,0 +1,35 @@ +from typing import List, Optional + +import numpy as np +from absl.testing import absltest, parameterized +from sklearn import datasets, svm + +from tests.integ.snowflake.ml.registry.services import ( + registry_model_deployment_test_base, +) + + +class TestRegistrySklearnModelDeploymentInteg(registry_model_deployment_test_base.RegistryModelDeploymentTestBase): + @parameterized.parameters({"pip_requirements": None}, {"pip_requirements": ["scikit-learn"]}) # type: ignore[misc] + def test_sklearn(self, pip_requirements: Optional[List[str]]) -> None: + iris_X, iris_y = datasets.load_iris(return_X_y=True) + svc = svm.LinearSVC() + svc.fit(iris_X, iris_y) + + self._test_registry_model_deployment( + model=svc, + sample_input_data=iris_X, + prediction_assert_fns={ + "predict": ( + iris_X, + lambda res: np.testing.assert_allclose( + res.values, np.expand_dims(svc.predict(iris_X), axis=1), rtol=1e-3 + ), + ), + }, + pip_requirements=pip_requirements, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/services/registry_xgboost_model_deployment_test.py b/tests/integ/snowflake/ml/registry/services/registry_xgboost_model_deployment_test.py new file mode 100644 index 00000000..9519cb38 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/services/registry_xgboost_model_deployment_test.py @@ -0,0 +1,49 @@ +from typing import List + +import inflection +import numpy as np +import xgboost +from absl.testing import absltest, parameterized +from sklearn import datasets, model_selection + +from tests.integ.snowflake.ml.registry.services import ( + registry_model_deployment_test_base, +) + + +class TestRegistryXGBoostModelDeploymentInteg(registry_model_deployment_test_base.RegistryModelDeploymentTestBase): + @parameterized.product( # type: ignore[misc] + gpu_requests=[None, "1"], + pip_requirements=[None, ["xgboost"]], + ) + def test_xgb( + self, + gpu_requests: str, + pip_requirements: List[str], + ) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True) + cal_X = cal_data.data + cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) + regressor.fit(cal_X_train, cal_y_train) + self._test_registry_model_deployment( + model=regressor, + sample_input_data=cal_X_test, + prediction_assert_fns={ + "predict": ( + cal_X_test, + lambda res: np.testing.assert_allclose( + res.values, np.expand_dims(regressor.predict(cal_X_test), axis=1), rtol=1e-3 + ), + ), + }, + options={"cuda_version": "11.8"} if gpu_requests else {}, + gpu_requests=gpu_requests, + pip_requirements=pip_requirements, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/test_utils/BUILD.bazel b/tests/integ/snowflake/ml/test_utils/BUILD.bazel index 144b3e86..098eb2c6 100644 --- a/tests/integ/snowflake/ml/test_utils/BUILD.bazel +++ b/tests/integ/snowflake/ml/test_utils/BUILD.bazel @@ -26,7 +26,6 @@ py_library( srcs = ["db_manager.py"], deps = [ "//snowflake/ml/_internal/utils:identifier", - "//snowflake/ml/model/_deploy_client/utils:constants", "//snowflake/ml/utils:sql_client", ], ) diff --git a/tests/integ/snowflake/ml/test_utils/db_manager.py b/tests/integ/snowflake/ml/test_utils/db_manager.py index fa26dcd5..db72d289 100644 --- a/tests/integ/snowflake/ml/test_utils/db_manager.py +++ b/tests/integ/snowflake/ml/test_utils/db_manager.py @@ -4,7 +4,6 @@ from snowflake import snowpark from snowflake.ml._internal.utils import identifier -from snowflake.ml.model._deploy_client.utils import constants from snowflake.ml.utils import sql_client _COMMON_PREFIX = "snowml_test_" @@ -158,12 +157,11 @@ def create_stage( ).collect() return full_qual_stage_name - def show_stages( - self, - stage_name: str, + @staticmethod + def get_show_location_url( schema_name: Optional[str] = None, db_name: Optional[str] = None, - ) -> snowpark.DataFrame: + ) -> str: if schema_name: actual_schema_name = identifier.get_inferred_name(schema_name) if db_name: @@ -174,6 +172,15 @@ def show_stages( location_sql = f" IN SCHEMA {full_qual_schema_name}" else: location_sql = "" + return location_sql + + def show_stages( + self, + stage_name: str, + schema_name: Optional[str] = None, + db_name: Optional[str] = None, + ) -> snowpark.DataFrame: + location_sql = DBManager.get_show_location_url(schema_name, db_name) sql = f"SHOW STAGES LIKE '{stage_name}'{location_sql}" return self._session.sql(sql) @@ -216,16 +223,7 @@ def show_user_functions( schema_name: Optional[str] = None, db_name: Optional[str] = None, ) -> snowpark.DataFrame: - if schema_name: - actual_schema_name = identifier.get_inferred_name(schema_name) - if db_name: - actual_db_name = identifier.get_inferred_name(db_name) - full_qual_schema_name = f"{actual_db_name}.{actual_schema_name}" - else: - full_qual_schema_name = actual_schema_name - location_sql = f" IN SCHEMA {full_qual_schema_name}" - else: - location_sql = "" + location_sql = DBManager.get_show_location_url(schema_name, db_name) sql = f"SHOW USER FUNCTIONS LIKE '{function_name}'{location_sql}" return self._session.sql(sql) @@ -271,18 +269,6 @@ def cleanup_user_functions( func_def = func_arguments.partition("RETURN")[0].strip() self.drop_function(function_def=func_def, schema_name=schema_name, db_name=db_name, if_exists=True) - def get_snowservice_image_repo( - self, - repo: str, - subdomain: str = constants.DEV_IMAGE_REGISTRY_SUBDOMAIN, - ) -> str: - conn = self._session._conn._conn - org = conn.host.split(".")[1] - account = conn.account - db = conn._database - schema = conn._schema - return f"{org}-{account}.{subdomain}.{constants.PROD_IMAGE_REGISTRY_DOMAIN}/{db}/{schema}/{repo}".lower() - def create_compute_pool( self, compute_pool_name: str, diff --git a/tests/integ/snowflake/ml/test_utils/model_factory.py b/tests/integ/snowflake/ml/test_utils/model_factory.py index 6763b01b..f059ee04 100644 --- a/tests/integ/snowflake/ml/test_utils/model_factory.py +++ b/tests/integ/snowflake/ml/test_utils/model_factory.py @@ -1,10 +1,12 @@ from enum import Enum from typing import TYPE_CHECKING, List, Optional, Tuple, cast +import inflection import numpy as np import numpy.typing as npt import pandas as pd -from sklearn import datasets, svm +import xgboost +from sklearn import datasets, model_selection, svm from snowflake.ml.model import custom_model from snowflake.ml.modeling.linear_model import ( # type: ignore[attr-defined] @@ -53,6 +55,17 @@ def one_vs_all(dataset: npt.NDArray[np.float64], digit: int) -> List[bool]: return clf, test_features, test_labels + @staticmethod + def prepare_xgboost_model() -> Tuple[xgboost.XGBRegressor, pd.DataFrame, pd.DataFrame]: + cal_data = datasets.load_breast_cancer(as_frame=True) + cal_X = cal_data.data + cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) + regressor.fit(cal_X_train, cal_y_train) + return regressor, cal_X_test, cal_y_test + @staticmethod def prepare_snowml_model_xgb() -> Tuple[XGBClassifier, pd.DataFrame, pd.DataFrame]: """Prepare SnowML XGBClassifier model. diff --git a/tests/pytest.ini b/tests/pytest.ini index a8522207..364a19e1 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -20,3 +20,4 @@ markers = ; the SnowML Build & Test pipeline. They will still be tested in conda environment. pip_incompatible: mark a test as incompatible with pip environment. conda_incompatible: mark a test as incompatible with conda environment. + spcs_deployment_image: mark a test as requiring the SPCS deployment image.