From f50d041522ec34bdb85aef793cab17c4b41f30f4 Mon Sep 17 00:00:00 2001 From: Angel Antonio Avalos Cisneros Date: Thu, 12 Sep 2024 10:23:52 -0700 Subject: [PATCH] Project import generated by Copybara. (#116) GitOrigin-RevId: 6fc3ce416ce5843fc01936fc61bf5480ae9f791f Co-authored-by: Snowflake Authors --- BUILD.bazel | 2 + CHANGELOG.md | 23 +- CONTRIBUTING.md | 6 +- bazel/environments/conda-env-snowflake.yml | 5 +- bazel/environments/conda-env.yml | 12 +- bazel/environments/conda-gpu-env.yml | 12 +- bazel/requirements/requirements.schema.json | 5 - ci/conda_recipe/meta.yaml | 4 +- codegen/codegen_rules.bzl | 2 - codegen/sklearn_wrapper_generator.py | 9 +- codegen/sklearn_wrapper_template.py_template | 81 +++-- ...nsformer_autogen_test_template.py_template | 10 +- requirements.txt | 8 +- requirements.yml | 23 +- snowflake/cortex/BUILD.bazel | 40 --- snowflake/cortex/__init__.py | 4 - snowflake/cortex/_embed_text_1024.py | 42 --- snowflake/cortex/_embed_text_768.py | 43 --- snowflake/cortex/embed_text_1024_test.py | 65 ---- snowflake/cortex/embed_text_768_test.py | 65 ---- snowflake/cortex/package_visibility_test.py | 6 - .../ml/_internal/snowpark_pandas/patch.py | 2 +- snowflake/ml/_internal/telemetry.py | 162 +++++++-- snowflake/ml/_internal/telemetry_test.py | 188 +++++++++- snowflake/ml/_internal/utils/BUILD.bazel | 20 ++ snowflake/ml/_internal/utils/db_utils.py | 50 +++ snowflake/ml/_internal/utils/db_utils_test.py | 78 +++++ snowflake/ml/_internal/utils/identifier.py | 59 +++- .../ml/_internal/utils/identifier_test.py | 18 +- snowflake/ml/_internal/utils/snowflake_env.py | 36 +- .../ml/_internal/utils/snowflake_env_test.py | 12 + .../ml/_internal/utils/sql_identifier.py | 2 +- .../ml/_internal/utils/sql_identifier_test.py | 6 + snowflake/ml/_internal/utils/table_manager.py | 20 +- snowflake/ml/_internal/utils/uri.py | 4 +- snowflake/ml/data/BUILD.bazel | 18 +- snowflake/ml/data/data_connector.py | 40 ++- snowflake/ml/data/data_connector_test.py | 67 +++- snowflake/ml/data/torch_dataset.py | 33 -- snowflake/ml/data/torch_utils.py | 68 ++++ snowflake/ml/dataset/dataset.py | 4 +- snowflake/ml/feature_store/feature_store.py | 58 +++- snowflake/ml/feature_store/feature_view.py | 4 +- snowflake/ml/fileset/embedded_stage_fs.py | 2 +- snowflake/ml/fileset/fileset.py | 2 +- snowflake/ml/fileset/sfcfs.py | 12 +- snowflake/ml/fileset/sfcfs_test.py | 6 +- .../model/_client/model/model_version_impl.py | 29 +- .../_client/model/model_version_impl_test.py | 57 ++- snowflake/ml/model/_client/ops/BUILD.bazel | 14 + snowflake/ml/model/_client/ops/model_ops.py | 42 ++- .../ml/model/_client/ops/model_ops_test.py | 73 +++- snowflake/ml/model/_client/ops/service_ops.py | 205 ++++++++++- .../ml/model/_client/ops/service_ops_test.py | 124 +++++++ .../ml/model/_client/service/BUILD.bazel | 11 +- .../_client/service/model_deployment_spec.py | 9 +- .../service/model_deployment_spec_schema.py | 3 +- .../service/model_deployment_spec_test.py | 158 +++++++++ snowflake/ml/model/_client/sql/BUILD.bazel | 11 + snowflake/ml/model/_client/sql/service.py | 103 +++++- .../ml/model/_client/sql/service_test.py | 327 ++++++++++++++++++ .../image_builds/inference_server/BUILD.bazel | 2 - .../image_builds/server_image_builder.py | 2 +- .../_deploy_client/snowservice/deploy.py | 6 +- .../model/_model_composer/model_composer.py | 2 + .../_model_composer/model_composer_test.py | 14 +- .../model_manifest/model_manifest.py | 11 +- .../model_manifest/model_manifest_test.py | 3 +- .../_packager/model_handlers/BUILD.bazel | 11 + .../model/_packager/model_handlers/_utils.py | 60 +++- .../_packager/model_handlers/catboost.py | 32 +- .../model_handlers/huggingface_pipeline.py | 38 +- .../_packager/model_handlers/lightgbm.py | 72 +--- .../ml/model/_packager/model_handlers/llm.py | 6 +- .../model_handlers/model_objective_utils.py | 116 +++++++ .../model/_packager/model_handlers/sklearn.py | 60 ++-- .../_packager/model_handlers/snowmlmodel.py | 125 ++++++- .../_packager/model_handlers/torchscript.py | 4 +- .../model/_packager/model_handlers/xgboost.py | 96 ++--- .../_packager/model_handlers_test/BUILD.bazel | 12 + .../model_handlers_test/_utils_test.py | 35 +- .../model_handlers_test/catboost_test.py | 35 ++ .../huggingface_pipeline_test.py | 4 + .../model_objective_utils_test.py | 127 +++++++ .../model_handlers_test/pytorch_test.py | 2 +- .../model_handlers_test/sklearn_test.py | 47 ++- .../model_handlers_test/snowmlmodel_test.py | 47 ++- .../model_handlers_test/torchscript_test.py | 2 +- .../model_handlers_test/xgboost_test.py | 51 +++ .../model/_packager/model_meta/model_meta.py | 17 +- .../_packager/model_meta/model_meta_schema.py | 8 - .../_packager/model_meta/model_meta_test.py | 12 +- .../ml/model/_packager/model_packager.py | 2 + .../ml/model/_packager/model_packager_test.py | 4 +- .../ml/model/_signatures/pytorch_handler.py | 2 +- snowflake/ml/model/_signatures/utils.py | 9 + snowflake/ml/model/models/llm.py | 4 +- snowflake/ml/model/type_hints.py | 10 +- snowflake/ml/modeling/_internal/constants.py | 1 + .../local_implementations/pandas_handlers.py | 10 +- .../local_implementations/pandas_trainer.py | 15 +- .../_internal/model_specifications.py | 2 + .../ml/modeling/_internal/model_trainer.py | 1 + .../distributed_hpo_trainer.py | 4 +- .../snowpark_handlers.py | 10 +- .../snowpark_trainer.py | 273 ++++++--------- snowflake/ml/modeling/parameters/BUILD.bazel | 34 +- .../parameters/disable_model_tracer.py | 5 + .../parameters/disable_model_tracer_test.py | 24 ++ snowflake/ml/modeling/pipeline/BUILD.bazel | 3 + snowflake/ml/modeling/pipeline/pipeline.py | 13 +- .../ml/modeling/pipeline/pipeline_test.py | 112 +++++- .../ml/registry/_manager/model_manager.py | 4 + .../registry/_manager/model_manager_test.py | 7 + snowflake/ml/registry/model_registry.py | 2 +- snowflake/ml/registry/registry.py | 3 +- snowflake/ml/registry/registry_test.py | 1 + snowflake/ml/test_utils/mock_session.py | 16 + snowflake/ml/version.bzl | 2 +- .../snowflake/ml/extra_tests/BUILD.bazel | 16 +- .../batch_inference_with_nan_data_test.py | 13 +- .../extra_tests/column_name_inference_test.py | 2 +- .../ml/extra_tests/decimal_type_test.py | 2 +- .../multi_label_column_name_test.py | 10 + .../pipeline_with_ohe_and_xgbr_test.py | 253 ++++---------- .../ml/extra_tests/sample_weight_col_test.py | 104 ++++++ .../xgboost_external_memory_training_test.py | 27 +- .../ml/feature_store/common_utils.py | 10 +- .../ml/feature_store/feature_store_test.py | 2 +- ...e_huggingface_pipeline_model_integ_test.py | 5 +- .../snowflake/ml/modeling/metrics/BUILD.bazel | 4 + .../metrics/mean_absolute_error_test.py | 149 ++++---- .../mean_absolute_percentage_error_test.py | 149 ++++---- .../metrics/mean_squared_error_test.py | 207 +++++------ .../metrics/precision_recall_curve_test.py | 102 +++--- .../snowflake/ml/observability/BUILD.bazel | 16 + .../observability/model_monitor_integ_test.py | 263 ++++++++++++++ .../snowflake/ml/registry/model/BUILD.bazel | 2 +- ...egistry_huggingface_pipeline_model_test.py | 6 +- .../model/registry_modeling_model_test.py | 245 +++++++++++++ .../model/registry_xgboost_model_test.py | 25 ++ .../snowflake/ml/snowpark_pandas/BUILD.bazel | 1 - 142 files changed, 4356 insertions(+), 1493 deletions(-) delete mode 100644 snowflake/cortex/_embed_text_1024.py delete mode 100644 snowflake/cortex/_embed_text_768.py delete mode 100644 snowflake/cortex/embed_text_1024_test.py delete mode 100644 snowflake/cortex/embed_text_768_test.py create mode 100644 snowflake/ml/_internal/utils/db_utils.py create mode 100644 snowflake/ml/_internal/utils/db_utils_test.py delete mode 100644 snowflake/ml/data/torch_dataset.py create mode 100644 snowflake/ml/data/torch_utils.py create mode 100644 snowflake/ml/model/_client/ops/service_ops_test.py create mode 100644 snowflake/ml/model/_client/service/model_deployment_spec_test.py create mode 100644 snowflake/ml/model/_client/sql/service_test.py create mode 100644 snowflake/ml/model/_packager/model_handlers/model_objective_utils.py create mode 100644 snowflake/ml/model/_packager/model_handlers_test/model_objective_utils_test.py create mode 100644 snowflake/ml/modeling/parameters/disable_model_tracer.py create mode 100644 snowflake/ml/modeling/parameters/disable_model_tracer_test.py create mode 100644 tests/integ/snowflake/ml/extra_tests/sample_weight_col_test.py create mode 100644 tests/integ/snowflake/ml/observability/BUILD.bazel create mode 100644 tests/integ/snowflake/ml/observability/model_monitor_integ_test.py diff --git a/BUILD.bazel b/BUILD.bazel index 3fb7cb2f..6009c8fc 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -3,6 +3,8 @@ load("//:packages.bzl", "PACKAGES") load("//bazel:py_rules.bzl", "py_wheel") load("//bazel/requirements:rules.bzl", "generate_pyproject_file") +package(default_visibility = ["//visibility:public"]) + exports_files([ "CHANGELOG.md", "README.md", diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bd71e29..6afd884c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,26 @@ # Release History -## 1.6.1 (TBD) +## 1.6.2 (TBD) + +### Bug Fixes + +- Modeling: Support XGBoost version that is larger than 2. + +- Data: Fix multiple epoch iteration over `DataConnector.to_torch_datapipe()` DataPipes. +- Generic: Fix a bug that when an invalid name is provided to argument where fully qualified name is expected, it will + be parsed wrongly. Now it raises an exception correctly. +- Model Explainability: Handle explanations for multiclass XGBoost classification models +- Model Explainability: Workarounds and better error handling for XGB>2.1.0 not working with SHAP==0.42.1 + +### New Features + +- Data: Add top-level exports for `DataConnector` and `DataSource` to `snowflake.ml.data`. +- Data: Add native batching support via `batch_size` and `drop_last_batch` arguments to `DataConnector.to_torch_dataset()` +- Feature Store: update_feature_view() supports taking feature view object as argument. + +### Behavior Changes + +## 1.6.1 (2024-08-12) ### Bug Fixes @@ -17,7 +37,6 @@ ### New Features - Enable `set_params` to set the parameters of the underlying sklearn estimator, if the snowflake-ml model has been fit. -- Data: Add top-level exports for `DataConnector` and `DataSource` to `snowflake.ml.data`. - Data: Add `snowflake.ml.data.ingestor_utils` module with utility functions helpful for `DataIngestor` implementations. - Data: Add new `to_torch_dataset()` connector to `DataConnector` to replace deprecated DataPipe. - Registry: Option to `enable_explainability` set to True by default for XGBoost, LightGBM and CatBoost as PuPr feature. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 990f4705..e706ba00 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -304,7 +304,7 @@ Example: ## Unit Testing -Write `pytest` or Python `unittest` style unit tests. +Write Python `unittest` style unit tests. Pytest is allowed, but not recommended. ### `unittest` @@ -320,6 +320,10 @@ from absl.testing import absltest # instead of # from unittest import TestCase, main from absl.testing.absltest import TestCase, main + +# Call main. +if __name__ == '__main__': + absltest.main() ``` `absltest` provides better `bazel` integration which produces a more detailed XML diff --git a/bazel/environments/conda-env-snowflake.yml b/bazel/environments/conda-env-snowflake.yml index 4652b6eb..7109cd8d 100644 --- a/bazel/environments/conda-env-snowflake.yml +++ b/bazel/environments/conda-env-snowflake.yml @@ -28,6 +28,7 @@ dependencies: - lightgbm==3.3.5 - mlflow==2.3.1 - moto==4.0.11 + - mypy==1.10.0 - networkx==2.8.4 - numpy==1.23.5 - packaging==23.0 @@ -54,14 +55,16 @@ dependencies: - snowflake-snowpark-python==1.17.0 - sphinx==5.0.2 - sqlparse==0.4.4 + - starlette==0.27.0 - tensorflow==2.12.0 - tokenizers==0.13.2 - toml==0.10.2 - torchdata==0.6.1 - transformers==4.32.1 + - types-PyYAML==6.0.12.12 - types-protobuf==4.23.0.1 - types-requests==2.30.0.0 - types-toml==0.10.8.6 - - typing-extensions==4.5.0 + - typing-extensions==4.6.3 - werkzeug==2.2.2 - xgboost==1.7.3 diff --git a/bazel/environments/conda-env.yml b/bazel/environments/conda-env.yml index dfc99a9b..35f631d3 100644 --- a/bazel/environments/conda-env.yml +++ b/bazel/environments/conda-env.yml @@ -14,11 +14,6 @@ dependencies: - cachetools==4.2.2 - catboost==1.2.0 - cloudpickle==2.2.1 - - conda-forge::accelerate==0.22.0 - - conda-forge::mypy==1.5.1 - - conda-forge::starlette==0.27.0 - - conda-forge::types-PyYAML==6.0.12 - - conda-forge::types-cachetools==4.2.2 - conda-libmamba-solver==23.7.0 - coverage==6.3.2 - cryptography==39.0.1 @@ -33,6 +28,7 @@ dependencies: - lightgbm==3.3.5 - mlflow==2.3.1 - moto==4.0.11 + - mypy==1.10.0 - networkx==2.8.4 - numpy==1.23.5 - packaging==23.0 @@ -59,18 +55,22 @@ dependencies: - snowflake-snowpark-python==1.17.0 - sphinx==5.0.2 - sqlparse==0.4.4 + - starlette==0.27.0 - tensorflow==2.12.0 - tokenizers==0.13.2 - toml==0.10.2 - torchdata==0.6.1 - transformers==4.32.1 + - types-PyYAML==6.0.12.12 - types-protobuf==4.23.0.1 - types-requests==2.30.0.0 - types-toml==0.10.8.6 - - typing-extensions==4.5.0 + - typing-extensions==4.6.3 - werkzeug==2.2.2 - xgboost==1.7.3 - pip - pip: - --extra-index-url https://pypi.org/simple + - accelerate==0.22.0 + - types-cachetools==4.2.2 - peft==0.5.0 diff --git a/bazel/environments/conda-gpu-env.yml b/bazel/environments/conda-gpu-env.yml index 3233386c..20d82ad3 100755 --- a/bazel/environments/conda-gpu-env.yml +++ b/bazel/environments/conda-gpu-env.yml @@ -14,11 +14,6 @@ dependencies: - cachetools==4.2.2 - catboost==1.2.0 - cloudpickle==2.2.1 - - conda-forge::accelerate==0.22.0 - - conda-forge::mypy==1.5.1 - - conda-forge::starlette==0.27.0 - - conda-forge::types-PyYAML==6.0.12 - - conda-forge::types-cachetools==4.2.2 - conda-libmamba-solver==23.7.0 - coverage==6.3.2 - cryptography==39.0.1 @@ -33,6 +28,7 @@ dependencies: - lightgbm==3.3.5 - mlflow==2.3.1 - moto==4.0.11 + - mypy==1.10.0 - networkx==2.8.4 - numpy==1.23.5 - nvidia::cuda==11.7.* @@ -61,19 +57,23 @@ dependencies: - snowflake-snowpark-python==1.17.0 - sphinx==5.0.2 - sqlparse==0.4.4 + - starlette==0.27.0 - tensorflow==2.12.0 - tokenizers==0.13.2 - toml==0.10.2 - torchdata==0.6.1 - transformers==4.32.1 + - types-PyYAML==6.0.12.12 - types-protobuf==4.23.0.1 - types-requests==2.30.0.0 - types-toml==0.10.8.6 - - typing-extensions==4.5.0 + - typing-extensions==4.6.3 - werkzeug==2.2.2 - xgboost==1.7.3 - pip - pip: - --extra-index-url https://pypi.org/simple + - accelerate==0.22.0 + - types-cachetools==4.2.2 - peft==0.5.0 - vllm==0.2.1.post1 diff --git a/bazel/requirements/requirements.schema.json b/bazel/requirements/requirements.schema.json index 32e9c29e..dbc20e69 100644 --- a/bazel/requirements/requirements.schema.json +++ b/bazel/requirements/requirements.schema.json @@ -59,11 +59,6 @@ "pattern": "^$|^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\\.(0|[1-9][0-9]*))*((a|b|rc|alpha|beta)(0|[1-9][0-9]*))?(\\.post(0|[1-9][0-9]*))?(\\.dev(0|[1-9][0-9]*))?$", "type": "string" }, - "from_channel": { - "default": "https://repo.anaconda.com/pkgs/snowflake", - "description": "The channel where the package come from, set if not from Snowflake Anaconda Channel.", - "type": "string" - }, "gpu_only": { "default": false, "description": "The package is required when running in an environment where GPU is available.", diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index 7d2f84c3..ff824bda 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -17,7 +17,7 @@ build: noarch: python package: name: snowflake-ml-python - version: 1.6.1 + version: 1.6.2 requirements: build: - python @@ -45,7 +45,7 @@ requirements: - snowflake-snowpark-python>=1.17.0,<2 - sqlparse>=0.4,<1 - typing-extensions>=4.1.0,<5 - - xgboost>=1.7.3,<2 + - xgboost>=1.7.3,<2.1 - python>=3.8,<3.12 run_constrained: - catboost>=1.2.0, <2 diff --git a/codegen/codegen_rules.bzl b/codegen/codegen_rules.bzl index a3184740..a46a0642 100644 --- a/codegen/codegen_rules.bzl +++ b/codegen/codegen_rules.bzl @@ -90,7 +90,6 @@ def autogen_estimators(module, estimator_info_list): "//snowflake/ml/_internal/exceptions:exceptions", "//snowflake/ml/_internal/utils:temp_file_utils", "//snowflake/ml/_internal/utils:query_result_checker", - "//snowflake/ml/_internal/utils:pkg_version_utils", "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/model:model_signature", "//snowflake/ml/model/_signatures:utils", @@ -181,7 +180,6 @@ def autogen_snowpark_pandas_tests(module, module_root_dir, snowpark_pandas_estim "//snowflake/ml/_internal/snowpark_pandas:snowpark_pandas_lib", "//snowflake/ml/utils:connection_params", ], - compatible_with_snowpark = False, timeout = "long", legacy_create_init = 0, shard_count = 5, diff --git a/codegen/sklearn_wrapper_generator.py b/codegen/sklearn_wrapper_generator.py index ad94594e..6528a2de 100644 --- a/codegen/sklearn_wrapper_generator.py +++ b/codegen/sklearn_wrapper_generator.py @@ -1153,15 +1153,18 @@ def generate(self) -> "XGBoostWrapperGenerator": super().generate() # Populate XGBoost specific values - self.estimator_imports_list.append("import xgboost") + self.estimator_imports_list.extend(["import sklearn", "import xgboost"]) self.test_estimator_input_args_list.extend( ["random_state=0", "subsample=1.0", "colsample_bynode=1.0", "n_jobs=1"] ) - self.score_sproc_imports = ["xgboost"] + self.score_sproc_imports = ["xgboost", "sklearn"] # TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda. self.supported_export_method = "to_xgboost" self.unsupported_export_methods = ["to_sklearn", "to_lightgbm"] - self.deps = "f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'" + self.deps = ( + "f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', " + + "f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'" + ) self._construct_string_from_lists() return self diff --git a/codegen/sklearn_wrapper_template.py_template b/codegen/sklearn_wrapper_template.py_template index acbd80eb..67725e7e 100644 --- a/codegen/sklearn_wrapper_template.py_template +++ b/codegen/sklearn_wrapper_template.py_template @@ -1,13 +1,11 @@ import inspect import os -import posixpath -from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set -from typing_extensions import TypeGuard +from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple from uuid import uuid4 import cloudpickle as cp -import pandas as pd import numpy as np +import pandas as pd from numpy import typing as npt @@ -18,12 +16,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols from snowflake.ml._internal import telemetry from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV -from snowflake.ml._internal.utils import pkg_version_utils, identifier +from snowflake.ml._internal.utils import identifier from snowflake.snowpark import DataFrame, Session from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder from snowflake.ml.modeling._internal.transformer_protocols import ( - ModelTransformHandlers, BatchInferenceKwargsTypedDict, ScoreKwargsTypedDict ) @@ -361,12 +358,23 @@ class {transform.original_class_name}(BaseTransformer): autogenerated=self._autogenerated, subproject=_SUBPROJECT, ) - output_result, fitted_estimator = model_trainer.train_fit_predict( - drop_input_cols=self._drop_input_cols, - expected_output_cols_list=( - self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix) - ), + expected_output_cols = ( + self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix) ) + if isinstance(dataset, DataFrame): + expected_output_cols, example_output_pd_df = self._align_expected_output( + "fit_predict", dataset, expected_output_cols, output_cols_prefix + ) + output_result, fitted_estimator = model_trainer.train_fit_predict( + drop_input_cols=self._drop_input_cols, + expected_output_cols_list=expected_output_cols, + example_output_pd_df=example_output_pd_df, + ) + else: + output_result, fitted_estimator = model_trainer.train_fit_predict( + drop_input_cols=self._drop_input_cols, + expected_output_cols_list=expected_output_cols, + ) self._sklearn_object = fitted_estimator self._is_fitted = True return output_result @@ -437,12 +445,41 @@ class {transform.original_class_name}(BaseTransformer): return rv - def _align_expected_output_names( - self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str - ) -> List[str]: + def _align_expected_output( + self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str, + ) -> Tuple[List[str], pd.DataFrame]: + """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names + and output dataframe with 1 line. + If the method is fit_predict, run 2 lines of data. + """ # in case the inferred output column names dimension is different # we use one line of snowpark dataframe and put it into sklearn estimator using pandas - sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas() + + # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture + # so change the minimum of number of rows to 2 + num_examples = 2 + statement_params = telemetry.get_function_usage_statement_params( + project=_PROJECT, + subproject=_SUBPROJECT, + function_name=telemetry.get_statement_params_full_func_name( + inspect.currentframe(), {transform.original_class_name}.__class__.__name__ + ), + api_calls=[Session.call], + custom_tags={{"autogen": True}} if self._autogenerated else None, + ) + if output_cols_prefix == "fit_predict_": + if hasattr(self._sklearn_object, "n_clusters"): + # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters + num_examples = self._sklearn_object.n_clusters + elif hasattr(self._sklearn_object, "min_samples"): + # OPTICS default min_samples 5, which requires at least 5 lines of data + num_examples = self._sklearn_object.min_samples + elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"): + # LocalOutlierFactor expects n_neighbors <= n_samples + num_examples = self._sklearn_object.n_neighbors + sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params) + else: + sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params) # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order # seen during the fit. @@ -454,12 +491,14 @@ class {transform.original_class_name}(BaseTransformer): output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns) if self.sample_weight_col: output_df_columns_set -= set(self.sample_weight_col) + # if the dimension of inferred output column names is correct; use it if len(expected_output_cols_list) == len(output_df_columns_set): - return expected_output_cols_list + return expected_output_cols_list, output_df_pd # otherwise, use the sklearn estimator's output else: - return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x)) + expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x)) + return expected_output_cols_list, output_df_pd[expected_output_cols_list] @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( @@ -497,7 +536,7 @@ class {transform.original_class_name}(BaseTransformer): drop_input_cols=self._drop_input_cols, expected_output_cols_type="float", ) - expected_output_cols = self._align_expected_output_names( + expected_output_cols, _ = self._align_expected_output( inference_method, dataset, expected_output_cols, output_cols_prefix ) @@ -555,7 +594,7 @@ class {transform.original_class_name}(BaseTransformer): drop_input_cols=self._drop_input_cols, expected_output_cols_type="float", ) - expected_output_cols = self._align_expected_output_names( + expected_output_cols, _ = self._align_expected_output( inference_method, dataset, expected_output_cols, output_cols_prefix ) elif isinstance(dataset, pd.DataFrame): @@ -610,7 +649,7 @@ class {transform.original_class_name}(BaseTransformer): drop_input_cols=self._drop_input_cols, expected_output_cols_type="float", ) - expected_output_cols = self._align_expected_output_names( + expected_output_cols, _ = self._align_expected_output( inference_method, dataset, expected_output_cols, output_cols_prefix ) @@ -667,7 +706,7 @@ class {transform.original_class_name}(BaseTransformer): drop_input_cols = self._drop_input_cols, expected_output_cols_type="float", ) - expected_output_cols = self._align_expected_output_names( + expected_output_cols, _ = self._align_expected_output( inference_method, dataset, expected_output_cols, output_cols_prefix ) diff --git a/codegen/transformer_autogen_test_template.py_template b/codegen/transformer_autogen_test_template.py_template index a35cdbbb..39620dac 100644 --- a/codegen/transformer_autogen_test_template.py_template +++ b/codegen/transformer_autogen_test_template.py_template @@ -256,14 +256,14 @@ class {transform.test_class_name}(TestCase): ) if callable(getattr(sklearn_reg, "score", None)) and callable(getattr(reg, "score", None)): - score_argspec = inspect.getfullargspec(sklearn_reg.score) + score_params = inspect.signature(sklearn_reg.score).parameters # Some classes that has sample_weight argument in fit() but not in score(). - if use_weighted_dataset is True and 'sample_weight' not in score_argspec.args: + if use_weighted_dataset is True and 'sample_weight' not in score_params: del args['sample_weight'] input_df_pandas = input_df_pandas.drop(['sample_weight', 'SAMPLE_WEIGHT'], axis=1, errors='ignore') # Some classes have different arg name in score: X -> X_test - if "X_test" in score_argspec.args: + if "X_test" in score_params: args['X_test'] = args.pop('X') if inference_with_udf: @@ -300,8 +300,8 @@ class {transform.test_class_name}(TestCase): is_weighted_dataset_supported = False for m in inspect.getmembers(klass): if inspect.isfunction(m[1]) and m[0] == "fit": - argspec = inspect.getfullargspec(m[1]) - is_weighted_dataset_supported = True if "sample_weight" in argspec.args else False + params = inspect.signature(m[1]).parameters + is_weighted_dataset_supported = True if "sample_weight" in params else False return is_weighted_dataset_supported def test_fit_with_sproc_infer_with_udf_weighted_datasets(self) -> None: diff --git a/requirements.txt b/requirements.txt index a5b30eac..085e93c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ jsonschema==3.2.0 lightgbm==3.3.5 mlflow==2.3.1 moto==4.0.11 -mypy==1.5.1 +mypy==1.10.0 networkx==2.8.4 numpy==1.23.5 packaging==23.0 @@ -50,17 +50,17 @@ snowflake-snowpark-python==1.17.0 sphinx==5.0.2 sqlparse==0.4.4 starlette==0.27.0 -tensorflow==2.13.0 +tensorflow==2.12.0 tokenizers==0.13.2 toml==0.10.2 torch==2.0.1 torchdata==0.6.1 transformers==4.32.1 -types-PyYAML==6.0.12 +types-PyYAML==6.0.12.12 types-cachetools==4.2.2 types-protobuf==4.23.0.1 types-requests==2.30.0.0 types-toml==0.10.8.6 -typing-extensions==4.5.0 +typing-extensions==4.6.3 werkzeug==2.2.2 xgboost==1.7.3 diff --git a/requirements.yml b/requirements.yml index 2bca0a02..f2b2b820 100644 --- a/requirements.yml +++ b/requirements.yml @@ -77,9 +77,6 @@ - build_essential - deployment_core - snowml_inference_alternative -- name: accelerate - dev_version: 0.22.0 - from_channel: conda-forge # For fsspec[http] in conda - name_conda: aiohttp dev_version_conda: 3.8.3 @@ -161,8 +158,7 @@ - name: moto dev_version: 4.0.11 - name: mypy - dev_version: 1.5.1 - from_channel: conda-forge + dev_version: 1.10.0 - name: networkx dev_version: 2.8.4 - name: numpy @@ -266,13 +262,12 @@ - build_essential - name: starlette dev_version: 0.27.0 - from_channel: conda-forge - name: sqlparse dev_version: 0.4.4 version_requirements: '>=0.4,<1' - name: tensorflow dev_version_conda: 2.12.0 - dev_version_pypi: 2.13.0 + dev_version_pypi: 2.12.0 version_requirements: '>=2.10,<3' requirements_extra_tags: - tensorflow @@ -300,26 +295,22 @@ - name: types-protobuf dev_version: 4.23.0.1 - name: types-PyYAML - dev_version: 6.0.12 - from_channel: conda-forge + dev_version: 6.0.12.12 - name: types-toml dev_version: 0.10.8.6 tags: - build_essential - name: typing-extensions - dev_version: 4.5.0 + dev_version: 4.6.3 version_requirements: '>=4.1.0,<5' tags: - deployment_core - snowml_inference_alternative - name: xgboost dev_version: 1.7.3 - version_requirements: '>=1.7.3,<2' + version_requirements: '>=1.7.3,<2.1' tags: - build_essential -- name: types-cachetools - dev_version: 4.2.2 - from_channel: conda-forge - name: werkzeug dev_version: 2.2.2 - name: cachetools @@ -332,6 +323,10 @@ # Below are pip only external packages - name_pypi: --extra-index-url https://pypi.org/simple dev_version_pypi: '' +- name_pypi: accelerate + dev_version_pypi: 0.22.0 +- name_pypi: types-cachetools + dev_version_pypi: 4.2.2 - name_pypi: peft dev_version_pypi: 0.5.0 version_requirements_pypi: '>=0.5.0,<1' diff --git a/snowflake/cortex/BUILD.bazel b/snowflake/cortex/BUILD.bazel index 1c65b740..9a0c6c9d 100644 --- a/snowflake/cortex/BUILD.bazel +++ b/snowflake/cortex/BUILD.bazel @@ -153,44 +153,6 @@ py_test( ], ) -py_library( - name = "embed_text_768", - srcs = ["_embed_text_768.py"], - deps = [ - ":util", - "//snowflake/ml/_internal:telemetry", - ], -) - -py_test( - name = "embed_text_768_test", - srcs = ["embed_text_768_test.py"], - deps = [ - ":embed_text_768", - ":test_util", - "//snowflake/ml/utils:connection_params", - ], -) - -py_library( - name = "embed_text_1024", - srcs = ["_embed_text_1024.py"], - deps = [ - ":util", - "//snowflake/ml/_internal:telemetry", - ], -) - -py_test( - name = "embed_text_1024_test", - srcs = ["embed_text_1024_test.py"], - deps = [ - ":embed_text_1024", - ":test_util", - "//snowflake/ml/utils:connection_params", - ], -) - py_library( name = "init", srcs = [ @@ -199,8 +161,6 @@ py_library( deps = [ ":classify_text", ":complete", - ":embed_text_768", - ":embed_text_1024", ":extract_answer", ":sentiment", ":summarize", diff --git a/snowflake/cortex/__init__.py b/snowflake/cortex/__init__.py index 947b2d77..1ee01368 100644 --- a/snowflake/cortex/__init__.py +++ b/snowflake/cortex/__init__.py @@ -3,16 +3,12 @@ from snowflake.cortex._extract_answer import ExtractAnswer from snowflake.cortex._sentiment import Sentiment from snowflake.cortex._summarize import Summarize -from snowflake.cortex._embed_text_768 import EmbedText768 -from snowflake.cortex._embed_text_1024 import EmbedText1024 from snowflake.cortex._translate import Translate __all__ = [ "ClassifyText", "Complete", "CompleteOptions", - "EmbedText768", - "EmbedText1024", "ExtractAnswer", "Sentiment", "Summarize", diff --git a/snowflake/cortex/_embed_text_1024.py b/snowflake/cortex/_embed_text_1024.py deleted file mode 100644 index 9462a801..00000000 --- a/snowflake/cortex/_embed_text_1024.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Optional, Union - -from snowflake import snowpark -from snowflake.cortex._util import ( - CORTEX_FUNCTIONS_TELEMETRY_PROJECT, - call_sql_function, -) -from snowflake.ml._internal import telemetry - - -@telemetry.send_api_usage_telemetry( - project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, -) -def EmbedText1024( - model: Union[str, snowpark.Column], - text: Union[str, snowpark.Column], - session: Optional[snowpark.Session] = None, -) -> Union[list[float], snowpark.Column]: - """TextEmbed calls into the LLM inference service to embed the text. - - Args: - model: A Column of strings representing the model to use for embedding. The value - of the strings must be within the SUPPORTED_MODELS list. - text: A Column of strings representing input text. - session: The snowpark session to use. Will be inferred by context if not specified. - - Returns: - A column of vectors containing embeddings. - """ - - return _embed_text_1024_impl( - "snowflake.cortex.embed_text_1024", model, text, session=session - ) - - -def _embed_text_1024_impl( - function: str, - model: Union[str, snowpark.Column], - text: Union[str, snowpark.Column], - session: Optional[snowpark.Session] = None, -) -> Union[list[float], snowpark.Column]: - return call_sql_function(function, session, model, text) diff --git a/snowflake/cortex/_embed_text_768.py b/snowflake/cortex/_embed_text_768.py deleted file mode 100644 index 78838ff0..00000000 --- a/snowflake/cortex/_embed_text_768.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Optional, Union, List - -from snowflake import snowpark -from snowflake.cortex._util import ( - CORTEX_FUNCTIONS_TELEMETRY_PROJECT, - SnowflakeConfigurationException, - call_sql_function, -) -from snowflake.ml._internal import telemetry - - -@telemetry.send_api_usage_telemetry( - project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, -) -def EmbedText768( - model: Union[str, snowpark.Column], - text: Union[str, snowpark.Column], - session: Optional[snowpark.Session] = None, -) -> Union[list[float], snowpark.Column]: - """TextEmbed calls into the LLM inference service to embed the text. - - Args: - model: A Column of strings representing the model to use for embedding. The value - of the strings must be within the SUPPORTED_MODELS list. - text: A Column of strings representing input text. - session: The snowpark session to use. Will be inferred by context if not specified. - - Returns: - A column of vectors containing embeddings. - """ - - return _embed_text_768_impl( - "snowflake.cortex.embed_text_768", model, text, session=session - ) - - -def _embed_text_768_impl( - function: str, - model: Union[str, snowpark.Column], - text: Union[str, snowpark.Column], - session: Optional[snowpark.Session] = None, -) -> Union[list[float], snowpark.Column]: - return call_sql_function(function, session, model, text) diff --git a/snowflake/cortex/embed_text_1024_test.py b/snowflake/cortex/embed_text_1024_test.py deleted file mode 100644 index d2724d09..00000000 --- a/snowflake/cortex/embed_text_1024_test.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import List - -import _test_util -from absl.testing import absltest - -from snowflake import snowpark -from snowflake.cortex import _embed_text_1024 -from snowflake.snowpark import functions, types - - -class EmbedTest1024Test(absltest.TestCase): - model = "snowflake-arctic-embed-m" - text = "|text|" - - @staticmethod - def embed_text_1024_for_test(model: str, text: str) -> List[float]: - return [0.0] * 1024 - - def setUp(self) -> None: - self._session = _test_util.create_test_session() - functions.udf( - self.embed_text_1024_for_test, - name="embed_text_1024", - session=self._session, - return_type=types.VectorType(float, 1024), - input_types=[types.StringType(), types.StringType()], - is_permanent=False, - ) - - def tearDown(self) -> None: - self._session.sql("drop function embed_text_1024(string,string)").collect() - self._session.close() - - def test_embed_text_1024_str(self) -> None: - res = _embed_text_1024._embed_text_1024_impl( - "embed_text_1024", - self.model, - self.text, - session=self._session, - ) - out = self.embed_text_1024_for_test(self.model, self.text) - self.assertEqual( - out, res - ), f"Expected ({type(out)}) {out}, got ({type(res)}) {res}" - - def test_embed_text_1024_column(self) -> None: - df_in = self._session.create_dataframe( - [snowpark.Row(model=self.model, text=self.text)] - ) - df_out = df_in.select( - _embed_text_1024._embed_text_1024_impl( - "embed_text_1024", - functions.col("model"), - functions.col("text"), - session=self._session, - ) - ) - res = df_out.collect()[0][0] - out = self.embed_text_1024_for_test(self.model, self.text) - - self.assertEqual(out, res) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/cortex/embed_text_768_test.py b/snowflake/cortex/embed_text_768_test.py deleted file mode 100644 index c07249ab..00000000 --- a/snowflake/cortex/embed_text_768_test.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import List - -import _test_util -from absl.testing import absltest - -from snowflake import snowpark -from snowflake.cortex import _embed_text_768 -from snowflake.snowpark import functions, types - - -class EmbedTest768Test(absltest.TestCase): - model = "snowflake-arctic-embed-m" - text = "|text|" - - @staticmethod - def embed_text_768_for_test(model: str, text: str) -> List[float]: - return [0.0] * 768 - - def setUp(self) -> None: - self._session = _test_util.create_test_session() - functions.udf( - self.embed_text_768_for_test, - name="embed_text_768", - session=self._session, - return_type=types.VectorType(float, 768), - input_types=[types.StringType(), types.StringType()], - is_permanent=False, - ) - - def tearDown(self) -> None: - self._session.sql("drop function embed_text_768(string,string)").collect() - self._session.close() - - def test_embed_text_768_str(self) -> None: - res = _embed_text_768._embed_text_768_impl( - "embed_text_768", - self.model, - self.text, - session=self._session, - ) - out = self.embed_text_768_for_test(self.model, self.text) - self.assertEqual( - out, res - ), f"Expected ({type(out)}) {out}, got ({type(res)}) {res}" - - def test_embed_text_768_column(self) -> None: - df_in = self._session.create_dataframe( - [snowpark.Row(model=self.model, text=self.text)] - ) - df_out = df_in.select( - _embed_text_768._embed_text_768_impl( - "embed_text_768", - functions.col("model"), - functions.col("text"), - session=self._session, - ) - ) - res = df_out.collect()[0][0] - out = self.embed_text_768_for_test(self.model, self.text) - - self.assertEqual(out, res) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/cortex/package_visibility_test.py b/snowflake/cortex/package_visibility_test.py index 98655da8..1addaa09 100644 --- a/snowflake/cortex/package_visibility_test.py +++ b/snowflake/cortex/package_visibility_test.py @@ -16,12 +16,6 @@ def test_complete_visible(self) -> None: def test_extract_answer_visible(self) -> None: self.assertTrue(callable(cortex.ExtractAnswer)) - def test_embed_text_768_visible(self) -> None: - self.assertTrue(callable(cortex.EmbedText768)) - - def test_embed_text_1024_visible(self) -> None: - self.assertTrue(callable(cortex.EmbedText1024)) - def test_sentiment_visible(self) -> None: self.assertTrue(callable(cortex.Sentiment)) diff --git a/snowflake/ml/_internal/snowpark_pandas/patch.py b/snowflake/ml/_internal/snowpark_pandas/patch.py index 8704f27b..7f824898 100644 --- a/snowflake/ml/_internal/snowpark_pandas/patch.py +++ b/snowflake/ml/_internal/snowpark_pandas/patch.py @@ -78,7 +78,7 @@ def patch(*args: Any, **kwargs: Any) -> Any: else: - def patch(self: Any, *args: Any, **kwargs: Any) -> Any: + def patch(self: Any, *args: Any, **kwargs: Any) -> Any: # type: ignore[misc] stage = session.get_session_stage() has_snowpark_pandas = _has_snowpark_pandas(*args, **kwargs) if has_snowpark_pandas: diff --git a/snowflake/ml/_internal/telemetry.py b/snowflake/ml/_internal/telemetry.py index 93e951a8..2e5a0f72 100644 --- a/snowflake/ml/_internal/telemetry.py +++ b/snowflake/ml/_internal/telemetry.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import contextvars import enum import functools import inspect @@ -12,6 +13,7 @@ List, Mapping, Optional, + Set, Tuple, TypeVar, Union, @@ -28,7 +30,7 @@ exceptions as snowml_exceptions, ) from snowflake.snowpark import dataframe, exceptions as snowpark_exceptions, session -from snowflake.snowpark._internal import utils +from snowflake.snowpark._internal import server_connection, utils _log_counter = 0 _FLUSH_SIZE = 10 @@ -85,6 +87,122 @@ class TelemetryField(enum.Enum): FUNC_CAT_USAGE = "usage" +class _TelemetrySourceType(enum.Enum): + # Automatically inferred telemetry/statement parameters + AUTO_TELEMETRY = "SNOWML_AUTO_TELEMETRY" + # Mixture of manual and automatic telemetry/statement parameters + AUGMENT_TELEMETRY = "SNOWML_AUGMENT_TELEMETRY" + + +_statement_params_context_var: contextvars.ContextVar[Dict[str, str]] = contextvars.ContextVar("statement_params") + + +class _StatementParamsPatchManager: + def __init__(self) -> None: + self._patch_cache: Set[server_connection.ServerConnection] = set() + self._context_var: contextvars.ContextVar[Dict[str, str]] = _statement_params_context_var + + def apply_patches(self) -> None: + try: + # Apply patching to all active sessions in case of multiple + for sess in session._get_active_sessions(): + # Check patch cache here to avoid unnecessary context switches + if self._get_target(sess) not in self._patch_cache: + self._patch_session(sess) + except snowpark_exceptions.SnowparkSessionException: + pass + + def set_statement_params(self, statement_params: Dict[str, str]) -> None: + # Only set value if not already set in context + if not self._context_var.get({}): + self._context_var.set(statement_params) + + def _get_target(self, session: session.Session) -> server_connection.ServerConnection: + return cast(server_connection.ServerConnection, session._conn) + + def _patch_session(self, session: session.Session, throw_on_patch_fail: bool = False) -> None: + # Extract target + try: + target = self._get_target(session) + except AttributeError: + if throw_on_patch_fail: + raise + # TODO: Log a warning, this probably means there was a breaking change in Snowpark/SnowflakeConnection + return + + # Check if session has already been patched + if target in self._patch_cache: + return + self._patch_cache.add(target) + + functions = [ + ("execute_and_notify_query_listener", "_statement_params"), + ("execute_async_and_notify_query_listener", "_statement_params"), + ] + + for func, param_name in functions: + try: + self._patch_with_statement_params(target, func, param_name=param_name) + except AttributeError: + if throw_on_patch_fail: # primarily used for testing + raise + # TODO: Log a warning, this probably means there was a breaking change in Snowpark/SnowflakeConnection + pass + + def _patch_with_statement_params( + self, target: object, function_name: str, param_name: str = "statement_params" + ) -> None: + func = getattr(target, function_name) + assert callable(func) + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + # Retrieve context level statement parameters + context_params = self._context_var.get(dict()) + if not context_params: + # Exit early if not in SnowML (decorator) context + return func(*args, **kwargs) + + # Extract any explicitly provided statement parameters + orig_kwargs = dict(kwargs) + in_params = kwargs.pop(param_name, None) or {} + + # Inject a special flag to statement parameters so we can filter out these patched logs if necessary + # Calls that include SnowML telemetry are tagged with "SNOWML_AUGMENT_TELEMETRY" + # and calls without SnowML telemetry are tagged with "SNOWML_AUTO_TELEMETRY" + if TelemetryField.KEY_PROJECT.value in in_params: + context_params["snowml_telemetry_type"] = _TelemetrySourceType.AUGMENT_TELEMETRY.value + else: + context_params["snowml_telemetry_type"] = _TelemetrySourceType.AUTO_TELEMETRY.value + + # Apply any explicitly provided statement parameters and result into function call + context_params.update(in_params) + kwargs[param_name] = context_params + + try: + return func(*args, **kwargs) + except TypeError as e: + if str(e).endswith(f"unexpected keyword argument '{param_name}'"): + # TODO: Log warning that this patch is invalid + # Unwrap function for future invocations + setattr(target, function_name, func) + return func(*args, **orig_kwargs) + else: + raise + + setattr(target, function_name, wrapper) + + def __getstate__(self) -> Dict[str, Any]: + return {} + + def __setstate__(self, state: Dict[str, Any]) -> None: + # unpickling does not call __init__ by default, do it manually here + self.__init__() # type: ignore[misc] + + +_patch_manager = _StatementParamsPatchManager() + + def get_statement_params( project: str, subproject: Optional[str] = None, class_name: Optional[str] = None ) -> Dict[str, Any]: @@ -375,7 +493,18 @@ def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: Dict[ obj._statement_params = statement_params # type: ignore[assignment] return obj + # Set up framework-level credit usage instrumentation + ctx = contextvars.copy_context() + _patch_manager.apply_patches() + + # This function should be executed with ctx.run() + def execute_func_with_statement_params() -> _ReturnValue: + _patch_manager.set_statement_params(statement_params) + result = func(*args, **kwargs) + return update_stmt_params_if_snowpark_df(result, statement_params) + # prioritize `conn_attr_name` over the active session + telemetry_enabled = True if conn_attr_name: # raise AttributeError if conn attribute does not exist in `self` conn = operator.attrgetter(conn_attr_name)(args[0]) @@ -387,22 +516,17 @@ def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: Dict[ else: try: active_session = next(iter(session._get_active_sessions())) - # server no default session + conn = active_session._conn._conn + telemetry_enabled = active_session.telemetry_enabled except snowpark_exceptions.SnowparkSessionException: - try: - return update_stmt_params_if_snowpark_df(func(*args, **kwargs), statement_params) - except Exception as e: - if isinstance(e, snowml_exceptions.SnowflakeMLException): - raise e.original_exception.with_traceback(e.__traceback__) from None - # suppress SnowparkSessionException from telemetry in the stack trace - raise e from None - - conn = active_session._conn._conn - if (not active_session.telemetry_enabled) or (conn is None): - try: - return update_stmt_params_if_snowpark_df(func(*args, **kwargs), statement_params) - except snowml_exceptions.SnowflakeMLException as e: - raise e.original_exception from e + conn = None + + if conn is None or not telemetry_enabled: + # Telemetry not enabled, just execute without our additional telemetry logic + try: + return ctx.run(execute_func_with_statement_params) + except snowml_exceptions.SnowflakeMLException as e: + raise e.original_exception from e # TODO(hayu): [SNOW-750287] Optimize telemetry client to a singleton. telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject_name) @@ -415,11 +539,11 @@ def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: Dict[ custom_tags=custom_tags, ) try: - res = func(*args, **kwargs) + return ctx.run(execute_func_with_statement_params) except Exception as e: if not isinstance(e, snowml_exceptions.SnowflakeMLException): # already handled via a nested decorated function - if hasattr(e, "_snowflake_ml_handled") and e._snowflake_ml_handled: + if getattr(e, "_snowflake_ml_handled", False): raise e if isinstance(e, snowpark_exceptions.SnowparkClientException): me = snowml_exceptions.SnowflakeMLException( @@ -438,8 +562,6 @@ def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: Dict[ raise me.original_exception from None else: raise me.original_exception from e - else: - return update_stmt_params_if_snowpark_df(res, statement_params) finally: telemetry.send_function_usage_telemetry(**telemetry_args) global _log_counter diff --git a/snowflake/ml/_internal/telemetry_test.py b/snowflake/ml/_internal/telemetry_test.py index 922300d7..15d6985a 100644 --- a/snowflake/ml/_internal/telemetry_test.py +++ b/snowflake/ml/_internal/telemetry_test.py @@ -1,13 +1,16 @@ import inspect +import pickle +import threading import time import traceback -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional from unittest import mock +import cloudpickle from absl.testing import absltest, parameterized from snowflake import connector -from snowflake.connector import telemetry as connector_telemetry +from snowflake.connector import cursor, telemetry as connector_telemetry from snowflake.ml._internal import env, telemetry as utils_telemetry from snowflake.ml._internal.exceptions import error_codes, exceptions from snowflake.snowpark import dataframe, session @@ -32,6 +35,8 @@ def setUp(self) -> None: self.mock_session._conn = self.mock_server_conn self.mock_server_conn._conn = self.mock_snowflake_conn self.mock_snowflake_conn._telemetry = self.mock_telemetry + self.mock_snowflake_conn._session_parameters = {} + self.mock_snowflake_conn.is_closed.return_value = False self.telemetry_type = f"{_SOURCE.lower()}_{utils_telemetry.TelemetryField.TYPE_FUNCTION_USAGE.value}" @mock.patch("snowflake.snowpark.session._get_active_sessions") @@ -548,6 +553,185 @@ def test_add_statement_params_custom_tags(self) -> None: self.assertIn(utils_telemetry.TelemetryField.KEY_CUSTOM_TAGS.value, result) self.assertEqual(result.get(utils_telemetry.TelemetryField.KEY_CUSTOM_TAGS.value), custom_tags) + def test_apply_statement_params_patch(self) -> None: + patch_manager = utils_telemetry._StatementParamsPatchManager() + + mock_cursor = absltest.mock.MagicMock(spec=cursor.SnowflakeCursor) + with mock.patch.object(self.mock_snowflake_conn, "cursor", return_value=mock_cursor): + server_conn = server_connection.ServerConnection({}, self.mock_snowflake_conn) + sess = session.Session(server_conn) + try: + patch_manager._patch_session(sess, throw_on_patch_fail=True) + except Exception as e: + self.fail(f"Patching failed with unexpected exception: {e}") + + def test_pickle_instrumented_function(self) -> None: + @utils_telemetry.send_api_usage_telemetry( + project=_PROJECT, + subproject="PICKLE", + ) + def _picklable_test_function(session: session.Session) -> None: + """Used for test_pickle_instrumented_function""" + session.sql("SELECT 1").collect() + + with self.assertRaises(pickle.PicklingError): + _ = cloudpickle.dumps(self.mock_session) + + self._do_internal_statement_params_test(_picklable_test_function) + try: + function_pickled = cloudpickle.dumps(_picklable_test_function) + except Exception as e: + self.fail(f"Pickling failed with unexpected exception: {e}") + + function_unpickled = cloudpickle.loads(function_pickled) + self._do_internal_statement_params_test(function_unpickled) + + def test_statement_params_internal_query(self) -> None: + # Create and decorate a test function that calls some SQL query + @utils_telemetry.send_api_usage_telemetry( + project=_PROJECT, + subproject=_SUBPROJECT, + ) + def dummy_function(session: session.Session) -> None: + session.sql("SELECT 1").collect() # Intentionally omit statement_params arg + + self._do_internal_statement_params_test(dummy_function) + + def test_statement_params_nested_internal_query(self) -> None: + @utils_telemetry.send_api_usage_telemetry( + project="INNER_PROJECT", + subproject=_SUBPROJECT, + ) + def inner_function(session: session.Session) -> None: + session.sql("SELECT 1").collect() # Intentionally omit statement_params arg + + @utils_telemetry.send_api_usage_telemetry( + project="OUTER_PROJECT", + subproject=_SUBPROJECT, + ) + def outer_function(session: session.Session) -> None: + inner_function(session) + + self._do_internal_statement_params_test(outer_function, expected_params={"project": "OUTER_PROJECT"}) + + def test_statement_params_internal_params_precedence(self) -> None: + @utils_telemetry.send_api_usage_telemetry( + project=_PROJECT, + subproject=_SUBPROJECT, + ) + def project_override(session: session.Session) -> None: + session.sql("SELECT 1").collect(statement_params={"project": "MY_OVERRIDE"}) + + self._do_internal_statement_params_test( + project_override, + expected_params={"project": "MY_OVERRIDE", "snowml_telemetry_type": "SNOWML_AUGMENT_TELEMETRY"}, + ) + + @utils_telemetry.send_api_usage_telemetry( + project=_PROJECT, + subproject=_SUBPROJECT, + ) + def telemetry_type_override(session: session.Session) -> None: + session.sql("SELECT 1").collect(statement_params={"snowml_telemetry_type": "user override"}) + + self._do_internal_statement_params_test( + telemetry_type_override, + expected_params={"snowml_telemetry_type": "user override"}, + ) + + def test_statement_params_multithreading(self) -> None: + query1 = "select 1" + query2 = "select 2" + + @utils_telemetry.send_api_usage_telemetry(project="PROJECT_1") + def test_function1(session: session.Session) -> None: + time.sleep(0.1) + session.sql(query1).collect() + + @utils_telemetry.send_api_usage_telemetry(project="PROJECT_2") + def test_function2(session: session.Session) -> None: + session.sql(query2).collect() + + # Set up a real Session with mocking starting at SnowflakeConnection + # Do this manually instead of using _do_internal_statement_params_test + # to make sure we're sharing a single cursor so that we don't erroneously pass + # the test just because each thread is using their own cursor. + mock_cursor = absltest.mock.MagicMock(spec=cursor.SnowflakeCursor) + with mock.patch.object(self.mock_snowflake_conn, "cursor", return_value=mock_cursor): + server_conn = server_connection.ServerConnection({}, self.mock_snowflake_conn) + sess = session.Session(server_conn) + with mock.patch.object(session, "_get_active_sessions", return_value={sess}): + thread1 = threading.Thread(target=test_function1, args=(sess,)) + thread2 = threading.Thread(target=test_function2, args=(sess,)) + + thread1.start() + thread2.start() + + thread1.join() + thread2.join() + + self.assertEqual(2, len(mock_cursor.execute.call_args_list)) + statement_params_by_query = { + call[0][0]: call.kwargs.get("_statement_params", {}) for call in mock_cursor.execute.call_args_list + } + + default_params = {"source": "SnowML", "snowml_telemetry_type": "SNOWML_AUTO_TELEMETRY"} + self.assertDictContainsSubset({**default_params, "project": "PROJECT_1"}, statement_params_by_query[query1]) + self.assertDictContainsSubset({**default_params, "project": "PROJECT_2"}, statement_params_by_query[query2]) + + def test_statement_params_external_function(self) -> None: + # Create and decorate a test function that calls some SQL query + @utils_telemetry.send_api_usage_telemetry( + project=_PROJECT, + subproject=_SUBPROJECT, + ) + def dummy_function(session: session.Session) -> None: + session.sql("SELECT 1").collect() + + def external_function(session: session.Session) -> None: + session.sql("SELECT 2").collect() + + # Set up a real Session with mocking starting at SnowflakeConnection + # Do this manually instead of using _do_internal_statement_params_test + # to make sure we're sharing a single cursor so that we don't erroneously pass + # the test just because we're using a fresh session + mock_cursor = absltest.mock.MagicMock(spec=cursor.SnowflakeCursor) + with mock.patch.object(self.mock_snowflake_conn, "cursor", return_value=mock_cursor): + server_conn = server_connection.ServerConnection({}, self.mock_snowflake_conn) + sess = session.Session(server_conn) + with mock.patch.object(session, "_get_active_sessions", return_value={sess}): + dummy_function(sess) + external_function(sess) + + call_statement_params = [ + call.kwargs.get("_statement_params", {}) for call in mock_cursor.execute.call_args_list + ] + self.assertEqual(2, len(call_statement_params)) + self.assertIn("source", call_statement_params[0].keys()) + self.assertIn("snowml_telemetry_type", call_statement_params[0].keys()) + self.assertNotIn("source", call_statement_params[1].keys()) + self.assertNotIn("snowml_telemetry_type", call_statement_params[1].keys()) + + def _do_internal_statement_params_test( + self, func: Callable[[session.Session], None], expected_params: Optional[Dict[str, str]] = None + ) -> None: + # Set up a real Session with mocking starting at SnowflakeConnection + mock_cursor = absltest.mock.MagicMock(spec=cursor.SnowflakeCursor) + with mock.patch.object(self.mock_snowflake_conn, "cursor", return_value=mock_cursor): + server_conn = server_connection.ServerConnection({}, self.mock_snowflake_conn) + sess = session.Session(server_conn) + with mock.patch.object(session, "_get_active_sessions", return_value={sess}): + func(sess) + + # Validate that the mock cursor received statement params + mock_cursor.execute.assert_called_once() + statement_params = mock_cursor.execute.call_args.kwargs.get("_statement_params", None) + self.assertIsNotNone(statement_params, "statement params not found in execute call") + + expected_dict = {"source": "SnowML", "snowml_telemetry_type": "SNOWML_AUTO_TELEMETRY"} + expected_dict.update(expected_params or {}) + self.assertDictContainsSubset(expected_dict, statement_params) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/_internal/utils/BUILD.bazel b/snowflake/ml/_internal/utils/BUILD.bazel index b918ba59..e3cf9367 100644 --- a/snowflake/ml/_internal/utils/BUILD.bazel +++ b/snowflake/ml/_internal/utils/BUILD.bazel @@ -17,6 +17,25 @@ py_library( srcs = ["import_utils.py"], ) +py_library( + name = "db_utils", + srcs = ["db_utils.py"], + deps = [ + ":query_result_checker", + ":sql_identifier", + ], +) + +py_test( + name = "db_utils_test", + srcs = ["db_utils_test.py"], + deps = [ + ":db_utils", + "//snowflake/ml/test_utils:mock_data_frame", + "//snowflake/ml/test_utils:mock_session", + ], +) + py_test( name = "import_utils_test", srcs = ["import_utils_test.py"], @@ -161,6 +180,7 @@ py_library( srcs = [ "table_manager.py", "//snowflake/ml/_internal/utils:formatting", + "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/_internal/utils:query_result_checker", ], ) diff --git a/snowflake/ml/_internal/utils/db_utils.py b/snowflake/ml/_internal/utils/db_utils.py new file mode 100644 index 00000000..0de2f4ac --- /dev/null +++ b/snowflake/ml/_internal/utils/db_utils.py @@ -0,0 +1,50 @@ +from enum import Enum +from typing import Any, Dict, Optional + +from snowflake.ml._internal.utils import query_result_checker, sql_identifier +from snowflake.snowpark import session + +MAX_IDENTIFIER_LENGTH = 255 + + +class SnowflakeDbObjectType(Enum): + TABLE = "TABLE" + WAREHOUSE = "WAREHOUSE" + + +def db_object_exists( + session: session.Session, + object_type: SnowflakeDbObjectType, + object_name: sql_identifier.SqlIdentifier, + *, + database_name: Optional[sql_identifier.SqlIdentifier] = None, + schema_name: Optional[sql_identifier.SqlIdentifier] = None, + statement_params: Optional[Dict[str, Any]] = None, +) -> bool: + """Check if object exists in database. + + Args: + session: Active Snowpark Session. + object_type: Type of object to search for. + object_name: Name of object to search for. + database_name: Optional database name to search in. Only used if both schema is also provided. + schema_name: Optional schema to search in. + statement_params: Optional set of statement_params to include with queries. + + Returns: + boolean indicating whether object exists. + """ + optional_in_clause = "" + if database_name and schema_name: + optional_in_clause = f" IN {database_name}.{schema_name}" + + result = ( + query_result_checker.SqlResultValidator( + session, + f"""SHOW {object_type.value}S LIKE '{object_name}'{optional_in_clause}""", + statement_params=statement_params, + ) + .has_column("name", allow_empty=True) # TODO: Check this is actually what is returned from server + .validate() + ) + return len(result) == 1 diff --git a/snowflake/ml/_internal/utils/db_utils_test.py b/snowflake/ml/_internal/utils/db_utils_test.py new file mode 100644 index 00000000..569c57b7 --- /dev/null +++ b/snowflake/ml/_internal/utils/db_utils_test.py @@ -0,0 +1,78 @@ +from typing import cast + +from absl.testing import absltest + +from snowflake.ml._internal.utils import db_utils, sql_identifier +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import Row, Session + + +class DbUtilsTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.test_db_name = sql_identifier.SqlIdentifier("SNOWML_OBSERVABILITY") + self.test_schema_name = sql_identifier.SqlIdentifier("METADATA") + + self.session = cast(Session, self.m_session) + + def test_warehouse_exists(self) -> None: + test_wh_name = sql_identifier.SqlIdentifier("test_wh") + self.m_session.add_mock_sql( + query=f"""SHOW WAREHOUSES LIKE '{test_wh_name}'""", + result=mock_data_frame.MockDataFrame([Row(name=test_wh_name)]), + ) + self.assertTrue(db_utils.db_object_exists(self.session, db_utils.SnowflakeDbObjectType.WAREHOUSE, test_wh_name)) + self.m_session.finalize() + + def test_warehouse_not_exists(self) -> None: + test_wh_name = sql_identifier.SqlIdentifier("test_wh") + self.m_session.add_mock_sql( + query=f"""SHOW WAREHOUSES LIKE '{test_wh_name}'""", + result=mock_data_frame.MockDataFrame([]), + ) + self.assertFalse( + db_utils.db_object_exists(self.session, db_utils.SnowflakeDbObjectType.WAREHOUSE, test_wh_name) + ) + self.m_session.finalize() + + def test_table_exists(self) -> None: + test_tbl_name = sql_identifier.SqlIdentifier("test_tbl") + test_db = sql_identifier.SqlIdentifier("test_db") + test_schema = sql_identifier.SqlIdentifier("test_schema") + self.m_session.add_mock_sql( + query=f"""SHOW TABLES LIKE '{test_tbl_name}' IN {test_db}.{test_schema}""", + result=mock_data_frame.MockDataFrame([Row(name=test_tbl_name)]), + ) + self.assertTrue( + db_utils.db_object_exists( + self.session, + db_utils.SnowflakeDbObjectType.TABLE, + test_tbl_name, + database_name=test_db, + schema_name=test_schema, + ) + ) + self.m_session.finalize() + + def test_table_not_exists(self) -> None: + test_tbl_name = sql_identifier.SqlIdentifier("test_tbl") + test_db = sql_identifier.SqlIdentifier("test_db") + test_schema = sql_identifier.SqlIdentifier("test_schema") + self.m_session.add_mock_sql( + query=f"""SHOW TABLES LIKE '{test_tbl_name}' IN {test_db}.{test_schema}""", + result=mock_data_frame.MockDataFrame([]), + ) + self.assertFalse( + db_utils.db_object_exists( + self.session, + db_utils.SnowflakeDbObjectType.TABLE, + test_tbl_name, + database_name=test_db, + schema_name=test_schema, + ) + ) + self.m_session.finalize() + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/_internal/utils/identifier.py b/snowflake/ml/_internal/utils/identifier.py index a3b0ce3d..26b85ee3 100644 --- a/snowflake/ml/_internal/utils/identifier.py +++ b/snowflake/ml/_internal/utils/identifier.py @@ -10,9 +10,11 @@ _SF_IDENTIFIER = f"({_SF_UNQUOTED_CASE_INSENSITIVE_IDENTIFIER}|{SF_QUOTED_IDENTIFIER})" SF_IDENTIFIER_RE = re.compile(_SF_IDENTIFIER) _SF_SCHEMA_LEVEL_OBJECT = ( - rf"(?:(?:(?P{_SF_IDENTIFIER})\.)?(?P{_SF_IDENTIFIER})\.)?(?P{_SF_IDENTIFIER})(?P.*)" + rf"(?:(?:(?P{_SF_IDENTIFIER})\.)?(?P{_SF_IDENTIFIER})\.)?(?P{_SF_IDENTIFIER})" ) +_SF_STAGE_PATH = rf"{_SF_SCHEMA_LEVEL_OBJECT}(?P.*)" _SF_SCHEMA_LEVEL_OBJECT_RE = re.compile(_SF_SCHEMA_LEVEL_OBJECT) +_SF_STAGE_PATH_RE = re.compile(_SF_STAGE_PATH) UNQUOTED_CASE_INSENSITIVE_RE = re.compile(f"^({_SF_UNQUOTED_CASE_INSENSITIVE_IDENTIFIER})$") UNQUOTED_CASE_SENSITIVE_RE = re.compile(f"^({_SF_UNQUOTED_CASE_SENSITIVE_IDENTIFIER})$") @@ -139,29 +141,61 @@ def rename_to_valid_snowflake_identifier(name: str) -> str: def parse_schema_level_object_identifier( + object_name: str, +) -> Tuple[Union[str, Any], Union[str, Any], Union[str, Any]]: + """Parse a string which starts with schema level object. + + Args: + object_name: A string starts with a schema level object path, which is in the format + '..'. Here, '', '' and '' are all snowflake identifiers. + + Returns: + A tuple of 3 strings in the form of (db, schema, object_name). + + Raises: + ValueError: If the id is invalid. + """ + res = _SF_SCHEMA_LEVEL_OBJECT_RE.fullmatch(object_name) + if not res: + raise ValueError( + "Invalid identifier because it does not follow the pattern. " + f"It should start with [[database.]schema.]object. Getting {object_name}" + ) + return ( + res.group("db"), + res.group("schema"), + res.group("object"), + ) + + +def parse_snowflake_stage_path( path: str, ) -> Tuple[Union[str, Any], Union[str, Any], Union[str, Any], Union[str, Any]]: - """Parse a string which starts with schema level object. + """Parse a string which represents a snowflake stage path. Args: - path: A string starts with a schema level object path, which is in the format '..'. - Here, '', '' and '' are all snowflake identifiers. + path: A string starts with a schema level object path, which is in the format + '..'. Here, '', '' and '' are all snowflake + identifiers. Returns: - A tuple of 4 strings in the form of (db, schema, object_name, others). 'db', 'schema', 'object_name' are parsed - from the schema level object and 'others' are all the content post to the object. + A tuple of 4 strings in the form of (db, schema, object_name, path). 'db', 'schema', 'object_name' are parsed + from the schema level object and 'path' are all the content post to the object. Raises: ValueError: If the id is invalid. """ - res = _SF_SCHEMA_LEVEL_OBJECT_RE.fullmatch(path) + res = _SF_STAGE_PATH_RE.fullmatch(path) if not res: - raise ValueError(f"Invalid identifier. It should start with database.schema.object. Getting {path}") + raise ValueError( + "Invalid identifier because it does not follow the pattern. " + f"It should start with [[database.]schema.]object. Getting {path}" + ) return ( res.group("db"), res.group("schema"), res.group("object"), - res.group("others"), + res.group("path"), ) @@ -175,8 +209,11 @@ def is_fully_qualified_name(name: str) -> bool: Returns: bool: True if the name is fully qualified, False otherwise. """ - res = parse_schema_level_object_identifier(name) - return res[0] is not None and res[1] is not None and res[2] is not None and not res[3] + try: + res = parse_schema_level_object_identifier(name) + return all(res) + except ValueError: + return False def get_schema_level_object_identifier( diff --git a/snowflake/ml/_internal/utils/identifier_test.py b/snowflake/ml/_internal/utils/identifier_test.py index 057ed92e..21f1bce2 100644 --- a/snowflake/ml/_internal/utils/identifier_test.py +++ b/snowflake/ml/_internal/utils/identifier_test.py @@ -5,6 +5,7 @@ SCHEMA_LEVEL_OBJECT_TEST_CASES = [ ("foo", False, None, None, "foo", ""), ("foo/", False, None, None, "foo", "/"), + ("foo-bar", False, None, None, "foo", "-bar"), ('"foo"', False, None, None, '"foo"', ""), ('"foo"/', False, None, None, '"foo"', "/"), ("foo/bar", False, None, None, "foo", "/bar"), @@ -107,15 +108,28 @@ def test_user_specified_quotes(self) -> None: self.assertEqual('"demo__task1"', identifier.concat_names(['"demo__"', "task1"])) self.assertEqual('"demo__task1"', identifier.concat_names(["demo__", '"task1"'])) - def test_parse_schema_level_object_identifier(self) -> None: + def test_parse_snowflake_stage_path(self) -> None: """Test if the schema level identifiers could be successfully parsed""" for test_case in SCHEMA_LEVEL_OBJECT_TEST_CASES: with self.subTest(): self.assertTupleEqual( - tuple(test_case[2:]), identifier.parse_schema_level_object_identifier(test_case[0]) + tuple(test_case[2:]), + identifier.parse_snowflake_stage_path(test_case[0]), ) + def test_parse_schema_level_object_identifier(self) -> None: + for test_case in SCHEMA_LEVEL_OBJECT_TEST_CASES: + with self.subTest(): + if test_case[5] != "": + with self.assertRaises(ValueError): + identifier.parse_schema_level_object_identifier(test_case[0]) + else: + self.assertTupleEqual( + tuple(test_case[2:5]), + identifier.parse_schema_level_object_identifier(test_case[0]), + ) + def test_get_schema_level_object_identifier(self) -> None: for test_case in SCHEMA_LEVEL_OBJECT_TEST_CASES: with self.subTest(): diff --git a/snowflake/ml/_internal/utils/snowflake_env.py b/snowflake/ml/_internal/utils/snowflake_env.py index 1dc41abe..d091f992 100644 --- a/snowflake/ml/_internal/utils/snowflake_env.py +++ b/snowflake/ml/_internal/utils/snowflake_env.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional, TypedDict, cast from packaging import version -from typing_extensions import Required +from typing_extensions import NotRequired, Required from snowflake.ml._internal.utils import query_result_checker from snowflake.snowpark import session @@ -52,7 +52,7 @@ def from_value(cls, value: str) -> "SnowflakeCloudType": class SnowflakeRegion(TypedDict): - region_group: Required[str] + region_group: NotRequired[str] snowflake_region: Required[str] cloud: Required[SnowflakeCloudType] region: Required[str] @@ -64,23 +64,33 @@ def get_regions( ) -> Dict[str, SnowflakeRegion]: res = ( query_result_checker.SqlResultValidator(sess, "SHOW REGIONS", statement_params=statement_params) - .has_column("region_group") .has_column("snowflake_region") .has_column("cloud") .has_column("region") .has_column("display_name") .validate() ) - return { - f"{r.region_group}.{r.snowflake_region}": SnowflakeRegion( - region_group=r.region_group, - snowflake_region=r.snowflake_region, - cloud=SnowflakeCloudType.from_value(r.cloud), - region=r.region, - display_name=r.display_name, - ) - for r in res - } + res_dict = {} + for r in res: + if hasattr(r, "region_group") and r.region_group: + key = f"{r.region_group}.{r.snowflake_region}" + res_dict[key] = SnowflakeRegion( + region_group=r.region_group, + snowflake_region=r.snowflake_region, + cloud=SnowflakeCloudType.from_value(r.cloud), + region=r.region, + display_name=r.display_name, + ) + else: + key = r.snowflake_region + res_dict[key] = SnowflakeRegion( + snowflake_region=r.snowflake_region, + cloud=SnowflakeCloudType.from_value(r.cloud), + region=r.region, + display_name=r.display_name, + ) + + return res_dict def get_current_region_id(sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None) -> str: diff --git a/snowflake/ml/_internal/utils/snowflake_env_test.py b/snowflake/ml/_internal/utils/snowflake_env_test.py index 850cc21f..f03be90d 100644 --- a/snowflake/ml/_internal/utils/snowflake_env_test.py +++ b/snowflake/ml/_internal/utils/snowflake_env_test.py @@ -29,6 +29,12 @@ def test_get_regions(self) -> None: session = mock_session.MockSession(conn=None, test_case=self) query = "SHOW REGIONS" sql_result = [ + Row( + snowflake_region="AWS_US_WEST_1", + cloud="aws", + region="us-west-1", + display_name="US West (Oregon)", + ), Row( region_group="PUBLIC", snowflake_region="AWS_US_WEST_2", @@ -55,6 +61,12 @@ def test_get_regions(self) -> None: actual_result = snowflake_env.get_regions(cast(Session, session)) self.assertDictEqual( { + "AWS_US_WEST_1": snowflake_env.SnowflakeRegion( + snowflake_region="AWS_US_WEST_1", + cloud=snowflake_env.SnowflakeCloudType.AWS, + region="us-west-1", + display_name="US West (Oregon)", + ), "PUBLIC.AWS_US_WEST_2": snowflake_env.SnowflakeRegion( region_group="PUBLIC", snowflake_region="AWS_US_WEST_2", diff --git a/snowflake/ml/_internal/utils/sql_identifier.py b/snowflake/ml/_internal/utils/sql_identifier.py index 5ba4510f..9f9a6ff4 100644 --- a/snowflake/ml/_internal/utils/sql_identifier.py +++ b/snowflake/ml/_internal/utils/sql_identifier.py @@ -84,7 +84,7 @@ def to_sql_identifiers(list_of_str: List[str], *, case_sensitive: bool = False) def parse_fully_qualified_name( name: str, ) -> Tuple[Optional[SqlIdentifier], Optional[SqlIdentifier], SqlIdentifier]: - db, schema, object, _ = identifier.parse_schema_level_object_identifier(name) + db, schema, object = identifier.parse_schema_level_object_identifier(name) assert name is not None, f"Unable parse the input name `{name}` as fully qualified." return ( diff --git a/snowflake/ml/_internal/utils/sql_identifier_test.py b/snowflake/ml/_internal/utils/sql_identifier_test.py index c47a85f0..6dd3000d 100644 --- a/snowflake/ml/_internal/utils/sql_identifier_test.py +++ b/snowflake/ml/_internal/utils/sql_identifier_test.py @@ -59,6 +59,12 @@ def test_parse_fully_qualified_name(self) -> None: ), ) + with self.assertRaises(ValueError): + sql_identifier.parse_fully_qualified_name('db."schema".abc.def') + + with self.assertRaises(ValueError): + sql_identifier.parse_fully_qualified_name("abc-def") + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/_internal/utils/table_manager.py b/snowflake/ml/_internal/utils/table_manager.py index 59d97da5..265e6d44 100644 --- a/snowflake/ml/_internal/utils/table_manager.py +++ b/snowflake/ml/_internal/utils/table_manager.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Optional, Tuple from snowflake import snowpark -from snowflake.ml._internal.utils import formatting, query_result_checker +from snowflake.ml._internal.utils import formatting, identifier, query_result_checker +from snowflake.snowpark import types """Table_manager is a set of utils that helps create tables. @@ -104,3 +105,20 @@ def get_table_schema(session: snowpark.Session, table_name: str, qualified_schem for row in result: schema_dict[row["name"]] = row["type"] return schema_dict + + +def get_table_schema_types( + session: snowpark.Session, + database: str, + schema: str, + table_name: str, +) -> Dict[str, types.DataType]: + fully_qualified_table_name = identifier.get_schema_level_object_identifier( + db=database, schema=schema, object_name=table_name + ) + struct_fields: List[types.StructField] = session.table(fully_qualified_table_name).schema.fields + + schema_dict: Dict[str, types.DataType] = {} + for field in struct_fields: + schema_dict[field.name] = field.datatype + return schema_dict diff --git a/snowflake/ml/_internal/utils/uri.py b/snowflake/ml/_internal/utils/uri.py index fa891105..6b4794ae 100644 --- a/snowflake/ml/_internal/utils/uri.py +++ b/snowflake/ml/_internal/utils/uri.py @@ -53,7 +53,7 @@ def get_uri_scheme(uri: str) -> str: def get_uri_from_snowflake_stage_path(stage_path: str) -> str: """Generates a URI from Snowflake stage path.""" assert stage_path.startswith("@") - (db, schema, stage, path) = identifier.parse_schema_level_object_identifier( + (db, schema, stage, path) = identifier.parse_snowflake_stage_path( posixpath.normpath(identifier.remove_prefix(stage_path, "@")) ) return urlunparse( @@ -70,7 +70,7 @@ def get_uri_from_snowflake_stage_path(stage_path: str) -> str: def get_stage_and_path(stage_path: str) -> Tuple[str, str]: assert stage_path.startswith("@"), f"stage path should start with @, actual: {stage_path}" - (db, schema, stage, path) = identifier.parse_schema_level_object_identifier( + (db, schema, stage, path) = identifier.parse_snowflake_stage_path( posixpath.normpath(identifier.remove_prefix(stage_path, "@")) ) full_qualified_stage = "@" + identifier.get_schema_level_object_identifier(db, schema, stage) diff --git a/snowflake/ml/data/BUILD.bazel b/snowflake/ml/data/BUILD.bazel index e71ce8f9..5c5b2a8b 100644 --- a/snowflake/ml/data/BUILD.bazel +++ b/snowflake/ml/data/BUILD.bazel @@ -1,4 +1,4 @@ -load("//bazel:py_rules.bzl", "py_library", "py_test") +load("//bazel:py_rules.bzl", "py_library", "py_package", "py_test") package(default_visibility = ["//visibility:public"]) @@ -25,8 +25,8 @@ py_library( ) py_library( - name = "torch_dataset", - srcs = ["torch_dataset.py"], + name = "torch_utils", + srcs = ["torch_utils.py"], ) py_library( @@ -34,9 +34,10 @@ py_library( srcs = ["data_connector.py"], deps = [ ":data_ingestor", - ":torch_dataset", + ":torch_utils", "//snowflake/ml/_internal:telemetry", "//snowflake/ml/data/_internal:arrow_ingestor", + "//snowflake/ml/modeling/_internal:constants", ], ) @@ -57,3 +58,12 @@ py_library( ":data_source", ], ) + +py_package( + name = "data_pkg", + packages = ["snowflake.ml"], + deps = [ + ":data", + "//snowflake/ml/dataset", + ], +) diff --git a/snowflake/ml/data/data_connector.py b/snowflake/ml/data/data_connector.py index 777c515c..48730284 100644 --- a/snowflake/ml/data/data_connector.py +++ b/snowflake/ml/data/data_connector.py @@ -1,3 +1,4 @@ +import os from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Type, TypeVar import numpy.typing as npt @@ -7,6 +8,10 @@ from snowflake.ml._internal import telemetry from snowflake.ml.data import data_ingestor, data_source from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor +from snowflake.ml.modeling._internal.constants import ( + IN_ML_RUNTIME_ENV_VAR, + USE_OPTIMIZED_DATA_INGESTOR, +) if TYPE_CHECKING: import pandas as pd @@ -142,32 +147,41 @@ def to_torch_datapipe( Returns: A Pytorch iterable datapipe that yield data. """ - from torch.utils.data.datapipes import iter as torch_iter + from snowflake.ml.data import torch_utils - return torch_iter.IterableWrapper( # type: ignore[no-untyped-call] - self._ingestor.to_batches(batch_size, shuffle, drop_last_batch) + return torch_utils.TorchDataPipeWrapper( + self._ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last_batch ) @telemetry.send_api_usage_telemetry( project=_PROJECT, subproject_extractor=lambda self: type(self).__name__, - func_params_to_log=["shuffle"], + func_params_to_log=["batch_size", "shuffle", "drop_last_batch"], ) - def to_torch_dataset(self, *, shuffle: bool = False) -> "torch_data.IterableDataset": # type: ignore[type-arg] + def to_torch_dataset( + self, *, batch_size: int = 1, shuffle: bool = False, drop_last_batch: bool = True + ) -> "torch_data.IterableDataset": # type: ignore[type-arg] """Transform the Snowflake data into a PyTorch Iterable Dataset to be used with a DataLoader. Return a PyTorch Dataset which iterates on rows of data. Args: + batch_size: It specifies the size of each data batch which will be yielded in the result dataset. + Batching is pushed down to data ingestion level which may be more performant than DataLoader + batching. shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and rows in each file will also be shuffled. + drop_last_batch: Whether the last batch of data should be dropped. If set to be true, + then the last batch will get dropped if its size is smaller than the given batch_size. Returns: A PyTorch Iterable Dataset that yields data. """ - from snowflake.ml.data import torch_dataset + from snowflake.ml.data import torch_utils - return torch_dataset.TorchDataset(self._ingestor, shuffle) + return torch_utils.TorchDatasetWrapper( + self._ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last_batch + ) @telemetry.send_api_usage_telemetry( project=_PROJECT, @@ -184,3 +198,15 @@ def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame": A Pandas DataFrame. """ return self._ingestor.to_pandas(limit) + + +# Switch to use Runtime's Data Ingester if running in ML runtime +# Fail silently if the data ingester is not found +if os.getenv(IN_ML_RUNTIME_ENV_VAR) and os.getenv(USE_OPTIMIZED_DATA_INGESTOR): + try: + from runtime_external_entities import get_ingester_class + + DataConnector.DEFAULT_INGESTOR_CLASS = get_ingester_class() + except ImportError: + """Runtime Default Ingester not found, ignore""" + pass diff --git a/snowflake/ml/data/data_connector_test.py b/snowflake/ml/data/data_connector_test.py index ddc6809f..ec63b434 100644 --- a/snowflake/ml/data/data_connector_test.py +++ b/snowflake/ml/data/data_connector_test.py @@ -45,7 +45,51 @@ def test_to_torch_datapipe(self) -> None: if col != "col3": self.assertIsInstance(tensor, torch.Tensor) - def test_to_torch_dataset(self) -> None: + # Ensure iterating through a second time (e.g. second epoch) works + count2 = 0 + for batch in dl: + np.testing.assert_array_equal(batch["col1"].numpy(), expected_res[count2]["col1"]) + np.testing.assert_array_equal(batch["col2"].numpy(), expected_res[count2]["col2"]) + np.testing.assert_array_equal(batch["col3"], expected_res[count2]["col3"]) + count2 += 1 + self.assertEqual(count2, len(expected_res)) + + def test_to_torch_datapipe_multiprocessing(self) -> None: + dp = self._sut.to_torch_datapipe(batch_size=2, shuffle=False, drop_last_batch=True) + + # FIXME: This test runs pretty slowly, probably due to multiprocessing overhead + # Make sure dataset works with num_workers > 0 (and doesn't duplicate data) + self.assertEqual( + len(list(torch_data.DataLoader(dp, batch_size=None, num_workers=2))), + 3, + ) + + def test_to_torch_dataset_native_batch(self) -> None: + expected_res = [ + {"col1": np.array([0, 1]), "col2": np.array([10, 11]), "col3": ["a", "ab"]}, + {"col1": np.array([2, 3]), "col2": np.array([12, 13]), "col3": ["abc", "m"]}, + {"col1": np.array([4, 5]), "col2": np.array([14, np.NaN]), "col3": ["mn", "mnm"]}, + ] + ds = self._sut.to_torch_dataset(batch_size=2, shuffle=False, drop_last_batch=True) + count = 0 + loader = torch_data.DataLoader(ds, batch_size=None) + for batch in loader: + np.testing.assert_array_equal(batch["col1"], expected_res[count]["col1"]) # type: ignore[arg-type] + np.testing.assert_array_equal(batch["col2"], expected_res[count]["col2"]) # type: ignore[arg-type] + np.testing.assert_array_equal(batch["col3"], expected_res[count]["col3"]) # type: ignore[arg-type] + count += 1 + self.assertEqual(count, len(expected_res)) + + # Ensure iterating through a second time (e.g. second epoch) works + count2 = 0 + for batch in loader: + np.testing.assert_array_equal(batch["col1"].numpy(), expected_res[count2]["col1"]) # type: ignore[arg-type] + np.testing.assert_array_equal(batch["col2"].numpy(), expected_res[count2]["col2"]) # type: ignore[arg-type] + np.testing.assert_array_equal(batch["col3"], expected_res[count2]["col3"]) # type: ignore[arg-type] + count2 += 1 + self.assertEqual(count2, len(expected_res)) + + def test_to_torch_dataset_loader_batch(self) -> None: expected_res = [ {"col1": np.array([0, 1]), "col2": np.array([10, 11]), "col3": ["a", "ab"]}, {"col1": np.array([2, 3]), "col2": np.array([12, 13]), "col3": ["abc", "m"]}, @@ -53,13 +97,23 @@ def test_to_torch_dataset(self) -> None: ] ds = self._sut.to_torch_dataset(shuffle=False) count = 0 - for batch in torch_data.DataLoader(ds, batch_size=2, shuffle=False, drop_last=True): + loader = torch_data.DataLoader(ds, batch_size=2, shuffle=False, drop_last=True) + for batch in loader: np.testing.assert_array_equal(batch["col1"], expected_res[count]["col1"]) # type: ignore[arg-type] np.testing.assert_array_equal(batch["col2"], expected_res[count]["col2"]) # type: ignore[arg-type] np.testing.assert_array_equal(batch["col3"], expected_res[count]["col3"]) # type: ignore[arg-type] count += 1 self.assertEqual(count, len(expected_res)) + # Ensure iterating through a second time (e.g. second epoch) works + count2 = 0 + for batch in loader: + np.testing.assert_array_equal(batch["col1"].numpy(), expected_res[count2]["col1"]) # type: ignore[arg-type] + np.testing.assert_array_equal(batch["col2"].numpy(), expected_res[count2]["col2"]) # type: ignore[arg-type] + np.testing.assert_array_equal(batch["col3"], expected_res[count2]["col3"]) # type: ignore[arg-type] + count2 += 1 + self.assertEqual(count2, len(expected_res)) + def test_to_torch_dataset_multiprocessing(self) -> None: ds = self._sut.to_torch_dataset(shuffle=False) @@ -89,6 +143,15 @@ def test_to_tf_dataset(self) -> None: count += 1 self.assertEqual(count, len(expected_res)) + # Ensure iterating through a second time (e.g. second epoch) works + count2 = 0 + for batch in dp: + np.testing.assert_array_equal(batch["col1"].numpy(), expected_res[count2]["col1"]) + np.testing.assert_array_equal(batch["col2"].numpy(), expected_res[count2]["col2"]) + np.testing.assert_array_equal(batch["col3"].numpy(), expected_res[count2]["col3"]) + count2 += 1 + self.assertEqual(count2, len(expected_res)) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/data/torch_dataset.py b/snowflake/ml/data/torch_dataset.py deleted file mode 100644 index bc11849f..00000000 --- a/snowflake/ml/data/torch_dataset.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Any, Dict, Iterator - -import torch.utils.data - -from snowflake.ml.data import data_ingestor - - -class TorchDataset(torch.utils.data.IterableDataset[Dict[str, Any]]): - """Implementation of PyTorch IterableDataset""" - - def __init__(self, ingestor: data_ingestor.DataIngestor, shuffle: bool = False) -> None: - """Not intended for direct usage. Use DataConnector.to_torch_dataset() instead""" - self._ingestor = ingestor - self._shuffle = shuffle - - def __iter__(self) -> Iterator[Dict[str, Any]]: - max_idx = 0 - filter_idx = 0 - worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: - max_idx = worker_info.num_workers - 1 - filter_idx = worker_info.id - - counter = 0 - for batch in self._ingestor.to_batches(batch_size=1, shuffle=self._shuffle, drop_last_batch=False): - # Skip indices during multi-process data loading to prevent data duplication - if counter == filter_idx: - yield {k: v.item() for k, v in batch.items()} - - if counter < max_idx: - counter += 1 - else: - counter = 0 diff --git a/snowflake/ml/data/torch_utils.py b/snowflake/ml/data/torch_utils.py new file mode 100644 index 00000000..25dc6d17 --- /dev/null +++ b/snowflake/ml/data/torch_utils.py @@ -0,0 +1,68 @@ +from typing import Any, Dict, Iterator, List, Union + +import numpy as np +import numpy.typing as npt +import torch.utils.data + +from snowflake.ml.data import data_ingestor + + +class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]): + """Wrap a DataIngestor into a PyTorch IterableDataset""" + + def __init__( + self, + ingestor: data_ingestor.DataIngestor, + *, + batch_size: int, + shuffle: bool = False, + drop_last: bool = False, + squeeze_outputs: bool = True + ) -> None: + """Not intended for direct usage. Use DataConnector.to_torch_dataset() instead""" + self._ingestor = ingestor + self._batch_size = batch_size + self._shuffle = shuffle + self._drop_last = drop_last + self._squeeze_outputs = squeeze_outputs + + def __iter__(self) -> Iterator[Dict[str, Union[npt.NDArray[Any], List[Any]]]]: + max_idx = 0 + filter_idx = 0 + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + max_idx = worker_info.num_workers - 1 + filter_idx = worker_info.id + + if self._shuffle and worker_info is not None: + raise RuntimeError("Dataset shuffling not currently supported with multithreading") + + counter = 0 + for batch in self._ingestor.to_batches( + batch_size=self._batch_size, shuffle=self._shuffle, drop_last_batch=self._drop_last + ): + # Skip indices during multi-process data loading to prevent data duplication + if counter == filter_idx: + # Basic preprocessing on batch values: squeeze away extra dimensions + # and convert object arrays (e.g. strings) to lists + if self._squeeze_outputs: + yield { + k: (v.squeeze().tolist() if v.dtype == np.object_ else v.squeeze()) for k, v in batch.items() + } + else: + yield batch # type: ignore[misc] + + if counter < max_idx: + counter += 1 + else: + counter = 0 + + +class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[Dict[str, Any]]): + """Wrap a DataIngestor into a PyTorch IterDataPipe""" + + def __init__( + self, ingestor: data_ingestor.DataIngestor, *, batch_size: int, shuffle: bool = False, drop_last: bool = False + ) -> None: + """Not intended for direct usage. Use DataConnector.to_torch_datapipe() instead""" + super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, squeeze_outputs=False) diff --git a/snowflake/ml/dataset/dataset.py b/snowflake/ml/dataset/dataset.py index aa3914a4..400f561e 100644 --- a/snowflake/ml/dataset/dataset.py +++ b/snowflake/ml/dataset/dataset.py @@ -472,9 +472,7 @@ def _load_from_lineage_node(session: snowpark.Session, name: str, version: str) def _get_schema_level_identifier(session: snowpark.Session, dataset_name: str) -> Tuple[str, str, str]: """Resolve a dataset name into a validated schema-level location identifier""" - db, schema, object_name, others = identifier.parse_schema_level_object_identifier(dataset_name) - if others: - raise ValueError(f"Invalid identifier: unexpected '{others}'") + db, schema, object_name = identifier.parse_schema_level_object_identifier(dataset_name) db = db or session.get_current_database() schema = schema or session.get_current_schema() return str(db), str(schema), str(object_name) diff --git a/snowflake/ml/feature_store/feature_store.py b/snowflake/ml/feature_store/feature_store.py index 4b4ff2dc..f9018c9c 100644 --- a/snowflake/ml/feature_store/feature_store.py +++ b/snowflake/ml/feature_store/feature_store.py @@ -604,7 +604,7 @@ def create_col_desc(col: StructField) -> str: logger.info(f"Registered FeatureView {feature_view.name}/{version} successfully.") return self.get_feature_view(feature_view.name, str(version)) - @dispatch_decorator() + @overload def update_feature_view( self, name: str, @@ -613,13 +613,37 @@ def update_feature_view( refresh_freq: Optional[str] = None, warehouse: Optional[str] = None, desc: Optional[str] = None, + ) -> FeatureView: + ... + + @overload + def update_feature_view( + self, + name: FeatureView, + version: Optional[str] = None, + *, + refresh_freq: Optional[str] = None, + warehouse: Optional[str] = None, + desc: Optional[str] = None, + ) -> FeatureView: + ... + + @dispatch_decorator() # type: ignore[misc] + def update_feature_view( + self, + name: Union[FeatureView, str], + version: Optional[str] = None, + *, + refresh_freq: Optional[str] = None, + warehouse: Optional[str] = None, + desc: Optional[str] = None, ) -> FeatureView: """Update a registered feature view. Check feature_view.py for which fields are allowed to be updated after registration. Args: - name: name of the FeatureView to be updated. - version: version of the FeatureView to be updated. + name: FeatureView object or name to suspend. + version: Optional version of feature view. Must set when argument feature_view is a str. refresh_freq: updated refresh frequency. warehouse: updated warehouse. desc: description of feature view. @@ -661,7 +685,7 @@ def update_feature_view( SnowflakeMLException: [RuntimeError] If FeatureView is not managed and refresh_freq is defined. SnowflakeMLException: [RuntimeError] Failed to update feature view. """ - feature_view = self.get_feature_view(name=name, version=version) + feature_view = self._validate_feature_view_name_and_version_input(name, version) new_desc = desc if desc is not None else feature_view.desc if feature_view.status == FeatureViewStatus.STATIC: @@ -696,7 +720,7 @@ def update_feature_view( f"Update feature view {feature_view.name}/{feature_view.version} failed: {e}" ), ) from e - return self.get_feature_view(name=name, version=version) + return self.get_feature_view(name=feature_view.name, version=str(feature_view.version)) @overload def read_feature_view(self, feature_view: str, version: str) -> DataFrame: @@ -2121,7 +2145,7 @@ def _get_fully_qualified_name(self, name: Union[SqlIdentifier, str]) -> str: if "." not in name: return f"{self._config.full_schema_path}.{name}" - db_name, schema_name, object_name, _ = identifier.parse_schema_level_object_identifier(name) + db_name, schema_name, object_name = identifier.parse_schema_level_object_identifier(name) return "{}.{}.{}".format( db_name or self._config.database, schema_name or self._config.schema, @@ -2186,11 +2210,7 @@ def _optimized_find_feature_views( if len(fv_maps.keys()) == 0: return self._session.create_dataframe([], schema=_LIST_FEATURE_VIEW_SCHEMA) - filters = ( - [lambda d: d["entityName"].startswith(feature_view_name.resolved())] # type: ignore[union-attr] - if feature_view_name - else None - ) + filters = [lambda d: d["entityName"].startswith(feature_view_name.resolved())] if feature_view_name else None res = self._lookup_tagged_objects(self._get_entity_name(entity_name), filters) output_values: List[List[Any]] = [] @@ -2281,16 +2301,20 @@ def find_and_compose_entity(name: str) -> Entity: timestamp_col=timestamp_col, desc=desc, version=version, - status=FeatureViewStatus(row["scheduling_state"]) - if len(row["scheduling_state"]) > 0 - else FeatureViewStatus.MASKED, + status=( + FeatureViewStatus(row["scheduling_state"]) + if len(row["scheduling_state"]) > 0 + else FeatureViewStatus.MASKED + ), feature_descs=self._fetch_column_descs("DYNAMIC TABLE", fv_name), refresh_freq=row["target_lag"], database=self._config.database.identifier(), schema=self._config.schema.identifier(), - warehouse=SqlIdentifier(row["warehouse"], case_sensitive=True).identifier() - if len(row["warehouse"]) > 0 - else None, + warehouse=( + SqlIdentifier(row["warehouse"], case_sensitive=True).identifier() + if len(row["warehouse"]) > 0 + else None + ), refresh_mode=row["refresh_mode"], refresh_mode_reason=row["refresh_mode_reason"], owner=row["owner"], diff --git a/snowflake/ml/feature_store/feature_view.py b/snowflake/ml/feature_store/feature_view.py index d697f340..021ecd8b 100644 --- a/snowflake/ml/feature_store/feature_view.py +++ b/snowflake/ml/feature_store/feature_view.py @@ -706,7 +706,7 @@ def to_df(self, session: Optional[Session] = None) -> DataFrame: >>> ).attach_feature_desc({"AGE": "my age", "TITLE": '"my title"'}) >>> fv = fs.register_feature_view(draft_fv, '1.0') - fv.to_df().show() + >>> fv.to_df().show() ----------------------------------------------------------------... |"NAME" |"ENTITIES" |"TIMESTAMP_COL" |"DESC" | ----------------------------------------------------------------... @@ -801,7 +801,7 @@ def _load_from_compact_repr(session: Session, serialized_repr: str) -> Union[Fea @staticmethod def _load_from_lineage_node(session: Session, name: str, version: str) -> FeatureView: - db_name, feature_store_name, feature_view_name, _ = identifier.parse_schema_level_object_identifier(name) + db_name, feature_store_name, feature_view_name = identifier.parse_schema_level_object_identifier(name) session_warehouse = session.get_current_warehouse() diff --git a/snowflake/ml/fileset/embedded_stage_fs.py b/snowflake/ml/fileset/embedded_stage_fs.py index dc05f896..4826b664 100644 --- a/snowflake/ml/fileset/embedded_stage_fs.py +++ b/snowflake/ml/fileset/embedded_stage_fs.py @@ -35,7 +35,7 @@ def __init__( **kwargs: Any, ) -> None: - (db, schema, object_name, _) = identifier.parse_schema_level_object_identifier(name) + (db, schema, object_name) = identifier.parse_schema_level_object_identifier(name) self._name = name # TODO: Require or resolve FQN self._domain = domain diff --git a/snowflake/ml/fileset/fileset.py b/snowflake/ml/fileset/fileset.py index a0619ecf..ef337b9d 100644 --- a/snowflake/ml/fileset/fileset.py +++ b/snowflake/ml/fileset/fileset.py @@ -538,7 +538,7 @@ def _validate_target_stage_loc(snowpark_session: snowpark.Session, target_stage_ original_exception=fileset_errors.FileSetLocationError('FileSet location should start with "@".'), ) try: - db, schema, stage, _ = identifier.parse_schema_level_object_identifier(target_stage_loc[1:]) + db, schema, stage, _ = identifier.parse_snowflake_stage_path(target_stage_loc[1:]) if db is None or schema is None: raise ValueError("The stage path should be in the form '@../*'") df_stages = snowpark_session.sql(f"Show stages like '{stage}' in SCHEMA {db}.{schema}") diff --git a/snowflake/ml/fileset/sfcfs.py b/snowflake/ml/fileset/sfcfs.py index d242a6d3..f83a2757 100644 --- a/snowflake/ml/fileset/sfcfs.py +++ b/snowflake/ml/fileset/sfcfs.py @@ -15,6 +15,7 @@ from snowflake.ml._internal.utils import identifier from snowflake.ml.fileset import stage_fs from snowflake.ml.utils import connection_params +from snowflake.snowpark import context, exceptions as snowpark_exceptions PROTOCOL_NAME = "sfc" @@ -84,7 +85,7 @@ def __init__( """ if kwargs.get(_RECREATE_FROM_SERIALIZED): try: - snowpark_session = self._create_default_session() + snowpark_session = self._get_default_session() except Exception as e: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.SNOWML_DESERIALIZATION_FAILED, @@ -103,7 +104,7 @@ def __init__( super().__init__(**kwargs) - def _create_default_session(self) -> snowpark.Session: + def _get_default_session(self) -> snowpark.Session: """Create a Snowpark Session from default login options. Returns: @@ -114,6 +115,11 @@ def _create_default_session(self) -> snowpark.Session: ValueError: Snowflake Connection could not be created. """ + try: + return context.get_active_session() + except snowpark_exceptions.SnowparkSessionException: + pass + try: snowflake_config = connection_params.SnowflakeLoginOptions() except Exception as e: @@ -328,7 +334,7 @@ def _parse_file_path(cls, path: str) -> _SFFilePath: ), ) try: - res = identifier.parse_schema_level_object_identifier(path[1:]) + res = identifier.parse_snowflake_stage_path(path[1:]) if res[1] is None or res[0] is None or (res[3] and not res[3].startswith("/")): raise ValueError("Invalid path. Missing database or schema identifier.") logging.debug(f"Parsed path: {res}") diff --git a/snowflake/ml/fileset/sfcfs_test.py b/snowflake/ml/fileset/sfcfs_test.py index 2e0b218e..afe6a51a 100644 --- a/snowflake/ml/fileset/sfcfs_test.py +++ b/snowflake/ml/fileset/sfcfs_test.py @@ -244,7 +244,7 @@ def test_fs_serializability(self) -> None: assert sffs_deserialized._conn is not None assert sffs_deserialized._kwargs == kwargs_dict - def test_create_default_session_exceptions(self) -> None: + def test_get_default_session_exceptions(self) -> None: """Tests that correct exceptions are raised when the function fails to create a session. Mocks the two session creation functions called by _create_default_connection individually. """ @@ -254,13 +254,13 @@ def test_create_default_session_exceptions(self) -> None: "snowflake.ml.fileset.sfcfs.connection_params.SnowflakeLoginOptions", side_effect=Exception("Error message"), ): - sffs._create_default_session() + sffs._get_default_session() with self.assertRaises(ValueError): with absltest.mock.patch( "snowflake.snowpark.Session.SessionBuilder.create", side_effect=Exception("Error message") ): - sffs._create_default_session() + sffs._get_default_session() def test_set_state_bad_state_dict(self) -> None: """When deserializing, the state dictionary requires a kwargs key that corresponds to a dictionary.""" diff --git a/snowflake/ml/model/_client/model/model_version_impl.py b/snowflake/ml/model/_client/model/model_version_impl.py index e8ae6c93..97cc93e0 100644 --- a/snowflake/ml/model/_client/model/model_version_impl.py +++ b/snowflake/ml/model/_client/model/model_version_impl.py @@ -306,6 +306,23 @@ def _get_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]: statement_params=statement_params, ) + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def get_model_objective(self) -> model_types.ModelObjective: + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + return self._model_ops.get_model_objective( + database_name=None, + schema_name=None, + model_name=self._model_name, + version_name=self._version_name, + statement_params=statement_params, + ) + @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, @@ -606,8 +623,8 @@ def _load_from_lineage_node(session: Session, name: str, version: str) -> "Model "image_repo_database", "image_repo_schema", "image_repo", - "image_name", "gpu_requests", + "num_workers", ], ) def create_service( @@ -617,11 +634,10 @@ def create_service( image_build_compute_pool: Optional[str] = None, service_compute_pool: str, image_repo: str, - image_name: Optional[str] = None, ingress_enabled: bool = False, - min_instances: int = 1, max_instances: int = 1, gpu_requests: Optional[str] = None, + num_workers: Optional[int] = None, force_rebuild: bool = False, build_external_access_integration: str, ) -> str: @@ -635,12 +651,12 @@ def create_service( service_compute_pool: The name of the compute pool used to run the inference service. image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database or schema of the model will be used. - image_name: The name of the model inference image. Use a generated name if None. ingress_enabled: Whether to enable ingress. - min_instances: The minimum number of inference service instances to run. max_instances: The maximum number of inference service instances to run. gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU if None. + num_workers: The number of workers (replicas of models) to run the inference service. + Auto determined if None. force_rebuild: Whether to force a model inference image rebuild. build_external_access_integration: The external access integration for image build. @@ -670,11 +686,10 @@ def create_service( image_repo_database_name=image_repo_db_id, image_repo_schema_name=image_repo_schema_id, image_repo_name=image_repo_id, - image_name=sql_identifier.SqlIdentifier(image_name) if image_name else None, ingress_enabled=ingress_enabled, - min_instances=min_instances, max_instances=max_instances, gpu_requests=gpu_requests, + num_workers=num_workers, force_rebuild=force_rebuild, build_external_access_integration=sql_identifier.SqlIdentifier(build_external_access_integration), statement_params=statement_params, diff --git a/snowflake/ml/model/_client/model/model_version_impl_test.py b/snowflake/ml/model/_client/model/model_version_impl_test.py index 9b4e449e..e17c588f 100644 --- a/snowflake/ml/model/_client/model/model_version_impl_test.py +++ b/snowflake/ml/model/_client/model/model_version_impl_test.py @@ -223,6 +223,21 @@ def test_get_functions(self) -> None: statement_params=mock.ANY, ) + def test_get_model_objective(self) -> None: + with mock.patch.object( + self.m_mv._model_ops, + attribute="get_model_objective", + return_value=model_types.ModelObjective.REGRESSION, + ) as mock_get_model_objective: + self.assertEqual(model_types.ModelObjective.REGRESSION, self.m_mv.get_model_objective()) + mock_get_model_objective.assert_called_once_with( + database_name=None, + schema_name=None, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + def test_run(self) -> None: m_df = mock_data_frame.MockDataFrame() m_methods = [ @@ -714,10 +729,9 @@ def test_create_service(self) -> None: image_build_compute_pool="IMAGE_BUILD_COMPUTE_POOL", service_compute_pool="SERVICE_COMPUTE_POOL", image_repo="IMAGE_REPO", - image_name="IMAGE_NAME", - min_instances=2, max_instances=3, gpu_requests="GPU", + num_workers=1, force_rebuild=True, build_external_access_integration="EAI", ) @@ -734,11 +748,44 @@ def test_create_service(self) -> None: image_repo_database_name=None, image_repo_schema_name=None, image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), - image_name=sql_identifier.SqlIdentifier("IMAGE_NAME"), ingress_enabled=False, - min_instances=2, max_instances=3, - gpu_requests=sql_identifier.SqlIdentifier("GPU"), + gpu_requests="GPU", + num_workers=1, + force_rebuild=True, + build_external_access_integration=sql_identifier.SqlIdentifier("EAI"), + statement_params=mock.ANY, + ) + + def test_create_service_same_pool(self) -> None: + with mock.patch.object(self.m_mv._service_ops, "create_service") as mock_create_service: + self.m_mv.create_service( + service_name="SERVICE", + service_compute_pool="SERVICE_COMPUTE_POOL", + image_repo="IMAGE_REPO", + max_instances=3, + gpu_requests="GPU", + num_workers=1, + force_rebuild=True, + build_external_access_integration="EAI", + ) + mock_create_service.assert_called_once_with( + database_name=None, + schema_name=None, + model_name=sql_identifier.SqlIdentifier(self.m_mv.model_name), + version_name=sql_identifier.SqlIdentifier(self.m_mv.version_name), + service_database_name=None, + service_schema_name=None, + service_name=sql_identifier.SqlIdentifier("SERVICE"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + image_repo_database_name=None, + image_repo_schema_name=None, + image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + ingress_enabled=False, + max_instances=3, + gpu_requests="GPU", + num_workers=1, force_rebuild=True, build_external_access_integration=sql_identifier.SqlIdentifier("EAI"), statement_params=mock.ANY, diff --git a/snowflake/ml/model/_client/ops/BUILD.bazel b/snowflake/ml/model/_client/ops/BUILD.bazel index af02b5e3..fef6e76c 100644 --- a/snowflake/ml/model/_client/ops/BUILD.bazel +++ b/snowflake/ml/model/_client/ops/BUILD.bazel @@ -75,3 +75,17 @@ py_library( "//snowflake/ml/model/_client/sql:stage", ], ) + +py_test( + name = "service_ops_test", + srcs = ["service_ops_test.py"], + deps = [ + ":service_ops", + "//snowflake/ml/_internal:file_utils", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model/_client/service:model_deployment_spec", + "//snowflake/ml/model/_client/sql:service", + "//snowflake/ml/model/_client/sql:stage", + "//snowflake/ml/test_utils:mock_session", + ], +) diff --git a/snowflake/ml/model/_client/ops/model_ops.py b/snowflake/ml/model/_client/ops/model_ops.py index 67357b52..a5c0aee8 100644 --- a/snowflake/ml/model/_client/ops/model_ops.py +++ b/snowflake/ml/model/_client/ops/model_ops.py @@ -554,15 +554,14 @@ def _match_model_spec_with_sql_functions( res[function_name] = target_method return res - def get_functions( + def _fetch_model_spec( self, - *, database_name: Optional[sql_identifier.SqlIdentifier], schema_name: Optional[sql_identifier.SqlIdentifier], model_name: sql_identifier.SqlIdentifier, version_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, - ) -> List[model_manifest_schema.ModelFunctionInfo]: + ) -> model_meta_schema.ModelMetadataDict: raw_model_spec_res = self._model_client.show_versions( database_name=database_name, schema_name=schema_name, @@ -573,6 +572,43 @@ def get_functions( )[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME] model_spec_dict = yaml.safe_load(raw_model_spec_res) model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict) + return model_spec + + def get_model_objective( + self, + *, + database_name: Optional[sql_identifier.SqlIdentifier], + schema_name: Optional[sql_identifier.SqlIdentifier], + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> type_hints.ModelObjective: + model_spec = self._fetch_model_spec( + database_name=database_name, + schema_name=schema_name, + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + model_objective_val = model_spec.get("model_objective", type_hints.ModelObjective.UNKNOWN.value) + return type_hints.ModelObjective(model_objective_val) + + def get_functions( + self, + *, + database_name: Optional[sql_identifier.SqlIdentifier], + schema_name: Optional[sql_identifier.SqlIdentifier], + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> List[model_manifest_schema.ModelFunctionInfo]: + model_spec = self._fetch_model_spec( + database_name=database_name, + schema_name=schema_name, + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) show_functions_res = self._model_version_client.show_functions( database_name=database_name, schema_name=schema_name, diff --git a/snowflake/ml/model/_client/ops/model_ops_test.py b/snowflake/ml/model/_client/ops/model_ops_test.py index c6ad955b..cfb9b0de 100644 --- a/snowflake/ml/model/_client/ops/model_ops_test.py +++ b/snowflake/ml/model/_client/ops/model_ops_test.py @@ -8,7 +8,7 @@ from absl.testing import absltest from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.model import model_signature +from snowflake.ml.model import model_signature, type_hints from snowflake.ml.model._client.ops import model_ops from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema @@ -1349,6 +1349,77 @@ def test_get_functions(self) -> None: ) mock_validate_model_metadata.assert_called_once_with(m_spec) + def test_get_model_objective(self) -> None: + m_spec = { + "signatures": { + "predict": _DUMMY_SIG["predict"].to_dict(), + "predict_table": _DUMMY_SIG["predict_table"].to_dict(), + }, + "model_objective": "binary_classification", + } + m_show_versions_result = [Row(model_spec=yaml.safe_dump(m_spec))] + with mock.patch.object( + self.m_ops._model_client, + "show_versions", + return_value=m_show_versions_result, + ) as mock_show_versions, mock.patch.object( + model_meta.ModelMetadata, + "_validate_model_metadata", + return_value=cast(model_meta_schema.ModelMetadataDict, m_spec), + ) as mock_validate_model_metadata: + res = self.m_ops.get_model_objective( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + statement_params=self.m_statement_params, + ) + mock_show_versions.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + check_model_details=True, + statement_params={**self.m_statement_params, "SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL": True}, + ) + mock_validate_model_metadata.assert_called_once_with(m_spec) + self.assertEqual(res, type_hints.ModelObjective.BINARY_CLASSIFICATION) + + def test_get_model_objective_empty(self) -> None: + m_spec = { + "signatures": { + "predict": _DUMMY_SIG["predict"].to_dict(), + "predict_table": _DUMMY_SIG["predict_table"].to_dict(), + } + } + m_show_versions_result = [Row(model_spec=yaml.safe_dump(m_spec))] + with mock.patch.object( + self.m_ops._model_client, + "show_versions", + return_value=m_show_versions_result, + ) as mock_show_versions, mock.patch.object( + model_meta.ModelMetadata, + "_validate_model_metadata", + return_value=cast(model_meta_schema.ModelMetadataDict, m_spec), + ) as mock_validate_model_metadata: + res = self.m_ops.get_model_objective( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + statement_params=self.m_statement_params, + ) + mock_show_versions.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + check_model_details=True, + statement_params={**self.m_statement_params, "SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL": True}, + ) + mock_validate_model_metadata.assert_called_once_with(m_spec) + self.assertEqual(res, type_hints.ModelObjective.UNKNOWN) + def test_download_files_minimal(self) -> None: m_list_files_res = [ [Row(name="versions/v1/model/model.yaml", size=419, md5="1234", last_modified="")], diff --git a/snowflake/ml/model/_client/ops/service_ops.py b/snowflake/ml/model/_client/ops/service_ops.py index 1010e575..84f80152 100644 --- a/snowflake/ml/model/_client/ops/service_ops.py +++ b/snowflake/ml/model/_client/ops/service_ops.py @@ -1,15 +1,45 @@ +import dataclasses +import hashlib +import logging import pathlib +import queue +import sys import tempfile -from typing import Any, Dict, Optional +import threading +import time +import uuid +from typing import Any, Dict, List, Optional, Tuple, cast +from snowflake import snowpark from snowflake.ml._internal import file_utils from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.model._client.service import model_deployment_spec from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql -from snowflake.snowpark import session +from snowflake.snowpark import exceptions, row, session from snowflake.snowpark._internal import utils as snowpark_utils +def get_logger(logger_name: str) -> logging.Logger: + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.INFO) + handler.setFormatter(logging.Formatter("%(name)s [%(asctime)s] [%(levelname)s] %(message)s")) + logger.addHandler(handler) + return logger + + +logger = get_logger(__name__) +logger.propagate = False + + +@dataclasses.dataclass +class ServiceLogInfo: + service_name: str + container_name: str + instance_id: str = "0" + + class ServiceOperator: """Service operator for container services logic.""" @@ -62,11 +92,10 @@ def create_service( image_repo_database_name: Optional[sql_identifier.SqlIdentifier], image_repo_schema_name: Optional[sql_identifier.SqlIdentifier], image_repo_name: sql_identifier.SqlIdentifier, - image_name: Optional[sql_identifier.SqlIdentifier], ingress_enabled: bool, - min_instances: int, max_instances: int, gpu_requests: Optional[str], + num_workers: Optional[int], force_rebuild: bool, build_external_access_integration: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, @@ -96,11 +125,10 @@ def create_service( image_repo_database_name=image_repo_database_name, image_repo_schema_name=image_repo_schema_name, image_repo_name=image_repo_name, - image_name=image_name, ingress_enabled=ingress_enabled, - min_instances=min_instances, max_instances=max_instances, gpu=gpu_requests, + num_workers=num_workers, force_rebuild=force_rebuild, external_access_integration=build_external_access_integration, ) @@ -111,11 +139,174 @@ def create_service( statement_params=statement_params, ) + # check if the inference service is already running + try: + model_inference_service_status, _ = self._service_client.get_service_status( + service_name=service_name, + include_message=False, + statement_params=statement_params, + ) + model_inference_service_exists = model_inference_service_status == service_sql.ServiceStatus.READY + except exceptions.SnowparkSQLException: + model_inference_service_exists = False + # deploy the model service - self._service_client.deploy_model( + query_id, async_job = self._service_client.deploy_model( stage_path=stage_path, model_deployment_spec_file_rel_path=model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH, statement_params=statement_params, ) + # stream service logs in a thread + services = [ + ServiceLogInfo(service_name=self._get_model_build_service_name(query_id), container_name="model-build"), + ServiceLogInfo(service_name=service_name, container_name="model-inference"), + ] + exception_queue: queue.Queue = queue.Queue() # type: ignore[type-arg] + log_thread = self._start_service_log_streaming( + async_job, services, model_inference_service_exists, exception_queue, statement_params + ) + log_thread.join() + + try: + # non-blocking check for an exception + exception = exception_queue.get(block=False) + if exception: + raise exception + except queue.Empty: + pass + return service_name + + def _start_service_log_streaming( + self, + async_job: snowpark.AsyncJob, + services: List[ServiceLogInfo], + model_inference_service_exists: bool, + exception_queue: queue.Queue, # type: ignore[type-arg] + statement_params: Optional[Dict[str, Any]] = None, + ) -> threading.Thread: + """Start the service log streaming in a separate thread.""" + log_thread = threading.Thread( + target=self._stream_service_logs, + args=(async_job, services, model_inference_service_exists, exception_queue, statement_params), + ) + log_thread.start() + return log_thread + + def _stream_service_logs( + self, + async_job: snowpark.AsyncJob, + services: List[ServiceLogInfo], + model_inference_service_exists: bool, + exception_queue: queue.Queue, # type: ignore[type-arg] + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + """Stream service logs while the async job is running.""" + + def fetch_logs(service_name: str, container_name: str, offset: int) -> Tuple[str, int]: + service_logs = self._service_client.get_service_logs( + service_name=service_name, + container_name=container_name, + statement_params=statement_params, + ) + + # return only new logs starting after the offset + if len(service_logs) > offset: + new_logs = service_logs[offset:] + new_offset = len(service_logs) + else: + new_logs = "" + new_offset = offset + + return new_logs, new_offset + + is_model_build_service_done = False + log_offset = 0 + model_build_service, model_inference_service = services[0], services[1] + service_name, container_name = model_build_service.service_name, model_build_service.container_name + # BuildJobName + service_logger = get_logger(service_name) + service_logger.propagate = False + while not async_job.is_done(): + if model_inference_service_exists: + time.sleep(5) + continue + + try: + block_size = 180 + service_status, message = self._service_client.get_service_status( + service_name=service_name, include_message=True, statement_params=statement_params + ) + logger.info(f"Inference service {service_name} is {service_status.value}.") + + new_logs, new_offset = fetch_logs(service_name, container_name, log_offset) + if new_logs: + service_logger.info(new_logs) + log_offset = new_offset + + # check if model build service is done + if not is_model_build_service_done: + service_status, _ = self._service_client.get_service_status( + service_name=model_build_service.service_name, + include_message=False, + statement_params=statement_params, + ) + + if service_status == service_sql.ServiceStatus.DONE: + is_model_build_service_done = True + log_offset = 0 + service_name = model_inference_service.service_name + container_name = model_inference_service.container_name + # InferenceServiceName-InstanceId + service_logger = get_logger(f"{service_name}-{model_inference_service.instance_id}") + service_logger.propagate = False + logger.info(f"Model build service {model_build_service.service_name} complete.") + logger.info("-" * block_size) + except ValueError: + logger.warning(f"Unknown service status: {service_status.value}") + except Exception as ex: + logger.warning(f"Caught an exception when logging: {repr(ex)}") + + time.sleep(5) + + if model_inference_service_exists: + logger.info(f"Inference service {model_inference_service.service_name} is already RUNNING.") + else: + self._finalize_logs(service_logger, services[-1], log_offset, statement_params) + + # catch exceptions from the deploy model execution + try: + res = cast(List[row.Row], async_job.result()) + logger.info(f"Model deployment for inference service {model_inference_service.service_name} complete.") + logger.info(res[0][0]) + except Exception as ex: + exception_queue.put(ex) + + def _finalize_logs( + self, + service_logger: logging.Logger, + service: ServiceLogInfo, + offset: int, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + """Fetch service logs after the async job is done to ensure no logs are missed.""" + try: + service_logs = self._service_client.get_service_logs( + service_name=service.service_name, + container_name=service.container_name, + statement_params=statement_params, + ) + + if len(service_logs) > offset: + service_logger.info(service_logs[offset:]) + except Exception as ex: + logger.warning(f"Caught an exception when logging: {repr(ex)}") + + @staticmethod + def _get_model_build_service_name(query_id: str) -> str: + """Get the model build service name through the server-side logic.""" + most_significant_bits = uuid.UUID(query_id).int >> 64 + md5_hash = hashlib.md5(str(most_significant_bits).encode()).hexdigest() + identifier = md5_hash[:6] + return ("model_build_" + identifier).upper() diff --git a/snowflake/ml/model/_client/ops/service_ops_test.py b/snowflake/ml/model/_client/ops/service_ops_test.py new file mode 100644 index 00000000..eee46a37 --- /dev/null +++ b/snowflake/ml/model/_client/ops/service_ops_test.py @@ -0,0 +1,124 @@ +import hashlib +import pathlib +import uuid +from typing import cast +from unittest import mock + +from absl.testing import absltest + +from snowflake import snowpark +from snowflake.ml._internal import file_utils +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.ops import service_ops +from snowflake.ml.model._client.sql import service as service_sql +from snowflake.ml.test_utils import mock_session +from snowflake.snowpark import Session +from snowflake.snowpark._internal import utils as snowpark_utils + + +class ModelOpsTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.m_statement_params = {"test": "1"} + self.c_session = cast(Session, self.m_session) + self.m_ops = service_ops.ServiceOperator( + self.c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ) + + def test_create_service(self) -> None: + with mock.patch.object(self.m_ops._stage_client, "create_tmp_stage",) as mock_create_stage, mock.patch.object( + snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_STAGE_ABCDEF0123" + ), mock.patch.object(self.m_ops._model_deployment_spec, "save",) as mock_save, mock.patch.object( + file_utils, "upload_directory_to_stage", return_value=None + ) as mock_upload_directory_to_stage, mock.patch.object( + self.m_ops._service_client, + "deploy_model", + return_value=(str(uuid.uuid4()), mock.MagicMock(spec=snowpark.AsyncJob)), + ) as mock_deploy_model, mock.patch.object( + self.m_ops._service_client, + "get_service_status", + return_value=(service_sql.ServiceStatus.PENDING, None), + ) as mock_get_service_status: + self.m_ops.create_service( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("VERSION"), + service_database_name=sql_identifier.SqlIdentifier("SERVICE_DB"), + service_schema_name=sql_identifier.SqlIdentifier("SERVICE_SCHEMA"), + service_name=sql_identifier.SqlIdentifier("SERVICE"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), + service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + image_repo_database_name=sql_identifier.SqlIdentifier("IMAGE_REPO_DB"), + image_repo_schema_name=sql_identifier.SqlIdentifier("IMAGE_REPO_SCHEMA"), + image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + ingress_enabled=True, + max_instances=1, + gpu_requests="1", + num_workers=1, + force_rebuild=True, + build_external_access_integration=sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION"), + statement_params=self.m_statement_params, + ) + mock_create_stage.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + stage_name=sql_identifier.SqlIdentifier("SNOWPARK_TEMP_STAGE_ABCDEF0123"), + statement_params=self.m_statement_params, + ) + mock_save.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("VERSION"), + service_database_name=sql_identifier.SqlIdentifier("SERVICE_DB"), + service_schema_name=sql_identifier.SqlIdentifier("SERVICE_SCHEMA"), + service_name=sql_identifier.SqlIdentifier("SERVICE"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), + service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + image_repo_database_name=sql_identifier.SqlIdentifier("IMAGE_REPO_DB"), + image_repo_schema_name=sql_identifier.SqlIdentifier("IMAGE_REPO_SCHEMA"), + image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + ingress_enabled=True, + max_instances=1, + gpu="1", + num_workers=1, + force_rebuild=True, + external_access_integration=sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION"), + ) + mock_upload_directory_to_stage.assert_called_once_with( + self.c_session, + local_path=self.m_ops._model_deployment_spec.workspace_path, + stage_path=pathlib.PurePosixPath( + self.m_ops._stage_client.fully_qualified_object_name( + sql_identifier.SqlIdentifier("DB"), + sql_identifier.SqlIdentifier("SCHEMA"), + sql_identifier.SqlIdentifier("SNOWPARK_TEMP_STAGE_ABCDEF0123"), + ) + ), + statement_params=self.m_statement_params, + ) + mock_deploy_model.assert_called_once_with( + stage_path="DB.SCHEMA.SNOWPARK_TEMP_STAGE_ABCDEF0123", + model_deployment_spec_file_rel_path=self.m_ops._model_deployment_spec.DEPLOY_SPEC_FILE_REL_PATH, + statement_params=self.m_statement_params, + ) + mock_get_service_status.assert_called_once_with( + service_name="SERVICE", + include_message=False, + statement_params=self.m_statement_params, + ) + + def test_get_model_build_service_name(self) -> None: + query_id = str(uuid.uuid4()) + most_significant_bits = uuid.UUID(query_id).int >> 64 + md5_hash = hashlib.md5(str(most_significant_bits).encode()).hexdigest() + identifier = md5_hash[:6] + service_name = ("model_build_" + identifier).upper() + self.assertEqual(self.m_ops._get_model_build_service_name(query_id), service_name) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/model/_client/service/BUILD.bazel b/snowflake/ml/model/_client/service/BUILD.bazel index 953408b3..e05bb76f 100644 --- a/snowflake/ml/model/_client/service/BUILD.bazel +++ b/snowflake/ml/model/_client/service/BUILD.bazel @@ -1,4 +1,4 @@ -load("//bazel:py_rules.bzl", "py_library") +load("//bazel:py_rules.bzl", "py_library", "py_test") package(default_visibility = [ "//bazel:snowml_public_common", @@ -18,3 +18,12 @@ py_library( ":model_deployment_spec_schema", ], ) + +py_test( + name = "model_deployment_spec_test", + srcs = ["model_deployment_spec_test.py"], + deps = [ + ":model_deployment_spec", + "//snowflake/ml/_internal/utils:sql_identifier", + ], +) diff --git a/snowflake/ml/model/_client/service/model_deployment_spec.py b/snowflake/ml/model/_client/service/model_deployment_spec.py index b3d67b28..a7946366 100644 --- a/snowflake/ml/model/_client/service/model_deployment_spec.py +++ b/snowflake/ml/model/_client/service/model_deployment_spec.py @@ -34,11 +34,10 @@ def save( image_repo_database_name: Optional[sql_identifier.SqlIdentifier], image_repo_schema_name: Optional[sql_identifier.SqlIdentifier], image_repo_name: sql_identifier.SqlIdentifier, - image_name: Optional[sql_identifier.SqlIdentifier], ingress_enabled: bool, - min_instances: int, max_instances: int, gpu: Optional[str], + num_workers: Optional[int], force_rebuild: bool, external_access_integration: sql_identifier.SqlIdentifier, ) -> None: @@ -61,8 +60,6 @@ def save( force_rebuild=force_rebuild, external_access_integrations=[external_access_integration.identifier()], ) - if image_name: - image_build_dict["image_name"] = image_name.identifier() # service spec saved_service_database = service_database_name or database_name @@ -74,12 +71,14 @@ def save( name=fq_service_name, compute_pool=service_compute_pool_name.identifier(), ingress_enabled=ingress_enabled, - min_instances=min_instances, max_instances=max_instances, ) if gpu: service_dict["gpu"] = gpu + if num_workers: + service_dict["num_workers"] = num_workers + # model deployment spec model_deployment_spec_dict = model_deployment_spec_schema.ModelDeploymentSpecDict( models=[model_dict], diff --git a/snowflake/ml/model/_client/service/model_deployment_spec_schema.py b/snowflake/ml/model/_client/service/model_deployment_spec_schema.py index d77d9d98..2dd58ce6 100644 --- a/snowflake/ml/model/_client/service/model_deployment_spec_schema.py +++ b/snowflake/ml/model/_client/service/model_deployment_spec_schema.py @@ -11,7 +11,6 @@ class ModelDict(TypedDict): class ImageBuildDict(TypedDict): compute_pool: Required[str] image_repo: Required[str] - image_name: NotRequired[str] force_rebuild: Required[bool] external_access_integrations: Required[List[str]] @@ -20,9 +19,9 @@ class ServiceDict(TypedDict): name: Required[str] compute_pool: Required[str] ingress_enabled: Required[bool] - min_instances: Required[int] max_instances: Required[int] gpu: NotRequired[str] + num_workers: NotRequired[int] class ModelDeploymentSpecDict(TypedDict): diff --git a/snowflake/ml/model/_client/service/model_deployment_spec_test.py b/snowflake/ml/model/_client/service/model_deployment_spec_test.py new file mode 100644 index 00000000..ae7b7aa1 --- /dev/null +++ b/snowflake/ml/model/_client/service/model_deployment_spec_test.py @@ -0,0 +1,158 @@ +import pathlib +import tempfile + +import yaml +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.service import model_deployment_spec + + +class ModelDeploymentSpecTest(absltest.TestCase): + def test_minimal(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + mds = model_deployment_spec.ModelDeploymentSpec(workspace_path=pathlib.Path(tmpdir)) + mds.save( + database_name=sql_identifier.SqlIdentifier("db"), + schema_name=sql_identifier.SqlIdentifier("schema"), + model_name=sql_identifier.SqlIdentifier("model"), + version_name=sql_identifier.SqlIdentifier("version"), + service_database_name=None, + service_schema_name=None, + service_name=sql_identifier.SqlIdentifier("service"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("image_build_compute_pool"), + service_compute_pool_name=sql_identifier.SqlIdentifier("service_compute_pool"), + image_repo_database_name=None, + image_repo_schema_name=None, + image_repo_name=sql_identifier.SqlIdentifier("image_repo"), + ingress_enabled=True, + max_instances=1, + gpu=None, + num_workers=None, + force_rebuild=False, + external_access_integration=sql_identifier.SqlIdentifier("external_access_integration"), + ) + + file_path = mds.workspace_path / mds.DEPLOY_SPEC_FILE_REL_PATH + with file_path.open("r", encoding="utf-8") as f: + result = yaml.safe_load(f) + self.assertDictEqual( + result, + { + "models": [{"name": "DB.SCHEMA.MODEL", "version": "VERSION"}], + "image_build": { + "compute_pool": "IMAGE_BUILD_COMPUTE_POOL", + "image_repo": "DB.SCHEMA.IMAGE_REPO", + "force_rebuild": False, + "external_access_integrations": ["EXTERNAL_ACCESS_INTEGRATION"], + }, + "service": { + "name": "DB.SCHEMA.SERVICE", + "compute_pool": "SERVICE_COMPUTE_POOL", + "ingress_enabled": True, + "max_instances": 1, + }, + }, + ) + + def test_minimal_case_sensitive(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + mds = model_deployment_spec.ModelDeploymentSpec(workspace_path=pathlib.Path(tmpdir)) + mds.save( + database_name=sql_identifier.SqlIdentifier("db", case_sensitive=True), + schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("model", case_sensitive=True), + version_name=sql_identifier.SqlIdentifier("version", case_sensitive=True), + service_database_name=None, + service_schema_name=None, + service_name=sql_identifier.SqlIdentifier("service", case_sensitive=True), + image_build_compute_pool_name=sql_identifier.SqlIdentifier( + "image_build_compute_pool", case_sensitive=True + ), + service_compute_pool_name=sql_identifier.SqlIdentifier("service_compute_pool", case_sensitive=True), + image_repo_database_name=None, + image_repo_schema_name=None, + image_repo_name=sql_identifier.SqlIdentifier("image_repo", case_sensitive=True), + ingress_enabled=True, + max_instances=1, + gpu=None, + num_workers=None, + force_rebuild=False, + external_access_integration=sql_identifier.SqlIdentifier( + "external_access_integration", case_sensitive=True + ), + ) + + file_path = mds.workspace_path / mds.DEPLOY_SPEC_FILE_REL_PATH + with file_path.open("r", encoding="utf-8") as f: + result = yaml.safe_load(f) + self.assertDictEqual( + result, + { + "models": [{"name": '"db"."schema"."model"', "version": '"version"'}], + "image_build": { + "compute_pool": '"image_build_compute_pool"', + "image_repo": '"db"."schema"."image_repo"', + "force_rebuild": False, + "external_access_integrations": ['"external_access_integration"'], + }, + "service": { + "name": '"db"."schema"."service"', + "compute_pool": '"service_compute_pool"', + "ingress_enabled": True, + "max_instances": 1, + }, + }, + ) + + def test_full(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + mds = model_deployment_spec.ModelDeploymentSpec(workspace_path=pathlib.Path(tmpdir)) + mds.save( + database_name=sql_identifier.SqlIdentifier("db"), + schema_name=sql_identifier.SqlIdentifier("schema"), + model_name=sql_identifier.SqlIdentifier("model"), + version_name=sql_identifier.SqlIdentifier("version"), + service_database_name=sql_identifier.SqlIdentifier("service_db"), + service_schema_name=sql_identifier.SqlIdentifier("service_schema"), + service_name=sql_identifier.SqlIdentifier("service"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("image_build_compute_pool"), + service_compute_pool_name=sql_identifier.SqlIdentifier("service_compute_pool"), + image_repo_database_name=sql_identifier.SqlIdentifier("image_repo_db"), + image_repo_schema_name=sql_identifier.SqlIdentifier("image_repo_schema"), + image_repo_name=sql_identifier.SqlIdentifier("image_repo"), + ingress_enabled=True, + max_instances=10, + gpu="1", + num_workers=10, + force_rebuild=True, + external_access_integration=sql_identifier.SqlIdentifier("external_access_integration"), + ) + + file_path = mds.workspace_path / mds.DEPLOY_SPEC_FILE_REL_PATH + with file_path.open("r", encoding="utf-8") as f: + result = yaml.safe_load(f) + self.assertDictEqual( + result, + { + "models": [{"name": "DB.SCHEMA.MODEL", "version": "VERSION"}], + "image_build": { + "compute_pool": "IMAGE_BUILD_COMPUTE_POOL", + "image_repo": "IMAGE_REPO_DB.IMAGE_REPO_SCHEMA.IMAGE_REPO", + "force_rebuild": True, + "external_access_integrations": ["EXTERNAL_ACCESS_INTEGRATION"], + }, + "service": { + "name": "SERVICE_DB.SERVICE_SCHEMA.SERVICE", + "compute_pool": "SERVICE_COMPUTE_POOL", + "ingress_enabled": True, + "max_instances": 10, + "gpu": "1", + "num_workers": 10, + }, + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/model/_client/sql/BUILD.bazel b/snowflake/ml/model/_client/sql/BUILD.bazel index d012fa0b..9eca12d0 100644 --- a/snowflake/ml/model/_client/sql/BUILD.bazel +++ b/snowflake/ml/model/_client/sql/BUILD.bazel @@ -111,3 +111,14 @@ py_library( "//snowflake/ml/_internal/utils:sql_identifier", ], ) + +py_test( + name = "service_test", + srcs = ["service_test.py"], + deps = [ + ":service", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/test_utils:mock_data_frame", + "//snowflake/ml/test_utils:mock_session", + ], +) diff --git a/snowflake/ml/model/_client/sql/service.py b/snowflake/ml/model/_client/sql/service.py index b6acaeb8..063e853b 100644 --- a/snowflake/ml/model/_client/sql/service.py +++ b/snowflake/ml/model/_client/sql/service.py @@ -1,6 +1,9 @@ +import enum +import json import textwrap from typing import Any, Dict, List, Optional, Tuple +from snowflake import snowpark from snowflake.ml._internal.utils import ( identifier, query_result_checker, @@ -11,6 +14,17 @@ from snowflake.snowpark._internal import utils as snowpark_utils +class ServiceStatus(enum.Enum): + UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet. + PENDING = "PENDING" # resource set is being created, can't be used yet + READY = "READY" # resource set has been deployed. + DELETING = "DELETING" # resource set is being deleted + FAILED = "FAILED" # resource set has failed and cannot be used anymore + DONE = "DONE" # resource set has finished running + NOT_FOUND = "NOT_FOUND" # not found or deleted + INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error. + + class ServiceSQLClient(_base._BaseSQLClient): def build_model_container( self, @@ -30,20 +44,21 @@ def build_model_container( ) -> None: actual_image_repo_database = image_repo_database_name or self._database_name actual_image_repo_schema = image_repo_schema_name or self._schema_name - fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name) - fq_image_repo_name = "/" + "/".join( - [ - actual_image_repo_database.identifier(), - actual_image_repo_schema.identifier(), - image_repo_name.identifier(), - ] + actual_model_database = database_name or self._database_name + actual_model_schema = schema_name or self._schema_name + fq_model_name = self.fully_qualified_object_name(actual_model_database, actual_model_schema, model_name) + fq_image_repo_name = identifier.get_schema_level_object_identifier( + actual_image_repo_database.identifier(), + actual_image_repo_schema.identifier(), + image_repo_name.identifier(), ) - is_gpu = gpu is not None + is_gpu_str = "TRUE" if gpu else "FALSE" + force_rebuild_str = "TRUE" if force_rebuild else "FALSE" query_result_checker.SqlResultValidator( self._session, ( f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}'," - f" '{fq_image_repo_name}', '{is_gpu}', '{force_rebuild}', '', '{external_access_integration}')" + f" '{fq_image_repo_name}', '{is_gpu_str}', '{force_rebuild_str}', '', '{external_access_integration}')" ), statement_params=statement_params, ).has_dimensions(expected_rows=1, expected_cols=1).validate() @@ -54,12 +69,12 @@ def deploy_model( stage_path: str, model_deployment_spec_file_rel_path: str, statement_params: Optional[Dict[str, Any]] = None, - ) -> None: - query_result_checker.SqlResultValidator( - self._session, - f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')", - statement_params=statement_params, - ).has_dimensions(expected_rows=1, expected_cols=1).validate() + ) -> Tuple[str, snowpark.AsyncJob]: + async_job = self._session.sql( + f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')" + ).collect(block=False, statement_params=statement_params) + assert isinstance(async_job, snowpark.AsyncJob) + return async_job.query_id, async_job def invoke_function_method( self, @@ -74,12 +89,20 @@ def invoke_function_method( statement_params: Optional[Dict[str, Any]] = None, ) -> dataframe.DataFrame: with_statements = [] + actual_database_name = database_name or self._database_name + actual_schema_name = schema_name or self._schema_name + + function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()]) + fully_qualified_function_name = identifier.get_schema_level_object_identifier( + actual_database_name.identifier(), + actual_schema_name.identifier(), + function_name, + ) + if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0: INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT" with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})") else: - actual_database_name = database_name or self._database_name - actual_schema_name = schema_name or self._schema_name tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE) INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier( actual_database_name.identifier(), @@ -104,7 +127,7 @@ def invoke_function_method( sql = textwrap.dedent( f"""{with_sql} SELECT *, - {service_name.identifier()}_{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME} + {fully_qualified_function_name}({args_sql}) AS {INTERMEDIATE_OBJ_NAME} FROM {INTERMEDIATE_TABLE_NAME}""" ) @@ -127,3 +150,47 @@ def invoke_function_method( output_df._statement_params = statement_params # type: ignore[assignment] return output_df + + def get_service_logs( + self, + *, + service_name: str, + instance_id: str = "0", + container_name: str, + statement_params: Optional[Dict[str, Any]] = None, + ) -> str: + system_func = "SYSTEM$GET_SERVICE_LOGS" + rows = ( + query_result_checker.SqlResultValidator( + self._session, + f"CALL {system_func}('{service_name}', '{instance_id}', '{container_name}')", + statement_params=statement_params, + ) + .has_dimensions(expected_rows=1, expected_cols=1) + .validate() + ) + return str(rows[0][system_func]) + + def get_service_status( + self, + *, + service_name: str, + include_message: bool = False, + statement_params: Optional[Dict[str, Any]] = None, + ) -> Tuple[ServiceStatus, Optional[str]]: + system_func = "SYSTEM$GET_SERVICE_STATUS" + rows = ( + query_result_checker.SqlResultValidator( + self._session, + f"CALL {system_func}('{service_name}')", + statement_params=statement_params, + ) + .has_dimensions(expected_rows=1, expected_cols=1) + .validate() + ) + metadata = json.loads(rows[0][system_func])[0] + if metadata and metadata["status"]: + service_status = ServiceStatus(metadata["status"]) + message = metadata["message"] if include_message else None + return service_status, message + return ServiceStatus.UNKNOWN, None diff --git a/snowflake/ml/model/_client/sql/service_test.py b/snowflake/ml/model/_client/sql/service_test.py new file mode 100644 index 00000000..7827e108 --- /dev/null +++ b/snowflake/ml/model/_client/sql/service_test.py @@ -0,0 +1,327 @@ +import copy +import uuid +from typing import cast +from unittest import mock + +from absl.testing import absltest + +from snowflake import snowpark +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.sql import service as service_sql +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import DataFrame, Row, Session, functions as F, types as spt +from snowflake.snowpark._internal import utils as snowpark_utils + + +class ServiceSQLTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + + def test_build_model_container(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row("Image built successfully.")], collect_statement_params=m_statement_params + ) + self.m_session.add_mock_sql( + """ + CALL SYSTEM$BUILD_MODEL_CONTAINER('TEMP."test".MODEL', 'V1', '"my_pool"', + 'TEMP."test"."image_repo"', 'FALSE', 'FALSE', '','MY_EAI')""", + copy.deepcopy(m_df), + ) + c_session = cast(Session, self.m_session) + service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).build_model_container( + database_name=None, + schema_name=None, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + compute_pool_name=sql_identifier.SqlIdentifier("my_pool", case_sensitive=True), + image_repo_database_name=None, + image_repo_schema_name=None, + image_repo_name=sql_identifier.SqlIdentifier("image_repo", case_sensitive=True), + gpu=None, + force_rebuild=False, + external_access_integration=sql_identifier.SqlIdentifier("MY_EAI"), + statement_params=m_statement_params, + ) + + self.m_session.add_mock_sql( + """ + CALL SYSTEM$BUILD_MODEL_CONTAINER('DB_1."sch_1"."model"', '"v1"', 'MY_POOL', + '"db_2".SCH_2.IMAGE_REPO', 'TRUE', 'TRUE', '', '"my_eai"')""", + copy.deepcopy(m_df), + ) + service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).build_model_container( + database_name=sql_identifier.SqlIdentifier("DB_1"), + schema_name=sql_identifier.SqlIdentifier("sch_1", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("model", case_sensitive=True), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + compute_pool_name=sql_identifier.SqlIdentifier("my_pool"), + image_repo_database_name=sql_identifier.SqlIdentifier("db_2", case_sensitive=True), + image_repo_schema_name=sql_identifier.SqlIdentifier("SCH_2"), + image_repo_name=sql_identifier.SqlIdentifier("image_repo"), + gpu="1", + force_rebuild=True, + external_access_integration=sql_identifier.SqlIdentifier("my_eai", case_sensitive=True), + statement_params=m_statement_params, + ) + + def test_deploy_model(self) -> None: + m_statement_params = {"test": "1"} + m_async_job = mock.MagicMock(spec=snowpark.AsyncJob) + m_async_job.query_id = uuid.uuid4() + m_df = mock_data_frame.MockDataFrame( + collect_block=False, + collect_result=m_async_job, + collect_statement_params=m_statement_params, + ) + + self.m_session.add_mock_sql( + """CALL SYSTEM$DEPLOY_MODEL('@stage_path/model_deployment_spec_file_rel_path')""", + copy.deepcopy(m_df), + ) + c_session = cast(Session, self.m_session) + + service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).deploy_model( + stage_path="stage_path", + model_deployment_spec_file_rel_path="model_deployment_spec_file_rel_path", + statement_params=m_statement_params, + ) + + def test_invoke_function_method(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame() + self.m_session.add_mock_sql( + """SELECT *, + TEMP."test".SERVICE_PREDICT(COL1, COL2) AS TMP_RESULT + FROM TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123""", + m_df, + ) + m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT") + c_session = cast(Session, self.m_session) + mock_writer = mock.MagicMock() + m_df.__setattr__("write", mock_writer) + m_df.add_query("queries", "query_1") + m_df.add_query("queries", "query_2") + with mock.patch.object(mock_writer, "save_as_table") as mock_save_as_table, mock.patch.object( + snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_TABLE_ABCDEF0123" + ) as mock_random_name_for_temp_object: + service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).invoke_function_method( + database_name=None, + schema_name=None, + service_name=sql_identifier.SqlIdentifier("SERVICE"), + method_name=sql_identifier.SqlIdentifier("PREDICT"), + input_df=cast(DataFrame, m_df), + input_args=[sql_identifier.SqlIdentifier("COL1"), sql_identifier.SqlIdentifier("COL2")], + returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], + statement_params=m_statement_params, + ) + mock_random_name_for_temp_object.assert_called_once_with(snowpark_utils.TempObjectType.TABLE) + mock_save_as_table.assert_called_once_with( + table_name='TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123', + mode="errorifexists", + table_type="temporary", + statement_params=m_statement_params, + ) + + def test_invoke_function_method_1(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame() + self.m_session.add_mock_sql( + """SELECT *, + FOO."bar"."service_PREDICT"(COL1, COL2) AS TMP_RESULT + FROM FOO."bar".SNOWPARK_TEMP_TABLE_ABCDEF0123""", + m_df, + ) + m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT") + c_session = cast(Session, self.m_session) + mock_writer = mock.MagicMock() + m_df.__setattr__("write", mock_writer) + m_df.add_query("queries", "query_1") + m_df.add_query("queries", "query_2") + with mock.patch.object(mock_writer, "save_as_table") as mock_save_as_table, mock.patch.object( + snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_TABLE_ABCDEF0123" + ) as mock_random_name_for_temp_object: + service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).invoke_function_method( + database_name=sql_identifier.SqlIdentifier("FOO"), + schema_name=sql_identifier.SqlIdentifier("bar", case_sensitive=True), + service_name=sql_identifier.SqlIdentifier("service", case_sensitive=True), + method_name=sql_identifier.SqlIdentifier("PREDICT"), + input_df=cast(DataFrame, m_df), + input_args=[sql_identifier.SqlIdentifier("COL1"), sql_identifier.SqlIdentifier("COL2")], + returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], + statement_params=m_statement_params, + ) + mock_random_name_for_temp_object.assert_called_once_with(snowpark_utils.TempObjectType.TABLE) + mock_save_as_table.assert_called_once_with( + table_name='FOO."bar".SNOWPARK_TEMP_TABLE_ABCDEF0123', + mode="errorifexists", + table_type="temporary", + statement_params=m_statement_params, + ) + + def test_invoke_function_method_2(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame() + self.m_session.add_mock_sql( + """WITH SNOWPARK_ML_MODEL_INFERENCE_INPUT AS (query_1) + SELECT *, + TEMP."test".SERVICE_PREDICT(COL1, COL2) AS TMP_RESULT + FROM SNOWPARK_ML_MODEL_INFERENCE_INPUT""", + m_df, + ) + m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT") + c_session = cast(Session, self.m_session) + m_df.add_query("queries", "query_1") + service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).invoke_function_method( + database_name=None, + schema_name=None, + service_name=sql_identifier.SqlIdentifier("SERVICE"), + method_name=sql_identifier.SqlIdentifier("PREDICT"), + input_df=cast(DataFrame, m_df), + input_args=[sql_identifier.SqlIdentifier("COL1"), sql_identifier.SqlIdentifier("COL2")], + returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], + statement_params=m_statement_params, + ) + + def test_get_service_logs(self) -> None: + m_statement_params = {"test": "1"} + row = Row("SYSTEM$GET_SERVICE_LOGS") + m_res = "INFO: Test" + m_df = mock_data_frame.MockDataFrame(collect_result=[row(m_res)], collect_statement_params=m_statement_params) + + self.m_session.add_mock_sql( + """CALL SYSTEM$GET_SERVICE_LOGS('SERVICE', '0', 'model-container')""", + copy.deepcopy(m_df), + ) + c_session = cast(Session, self.m_session) + + res = service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).get_service_logs( + service_name="SERVICE", + instance_id="0", + container_name="model-container", + statement_params=m_statement_params, + ) + self.assertEqual(res, m_res) + + def test_get_service_status_include_message(self) -> None: + m_statement_params = {"test": "1"} + m_service_status = service_sql.ServiceStatus("READY") + m_message = "test message" + m_res = (m_service_status, m_message) + status_res = ( + f'[{{"status":"{m_service_status.value}","message":"{m_message}",' + '"containerName":"model-inference","instanceId":"0","serviceName":"SERVICE",' + '"image":"image_url","restartCount":0,"startTime":""}]' + ) + row = Row("SYSTEM$GET_SERVICE_STATUS") + m_df = mock_data_frame.MockDataFrame( + collect_result=[row(status_res)], collect_statement_params=m_statement_params + ) + + self.m_session.add_mock_sql( + """CALL SYSTEM$GET_SERVICE_STATUS('SERVICE')""", + copy.deepcopy(m_df), + ) + c_session = cast(Session, self.m_session) + res = service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).get_service_status( + service_name="SERVICE", + include_message=True, + statement_params=m_statement_params, + ) + self.assertEqual(res, m_res) + + def test_get_service_status_exclude_message(self) -> None: + m_statement_params = {"test": "1"} + m_service_status = service_sql.ServiceStatus("READY") + m_message = "test message" + m_res = (m_service_status, None) + status_res = ( + f'[{{"status":"{m_service_status.value}","message":"{m_message}",' + '"containerName":"model-inference","instanceId":"0","serviceName":"SERVICE",' + '"image":"image_url","restartCount":0,"startTime":""}]' + ) + row = Row("SYSTEM$GET_SERVICE_STATUS") + m_df = mock_data_frame.MockDataFrame( + collect_result=[row(status_res)], collect_statement_params=m_statement_params + ) + + self.m_session.add_mock_sql( + """CALL SYSTEM$GET_SERVICE_STATUS('SERVICE')""", + copy.deepcopy(m_df), + ) + c_session = cast(Session, self.m_session) + res = service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).get_service_status( + service_name="SERVICE", + include_message=False, + statement_params=m_statement_params, + ) + self.assertEqual(res, m_res) + + def test_get_service_status_no_status(self) -> None: + m_statement_params = {"test": "1"} + m_message = "test message" + m_res = (service_sql.ServiceStatus.UNKNOWN, None) + status_res = ( + f'[{{"status":"","message":"{m_message}","containerName":"model-inference","instanceId":"0",' + '"serviceName":"SERVICE","image":"image_url","restartCount":0,"startTime":""}]' + ) + row = Row("SYSTEM$GET_SERVICE_STATUS") + m_df = mock_data_frame.MockDataFrame( + collect_result=[row(status_res)], collect_statement_params=m_statement_params + ) + + self.m_session.add_mock_sql( + """CALL SYSTEM$GET_SERVICE_STATUS('SERVICE')""", + copy.deepcopy(m_df), + ) + c_session = cast(Session, self.m_session) + res = service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).get_service_status( + service_name="SERVICE", + include_message=False, + statement_params=m_statement_params, + ) + self.assertEqual(res, m_res) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/model/_deploy_client/image_builds/inference_server/BUILD.bazel b/snowflake/ml/model/_deploy_client/image_builds/inference_server/BUILD.bazel index 788c3488..e9fff968 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/inference_server/BUILD.bazel +++ b/snowflake/ml/model/_deploy_client/image_builds/inference_server/BUILD.bazel @@ -9,7 +9,6 @@ exports_files([ py_library( name = "main", srcs = ["main.py"], - compatible_with_snowpark = False, deps = [ "//snowflake/ml/model:_api", "//snowflake/ml/model:custom_model", @@ -20,7 +19,6 @@ py_library( py_test( name = "main_test", srcs = ["main_test.py"], - compatible_with_snowpark = False, deps = [ ":main", "//snowflake/ml/_internal:file_utils", diff --git a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py b/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py index ec0fae69..3bc5828c 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +++ b/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py @@ -182,7 +182,7 @@ def _construct_and_upload_job_spec(self, base_image: str, kaniko_shell_script_st with file_utils.open_file(spec_file_path, "w+") as spec_file: assert self.artifact_stage_location.startswith("@") normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@")) - (db, schema, stage, path) = identifier.parse_schema_level_object_identifier(normed_artifact_stage_path) + (db, schema, stage, path) = identifier.parse_snowflake_stage_path(normed_artifact_stage_path) content = Template(spec_template).safe_substitute( { "base_image": base_image, diff --git a/snowflake/ml/model/_deploy_client/snowservice/deploy.py b/snowflake/ml/model/_deploy_client/snowservice/deploy.py index 5c6ecd0a..e28cc38a 100644 --- a/snowflake/ml/model/_deploy_client/snowservice/deploy.py +++ b/snowflake/ml/model/_deploy_client/snowservice/deploy.py @@ -280,7 +280,7 @@ def _sanitize_dns_url(url: str) -> str: conn = session._conn._conn # We try to use the same db and schema as the service function locates, as we could retrieve those information # if that is a fully qualified one. If not we use the current session one. - (_db, _schema, _, _) = identifier.parse_schema_level_object_identifier(service_func_name) + (_db, _schema, _) = identifier.parse_schema_level_object_identifier(service_func_name) db = _db if _db is not None else conn._database schema = _schema if _schema is not None else conn._schema assert isinstance(db, str) and isinstance(schema, str) @@ -343,7 +343,7 @@ def __init__( self.model_zip_stage_path = model_zip_stage_path self.options = options self.target_method = target_method - (db, schema, _, _) = identifier.parse_schema_level_object_identifier(service_func_name) + (db, schema, _) = identifier.parse_schema_level_object_identifier(service_func_name) self._service_name = identifier.get_schema_level_object_identifier(db, schema, f"service_{model_id}") self._job_name = identifier.get_schema_level_object_identifier(db, schema, f"build_{model_id}") @@ -503,7 +503,7 @@ def _prepare_and_upload_artifacts_to_stage(self, image: str) -> None: norm_stage_path = posixpath.normpath(identifier.remove_prefix(self.model_zip_stage_path, "@")) # Ensure model stage path has root prefix as stage mount will it mount it to root. absolute_model_stage_path = os.path.join("/", norm_stage_path) - (db, schema, stage, path) = identifier.parse_schema_level_object_identifier(norm_stage_path) + (db, schema, stage, path) = identifier.parse_snowflake_stage_path(norm_stage_path) substitutes = { "image": image, "predict_endpoint_name": constants.PREDICT, diff --git a/snowflake/ml/model/_model_composer/model_composer.py b/snowflake/ml/model/_model_composer/model_composer.py index f80c5760..2946deea 100644 --- a/snowflake/ml/model/_model_composer/model_composer.py +++ b/snowflake/ml/model/_model_composer/model_composer.py @@ -92,6 +92,7 @@ def save( python_version: Optional[str] = None, ext_modules: Optional[List[ModuleType]] = None, code_paths: Optional[List[str]] = None, + model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN, options: Optional[model_types.ModelSaveOption] = None, ) -> model_meta.ModelMetadata: if not options: @@ -120,6 +121,7 @@ def save( python_version=python_version, ext_modules=ext_modules, code_paths=code_paths, + model_objective=model_objective, options=options, ) assert self.packager.meta is not None diff --git a/snowflake/ml/model/_model_composer/model_composer_test.py b/snowflake/ml/model/_model_composer/model_composer_test.py index efb9c4cb..6359160a 100644 --- a/snowflake/ml/model/_model_composer/model_composer_test.py +++ b/snowflake/ml/model/_model_composer/model_composer_test.py @@ -77,15 +77,17 @@ def test_save_interface(self) -> None: name="model1", model=linear_model.LinearRegression(), sample_input_data=d, + model_objective=model_types.ModelObjective.REGRESSION, ) + + mock_upload_directory_to_stage.assert_called_once_with( + c_session, + local_path=mock.ANY, + stage_path=pathlib.PurePosixPath(stage_path), + statement_params=None, + ) mock_save.assert_called_once() mock_manifest_save.assert_called_once() - mock_upload_directory_to_stage.assert_called_once_with( - c_session, - local_path=mock.ANY, - stage_path=pathlib.PurePosixPath(stage_path), - statement_params=None, - ) def test_load(self) -> None: m_options = model_types.PyTorchLoadOptions(use_gpu=False) diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py index 6fcc89bb..7074818c 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py @@ -1,7 +1,6 @@ import collections import copy import pathlib -import warnings from typing import List, Optional, cast import yaml @@ -78,13 +77,9 @@ def save( ) dependencies = model_manifest_schema.ModelRuntimeDependenciesDict(conda=runtime_dict["dependencies"]["conda"]) - if options.get("include_pip_dependencies"): - warnings.warn( - "`include_pip_dependencies` specified as True: pip dependencies will be included and may not" - "be warehouse-compabible. The model may need to be run in SPCS.", - category=UserWarning, - stacklevel=1, - ) + + # We only want to include pip dependencies file if there are any pip requirements. + if len(model_meta.env.pip_requirements) > 0: dependencies["pip"] = runtime_dict["dependencies"]["pip"] manifest_dict = model_manifest_schema.ModelManifestDict( diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py index 7c19ca5f..2ee07e5c 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py @@ -282,12 +282,13 @@ def test_model_manifest_pip(self) -> None: name="model1", model_type="custom", signatures={"predict": _DUMMY_SIG["predict"]}, + pip_requirements=["xgboost"], python_version="3.8", embed_local_ml_library=True, ) as meta: meta.models["model1"] = _DUMMY_BLOB - mm.save(meta, pathlib.PurePosixPath("model"), options={"include_pip_dependencies": True}) + mm.save(meta, pathlib.PurePosixPath("model")) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: self.assertEqual( ( diff --git a/snowflake/ml/model/_packager/model_handlers/BUILD.bazel b/snowflake/ml/model/_packager/model_handlers/BUILD.bazel index 6d9bf8ae..091a0269 100644 --- a/snowflake/ml/model/_packager/model_handlers/BUILD.bazel +++ b/snowflake/ml/model/_packager/model_handlers/BUILD.bazel @@ -24,6 +24,14 @@ py_library( ], ) +py_library( + name = "model_objective_utils", + srcs = ["model_objective_utils.py"], + deps = [ + ":_utils", + ], +) + py_library( name = "catboost", srcs = ["catboost.py"], @@ -89,6 +97,7 @@ py_library( "//snowflake/ml/model:model_signature", "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_packager/model_env", + "//snowflake/ml/model/_packager/model_handlers:model_objective_utils", "//snowflake/ml/model/_packager/model_handlers_migrator:base_migrator", "//snowflake/ml/model/_packager/model_meta", "//snowflake/ml/model/_packager/model_meta:model_blob_meta", @@ -108,6 +117,7 @@ py_library( "//snowflake/ml/model:custom_model", "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_packager/model_env", + "//snowflake/ml/model/_packager/model_handlers:model_objective_utils", "//snowflake/ml/model/_packager/model_handlers_migrator:base_migrator", "//snowflake/ml/model/_packager/model_meta", "//snowflake/ml/model/_packager/model_meta:model_blob_meta", @@ -127,6 +137,7 @@ py_library( "//snowflake/ml/model:custom_model", "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_packager/model_env", + "//snowflake/ml/model/_packager/model_handlers:model_objective_utils", "//snowflake/ml/model/_packager/model_handlers_migrator:base_migrator", "//snowflake/ml/model/_packager/model_meta", "//snowflake/ml/model/_packager/model_meta:model_blob_meta", diff --git a/snowflake/ml/model/_packager/model_handlers/_utils.py b/snowflake/ml/model/_packager/model_handlers/_utils.py index fbd811bf..445e3e78 100644 --- a/snowflake/ml/model/_packager/model_handlers/_utils.py +++ b/snowflake/ml/model/_packager/model_handlers/_utils.py @@ -1,9 +1,11 @@ import json +import warnings from typing import Any, Callable, Iterable, Optional, Sequence, cast import numpy as np import numpy.typing as npt import pandas as pd +from absl import logging from snowflake.ml.model import model_signature, type_hints as model_types from snowflake.ml.model._packager.model_meta import model_meta @@ -11,6 +13,17 @@ from snowflake.snowpark import DataFrame as SnowparkDataFrame +class NumpyEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return super().default(obj) + + def _is_callable(model: model_types.SupportedModelType, method_name: str) -> bool: return callable(getattr(model, method_name, None)) @@ -93,23 +106,42 @@ def convert_explanations_to_2D_df( return pd.DataFrame(explanations) if hasattr(model, "classes_"): - classes_list = [cl for cl in model.classes_] # type:ignore[union-attr] + classes_list = [str(cl) for cl in model.classes_] # type:ignore[union-attr] len_classes = len(classes_list) if explanations.shape[2] != len_classes: raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}") else: - classes_list = [i for i in range(explanations.shape[2])] - exp_2d = [] - # TODO (SNOW-1549044): Optimize this - for row in explanations: - col_list = [] - for column in row: - class_explanations = {} - for cl, cl_exp in zip(classes_list, column): - if isinstance(cl, (int, np.integer)): - cl = int(cl) - class_explanations[cl] = cl_exp - col_list.append(json.dumps(class_explanations)) - exp_2d.append(col_list) + classes_list = [str(i) for i in range(explanations.shape[2])] + + def row_to_dict(row: npt.NDArray[Any]) -> npt.NDArray[Any]: + """Converts a single row to a dictionary.""" + # convert to object or numpy creates strings of fixed length + return np.asarray(json.dumps(dict(zip(classes_list, row)), cls=NumpyEncoder), dtype=object) + + exp_2d = np.apply_along_axis(row_to_dict, -1, explanations) return pd.DataFrame(exp_2d) + + +def validate_model_objective( + passed_model_objective: model_types.ModelObjective, inferred_model_objective: model_types.ModelObjective +) -> model_types.ModelObjective: + if ( + passed_model_objective != model_types.ModelObjective.UNKNOWN + and inferred_model_objective != model_types.ModelObjective.UNKNOWN + ): + if passed_model_objective != inferred_model_objective: + warnings.warn( + f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for this model " + f"version and passed argument ModelObjective: {passed_model_objective.name} is ignored", + category=UserWarning, + stacklevel=1, + ) + return inferred_model_objective + elif inferred_model_objective != model_types.ModelObjective.UNKNOWN: + logging.info( + f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for this model " + f"version" + ) + return inferred_model_objective + return passed_model_objective diff --git a/snowflake/ml/model/_packager/model_handlers/catboost.py b/snowflake/ml/model/_packager/model_handlers/catboost.py index 6177c843..badf1df6 100644 --- a/snowflake/ml/model/_packager/model_handlers/catboost.py +++ b/snowflake/ml/model/_packager/model_handlers/catboost.py @@ -34,20 +34,20 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]): DEFAULT_TARGET_METHODS = ["predict", "predict_proba"] @classmethod - def get_model_objective(cls, model: "catboost.CatBoost") -> model_meta_schema.ModelObjective: + def get_model_objective_and_output_type(cls, model: "catboost.CatBoost") -> model_types.ModelObjective: import catboost if isinstance(model, catboost.CatBoostClassifier): num_classes = handlers_utils.get_num_classes_if_exists(model) if num_classes == 2: - return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION - return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION + return model_types.ModelObjective.BINARY_CLASSIFICATION + return model_types.ModelObjective.MULTI_CLASSIFICATION if isinstance(model, catboost.CatBoostRanker): - return model_meta_schema.ModelObjective.RANKING + return model_types.ModelObjective.RANKING if isinstance(model, catboost.CatBoostRegressor): - return model_meta_schema.ModelObjective.REGRESSION + return model_types.ModelObjective.REGRESSION # TODO: Find out model type from the generic Catboost Model - return model_meta_schema.ModelObjective.UNKNOWN + return model_types.ModelObjective.UNKNOWN @classmethod def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]: @@ -77,6 +77,8 @@ def save_model( is_sub_model: Optional[bool] = False, **kwargs: Unpack[model_types.CatBoostModelSaveOptions], ) -> None: + enable_explainability = kwargs.get("enable_explainability", True) + import catboost assert isinstance(model, catboost.CatBoost) @@ -105,11 +107,14 @@ def get_prediction( sample_input_data=sample_input_data, get_prediction_fn=get_prediction, ) - model_objective = cls.get_model_objective(model) - model_meta.model_objective = model_objective - if kwargs.get("enable_explainability", True): + inferred_model_objective = cls.get_model_objective_and_output_type(model) + model_meta.model_objective = handlers_utils.validate_model_objective( + model_meta.model_objective, inferred_model_objective + ) + model_objective = model_meta.model_objective + if enable_explainability: output_type = model_signature.DataType.DOUBLE - if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION: + if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION: output_type = model_signature.DataType.STRING model_meta = handlers_utils.add_explain_method_signature( model_meta=model_meta, @@ -143,11 +148,8 @@ def get_prediction( ], check_local_version=True, ) - if kwargs.get("enable_explainability", True): - model_meta.env.include_if_absent( - [model_env.ModelDependency(requirement="shap", pip_name="shap")], - check_local_version=True, - ) + if enable_explainability: + model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")]) model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION) diff --git a/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py b/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py index e7f973cc..f96fc28e 100644 --- a/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +++ b/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py @@ -369,7 +369,9 @@ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame: else: # For others, we could offer the whole dataframe as a list. # Some of them may need some conversion - if isinstance(raw_model, transformers.ConversationalPipeline): + if hasattr(transformers, "ConversationalPipeline") and isinstance( + raw_model, transformers.ConversationalPipeline + ): input_data = [ transformers.Conversation( text=conv_data["user_inputs"][0], @@ -391,27 +393,33 @@ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame: # Making it not aligned with the auto-inferred signature. # If the output is a dict, we could blindly create a list containing that. # Otherwise, creating pandas DataFrame won't succeed. - if isinstance(temp_res, (dict, transformers.Conversation)) or ( - # For some pipeline that is expected to generate a list of dict per input - # When it omit outer list, it becomes list of dict instead of list of list of dict. - # We need to distinguish them from those pipelines that designed to output a dict per input - # So we need to check the pipeline type. - isinstance( - raw_model, - ( - transformers.FillMaskPipeline, - transformers.QuestionAnsweringPipeline, - ), + if ( + (hasattr(transformers, "Conversation") and isinstance(temp_res, transformers.Conversation)) + or isinstance(temp_res, dict) + or ( + # For some pipeline that is expected to generate a list of dict per input + # When it omit outer list, it becomes list of dict instead of list of list of dict. + # We need to distinguish them from those pipelines that designed to output a dict per input + # So we need to check the pipeline type. + isinstance( + raw_model, + ( + transformers.FillMaskPipeline, + transformers.QuestionAnsweringPipeline, + ), + ) + and X.shape[0] == 1 + and isinstance(temp_res[0], dict) ) - and X.shape[0] == 1 - and isinstance(temp_res[0], dict) ): temp_res = [temp_res] if len(temp_res) == 0: return pd.DataFrame() - if isinstance(raw_model, transformers.ConversationalPipeline): + if hasattr(transformers, "ConversationalPipeline") and isinstance( + raw_model, transformers.ConversationalPipeline + ): temp_res = [[conv.generated_responses] for conv in temp_res] # To concat those who outputs a list with one input. diff --git a/snowflake/ml/model/_packager/model_handlers/lightgbm.py b/snowflake/ml/model/_packager/model_handlers/lightgbm.py index 83461abf..779944f0 100644 --- a/snowflake/ml/model/_packager/model_handlers/lightgbm.py +++ b/snowflake/ml/model/_packager/model_handlers/lightgbm.py @@ -19,7 +19,11 @@ from snowflake.ml._internal import type_utils from snowflake.ml.model import custom_model, model_signature, type_hints as model_types from snowflake.ml.model._packager.model_env import model_env -from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils +from snowflake.ml.model._packager.model_handlers import ( + _base, + _utils as handlers_utils, + model_objective_utils, +) from snowflake.ml.model._packager.model_handlers_migrator import base_migrator from snowflake.ml.model._packager.model_meta import ( model_blob_meta, @@ -43,47 +47,6 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb MODEL_BLOB_FILE_OR_DIR = "model.pkl" DEFAULT_TARGET_METHODS = ["predict", "predict_proba"] - _BINARY_CLASSIFICATION_OBJECTIVES = ["binary"] - _MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"] - _RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"] - _REGRESSION_OBJECTIVES = [ - "regression", - "regression_l1", - "huber", - "fair", - "poisson", - "quantile", - "tweedie", - "mape", - "gamma", - ] - - @classmethod - def get_model_objective( - cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"] - ) -> model_meta_schema.ModelObjective: - import lightgbm - - # does not account for cross-entropy and custom - if isinstance(model, lightgbm.LGBMClassifier): - num_classes = handlers_utils.get_num_classes_if_exists(model) - if num_classes == 2: - return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION - return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION - if isinstance(model, lightgbm.LGBMRanker): - return model_meta_schema.ModelObjective.RANKING - if isinstance(model, lightgbm.LGBMRegressor): - return model_meta_schema.ModelObjective.REGRESSION - model_objective = model.params["objective"] - if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES: - return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION - if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES: - return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION - if model_objective in cls._RANKING_OBJECTIVES: - return model_meta_schema.ModelObjective.RANKING - if model_objective in cls._REGRESSION_OBJECTIVES: - return model_meta_schema.ModelObjective.REGRESSION - return model_meta_schema.ModelObjective.UNKNOWN @classmethod def can_handle( @@ -118,6 +81,8 @@ def save_model( is_sub_model: Optional[bool] = False, **kwargs: Unpack[model_types.LGBMModelSaveOptions], ) -> None: + enable_explainability = kwargs.get("enable_explainability", True) + import lightgbm assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel) @@ -146,20 +111,16 @@ def get_prediction( sample_input_data=sample_input_data, get_prediction_fn=get_prediction, ) - model_objective = cls.get_model_objective(model) - model_meta.model_objective = model_objective - if kwargs.get("enable_explainability", True): - output_type = model_signature.DataType.DOUBLE - if model_objective in [ - model_meta_schema.ModelObjective.BINARY_CLASSIFICATION, - model_meta_schema.ModelObjective.MULTI_CLASSIFICATION, - ]: - output_type = model_signature.DataType.STRING + model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model) + model_meta.model_objective = handlers_utils.validate_model_objective( + model_meta.model_objective, model_objective_and_output.objective + ) + if enable_explainability: model_meta = handlers_utils.add_explain_method_signature( model_meta=model_meta, explain_method="explain", target_method="predict", - output_return_type=output_type, + output_return_type=model_objective_and_output.output_type, ) model_meta.function_properties = { "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False} @@ -189,11 +150,8 @@ def get_prediction( ], check_local_version=True, ) - if kwargs.get("enable_explainability", True): - model_meta.env.include_if_absent( - [model_env.ModelDependency(requirement="shap", pip_name="shap")], - check_local_version=True, - ) + if enable_explainability: + model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")]) model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP return None diff --git a/snowflake/ml/model/_packager/model_handlers/llm.py b/snowflake/ml/model/_packager/model_handlers/llm.py index 591bb048..0082c760 100644 --- a/snowflake/ml/model/_packager/model_handlers/llm.py +++ b/snowflake/ml/model/_packager/model_handlers/llm.py @@ -205,7 +205,9 @@ def _prepare_for_lora(self) -> None: "token": raw_model.token, } model_dir_path = raw_model.model_id_or_path - peft_config = peft.PeftConfig.from_pretrained(model_dir_path) # type: ignore[attr-defined] + peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined] + model_dir_path + ) base_model_path = peft_config.base_model_name_or_path tokenizer = transformers.AutoTokenizer.from_pretrained( base_model_path, @@ -221,7 +223,7 @@ def _prepare_for_lora(self) -> None: model_dir_path, device_map="auto", torch_dtype="auto", - **hub_kwargs, + **hub_kwargs, # type: ignore[arg-type] ) hf_model.eval() hf_model = hf_model.merge_and_unload() diff --git a/snowflake/ml/model/_packager/model_handlers/model_objective_utils.py b/snowflake/ml/model/_packager/model_handlers/model_objective_utils.py new file mode 100644 index 00000000..ad20c45c --- /dev/null +++ b/snowflake/ml/model/_packager/model_handlers/model_objective_utils.py @@ -0,0 +1,116 @@ +import json +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Union + +from snowflake.ml.model import model_signature, type_hints +from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils + +if TYPE_CHECKING: + import lightgbm + import xgboost + + +@dataclass +class ModelObjectiveAndOutputType: + objective: type_hints.ModelObjective + output_type: model_signature.DataType + + +def get_model_objective_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> type_hints.ModelObjective: + + import lightgbm + + _BINARY_CLASSIFICATION_OBJECTIVES = ["binary"] + _MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"] + _RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"] + _REGRESSION_OBJECTIVES = [ + "regression", + "regression_l1", + "huber", + "fair", + "poisson", + "quantile", + "tweedie", + "mape", + "gamma", + ] + + # does not account for cross-entropy and custom + if isinstance(model, lightgbm.LGBMClassifier): + num_classes = handlers_utils.get_num_classes_if_exists(model) + if num_classes == 2: + return type_hints.ModelObjective.BINARY_CLASSIFICATION + return type_hints.ModelObjective.MULTI_CLASSIFICATION + if isinstance(model, lightgbm.LGBMRanker): + return type_hints.ModelObjective.RANKING + if isinstance(model, lightgbm.LGBMRegressor): + return type_hints.ModelObjective.REGRESSION + model_objective = model.params["objective"] + if model_objective in _BINARY_CLASSIFICATION_OBJECTIVES: + return type_hints.ModelObjective.BINARY_CLASSIFICATION + if model_objective in _MULTI_CLASSIFICATION_OBJECTIVES: + return type_hints.ModelObjective.MULTI_CLASSIFICATION + if model_objective in _RANKING_OBJECTIVES: + return type_hints.ModelObjective.RANKING + if model_objective in _REGRESSION_OBJECTIVES: + return type_hints.ModelObjective.REGRESSION + return type_hints.ModelObjective.UNKNOWN + + +def get_model_objective_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.ModelObjective: + + import xgboost + + _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"] + _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"] + _RANKING_OBJECTIVE_PREFIX = ["rank:"] + _REGRESSION_OBJECTIVE_PREFIX = ["reg:"] + + model_objective = "" + if isinstance(model, xgboost.Booster): + model_params = json.loads(model.save_config()) + model_objective = model_params.get("learner", {}).get("objective", "") + else: + if hasattr(model, "get_params"): + model_objective = model.get_params().get("objective", "") + + if isinstance(model_objective, dict): + model_objective = model_objective.get("name", "") + for classification_objective in _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX: + if classification_objective in model_objective: + return type_hints.ModelObjective.BINARY_CLASSIFICATION + for classification_objective in _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX: + if classification_objective in model_objective: + return type_hints.ModelObjective.MULTI_CLASSIFICATION + for ranking_objective in _RANKING_OBJECTIVE_PREFIX: + if ranking_objective in model_objective: + return type_hints.ModelObjective.RANKING + for regression_objective in _REGRESSION_OBJECTIVE_PREFIX: + if regression_objective in model_objective: + return type_hints.ModelObjective.REGRESSION + return type_hints.ModelObjective.UNKNOWN + + +def get_model_objective_and_output_type(model: Any) -> ModelObjectiveAndOutputType: + import xgboost + + if isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel): + model_objective = get_model_objective_xgb(model) + output_type = model_signature.DataType.DOUBLE + if model_objective == type_hints.ModelObjective.MULTI_CLASSIFICATION: + output_type = model_signature.DataType.STRING + return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type) + + import lightgbm + + if isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel): + model_objective = get_model_objective_lightgbm(model) + output_type = model_signature.DataType.DOUBLE + if model_objective in [ + type_hints.ModelObjective.BINARY_CLASSIFICATION, + type_hints.ModelObjective.MULTI_CLASSIFICATION, + ]: + output_type = model_signature.DataType.STRING + return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type) + + raise ValueError(f"Model type {type(model)} is not supported") diff --git a/snowflake/ml/model/_packager/model_handlers/sklearn.py b/snowflake/ml/model/_packager/model_handlers/sklearn.py index d9ab8d5d..1c9e2f11 100644 --- a/snowflake/ml/model/_packager/model_handlers/sklearn.py +++ b/snowflake/ml/model/_packager/model_handlers/sklearn.py @@ -45,23 +45,23 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", @classmethod def get_model_objective( cls, model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"] - ) -> model_meta_schema.ModelObjective: + ) -> model_types.ModelObjective: import sklearn.pipeline from sklearn.base import is_classifier, is_regressor if isinstance(model, sklearn.pipeline.Pipeline): - return model_meta_schema.ModelObjective.UNKNOWN + return model_types.ModelObjective.UNKNOWN if is_regressor(model): - return model_meta_schema.ModelObjective.REGRESSION + return model_types.ModelObjective.REGRESSION if is_classifier(model): classes_list = getattr(model, "classes_", []) num_classes = getattr(model, "n_classes_", None) or len(classes_list) if isinstance(num_classes, int): if num_classes > 2: - return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION - return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION - return model_meta_schema.ModelObjective.UNKNOWN - return model_meta_schema.ModelObjective.UNKNOWN + return model_types.ModelObjective.MULTI_CLASSIFICATION + return model_types.ModelObjective.BINARY_CLASSIFICATION + return model_types.ModelObjective.UNKNOWN + return model_types.ModelObjective.UNKNOWN @classmethod def can_handle( @@ -95,6 +95,18 @@ def cast_model( return cast(Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"], model) + @staticmethod + def get_explainability_supported_background( + sample_input_data: Optional[model_types.SupportedDataType] = None, + ) -> Optional[pd.DataFrame]: + if isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame): + return ( + sample_input_data + if isinstance(sample_input_data, pd.DataFrame) + else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data) + ) + return None + @classmethod def save_model( cls, @@ -106,32 +118,30 @@ def save_model( is_sub_model: Optional[bool] = False, **kwargs: Unpack[model_types.SKLModelSaveOptions], ) -> None: - enable_explainability = kwargs.get("enable_explainability", False) + # setting None by default to distinguish if users did not set it + enable_explainability = kwargs.get("enable_explainability", None) import sklearn.base import sklearn.pipeline assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline) - enable_explainability = kwargs.get("enable_explainability", False) + background_data = cls.get_explainability_supported_background(sample_input_data) + + # if users did not ask then we enable if we have background data + if enable_explainability is None and background_data is not None: + enable_explainability = True if enable_explainability: - # TODO: Currently limited to pandas df, need to extend to other types. - if sample_input_data is None or not ( - isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame) - ): + # if users set it explicitly but no background data then error out + if background_data is None: raise ValueError( "Sample input data is required to enable explainability. Currently we only support this for " + "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`." ) - sample_input_data_pandas = ( - sample_input_data - if isinstance(sample_input_data, pd.DataFrame) - else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data) - ) data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR) os.makedirs(data_blob_path, exist_ok=True) with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f: - sample_input_data_pandas.to_parquet(f) + background_data.to_parquet(f) if not is_sub_model: target_methods = handlers_utils.get_target_methods( @@ -159,9 +169,13 @@ def get_prediction( get_prediction_fn=get_prediction, ) + model_objective = cls.get_model_objective(model) + model_meta.model_objective = model_objective + if enable_explainability: output_type = model_signature.DataType.DOUBLE - if cls.get_model_objective(model) == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION: + + if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION: output_type = model_signature.DataType.STRING model_meta = handlers_utils.add_explain_method_signature( model_meta=model_meta, @@ -184,10 +198,8 @@ def get_prediction( model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION if enable_explainability: - model_meta.env.include_if_absent( - [model_env.ModelDependency(requirement="shap", pip_name="shap")], - check_local_version=True, - ) + model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")]) + model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP model_meta.env.include_if_absent( [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], check_local_version=True diff --git a/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py b/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py index 0152298b..d6d51fb7 100644 --- a/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +++ b/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py @@ -1,20 +1,27 @@ import os import warnings -from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final import cloudpickle import numpy as np import pandas as pd +from packaging import version from typing_extensions import TypeGuard, Unpack from snowflake.ml._internal import type_utils +from snowflake.ml._internal.exceptions import exceptions from snowflake.ml.model import custom_model, model_signature, type_hints as model_types from snowflake.ml.model._packager.model_env import model_env -from snowflake.ml.model._packager.model_handlers import _base +from snowflake.ml.model._packager.model_handlers import ( + _base, + _utils as handlers_utils, + model_objective_utils, +) from snowflake.ml.model._packager.model_handlers_migrator import base_migrator from snowflake.ml.model._packager.model_meta import ( model_blob_meta, model_meta as model_meta_api, + model_meta_schema, ) from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils @@ -62,6 +69,53 @@ def cast_model( return cast("BaseEstimator", model) + @classmethod + def _get_local_version_package(cls, pkg_name: str) -> Optional[version.Version]: + import importlib_metadata + from packaging import version + + local_version = None + + try: + local_dist = importlib_metadata.distribution(pkg_name) # type: ignore[no-untyped-call] + local_version = version.parse(local_dist.version) + except importlib_metadata.PackageNotFoundError: + pass + + return local_version + + @classmethod + def _can_support_xgb(cls, enable_explainability: Optional[bool]) -> bool: + + local_xgb_version = cls._get_local_version_package("xgboost") + + if local_xgb_version and local_xgb_version >= version.parse("2.1.0"): + if enable_explainability: + warnings.warn( + f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1." + + "If you want model explanations, lower the xgboost version to <2.1.0.", + category=UserWarning, + stacklevel=1, + ) + return False + return True + + @classmethod + def _get_supported_object_for_explainability( + cls, estimator: "BaseEstimator", enable_explainability: Optional[bool] + ) -> Any: + methods = ["to_xgboost", "to_lightgbm"] + for method_name in methods: + if hasattr(estimator, method_name): + try: + result = getattr(estimator, method_name)() + if method_name == "to_xgboost" and not cls._can_support_xgb(enable_explainability): + return None + return result + except exceptions.SnowflakeMLException: + pass # Do nothing and continue to the next method + return None + @classmethod def save_model( cls, @@ -73,9 +127,8 @@ def save_model( is_sub_model: Optional[bool] = False, **kwargs: Unpack[model_types.SNOWModelSaveOptions], ) -> None: - enable_explainability = kwargs.get("enable_explainability", False) - if enable_explainability: - raise NotImplementedError("Explainability is not supported for Snowpark ML model.") + + enable_explainability = kwargs.get("enable_explainability", None) from snowflake.ml.modeling.framework.base import BaseEstimator @@ -105,6 +158,26 @@ def save_model( raise ValueError(f"Target method {method_name} does not exist in the model.") model_meta.signatures = temp_model_signature_dict + if enable_explainability or enable_explainability is None: + python_base_obj = cls._get_supported_object_for_explainability(model, enable_explainability) + if python_base_obj is None: + if enable_explainability: # if user set enable_explainability to True, throw error else silently skip + raise ValueError("Explain only support for xgboost or lightgbm Snowpark ML models.") + # set None to False so we don't include shap in the environment + enable_explainability = False + else: + model_objective_and_output_type = model_objective_utils.get_model_objective_and_output_type( + python_base_obj + ) + model_meta.model_objective = model_objective_and_output_type.objective + model_meta = handlers_utils.add_explain_method_signature( + model_meta=model_meta, + explain_method="explain", + target_method="predict", + output_return_type=model_objective_and_output_type.output_type, + ) + enable_explainability = True + model_blob_path = os.path.join(model_blobs_dir_path, name) os.makedirs(model_blob_path, exist_ok=True) with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f: @@ -122,7 +195,29 @@ def save_model( model_dependencies = model._get_dependencies() for dep in model_dependencies: pkg_name = dep.split("==")[0] - _include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name)) + if pkg_name != "xgboost": + _include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name)) + continue + + local_xgb_version = cls._get_local_version_package("xgboost") + if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability: + model_meta.env.include_if_absent( + [ + model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"), + ], + check_local_version=False, + ) + else: + model_meta.env.include_if_absent( + [ + model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"), + ], + check_local_version=True, + ) + + if enable_explainability: + model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")]) + model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP model_meta.env.include_if_absent(_include_if_absent_pkgs, check_local_version=True) @classmethod @@ -177,6 +272,24 @@ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame: return model_signature_utils.rename_pandas_df(df, signature.outputs) + @custom_model.inference_api + def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame: + import shap + + methods = ["to_xgboost", "to_lightgbm"] + for method_name in methods: + try: + base_model = getattr(raw_model, method_name)() + explainer = shap.TreeExplainer(base_model) + df = pd.DataFrame(explainer(X).values) + return model_signature_utils.rename_pandas_df(df, signature.outputs) + except exceptions.SnowflakeMLException: + pass # Do nothing and continue to the next method + raise ValueError("The model must be an xgboost or lightgbm estimator.") + + if target_method == "explain": + return explain_fn + return fn type_method_dict = {} diff --git a/snowflake/ml/model/_packager/model_handlers/torchscript.py b/snowflake/ml/model/_packager/model_handlers/torchscript.py index 9dc6bb43..318abe80 100644 --- a/snowflake/ml/model/_packager/model_handlers/torchscript.py +++ b/snowflake/ml/model/_packager/model_handlers/torchscript.py @@ -111,7 +111,7 @@ def get_prediction( model_blob_path = os.path.join(model_blobs_dir_path, name) os.makedirs(model_blob_path, exist_ok=True) with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f: - torch.jit.save(model, f) # type:ignore[attr-defined] + torch.jit.save(model, f) # type:ignore[no-untyped-call, attr-defined] base_meta = model_blob_meta.ModelBlobMeta( name=name, model_type=cls.HANDLER_TYPE, @@ -141,7 +141,7 @@ def load_model( model_blob_metadata = model_blobs_metadata[name] model_blob_filename = model_blob_metadata.path with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f: - m = torch.jit.load( # type:ignore[attr-defined] + m = torch.jit.load( # type:ignore[no-untyped-call, attr-defined] f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu" ) assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined] diff --git a/snowflake/ml/model/_packager/model_handlers/xgboost.py b/snowflake/ml/model/_packager/model_handlers/xgboost.py index f1b5e009..145ddc9f 100644 --- a/snowflake/ml/model/_packager/model_handlers/xgboost.py +++ b/snowflake/ml/model/_packager/model_handlers/xgboost.py @@ -1,6 +1,6 @@ # mypy: disable-error-code="import" -import json import os +import warnings from typing import ( TYPE_CHECKING, Any, @@ -13,14 +13,20 @@ final, ) +import importlib_metadata import numpy as np import pandas as pd +from packaging import version from typing_extensions import TypeGuard, Unpack from snowflake.ml._internal import type_utils from snowflake.ml.model import custom_model, model_signature, type_hints as model_types from snowflake.ml.model._packager.model_env import model_env -from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils +from snowflake.ml.model._packager.model_handlers import ( + _base, + _utils as handlers_utils, + model_objective_utils, +) from snowflake.ml.model._packager.model_handlers_migrator import base_migrator from snowflake.ml.model._packager.model_meta import ( model_blob_meta, @@ -47,41 +53,6 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X MODEL_BLOB_FILE_OR_DIR = "model.ubj" DEFAULT_TARGET_METHODS = ["predict", "predict_proba"] - _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"] - _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"] - _RANKING_OBJECTIVE_PREFIX = ["rank:"] - _REGRESSION_OBJECTIVE_PREFIX = ["reg:"] - - @classmethod - def get_model_objective( - cls, model: Union["xgboost.Booster", "xgboost.XGBModel"] - ) -> model_meta_schema.ModelObjective: - import xgboost - - if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier): - num_classes = handlers_utils.get_num_classes_if_exists(model) - if num_classes == 2: - return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION - return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION - if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor): - return model_meta_schema.ModelObjective.REGRESSION - if isinstance(model, xgboost.XGBRanker): - return model_meta_schema.ModelObjective.RANKING - model_params = json.loads(model.save_config()) - model_objective = model_params["learner"]["objective"] - for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX: - if classification_objective in model_objective: - return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION - for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX: - if classification_objective in model_objective: - return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION - for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX: - if ranking_objective in model_objective: - return model_meta_schema.ModelObjective.RANKING - for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX: - if regression_objective in model_objective: - return model_meta_schema.ModelObjective.REGRESSION - return model_meta_schema.ModelObjective.UNKNOWN @classmethod def can_handle( @@ -116,10 +87,29 @@ def save_model( is_sub_model: Optional[bool] = False, **kwargs: Unpack[model_types.XGBModelSaveOptions], ) -> None: + enable_explainability = kwargs.get("enable_explainability", True) + import xgboost assert isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel) + local_xgb_version = None + + try: + local_dist = importlib_metadata.distribution("xgboost") # type: ignore[no-untyped-call] + local_xgb_version = version.parse(local_dist.version) + except importlib_metadata.PackageNotFoundError: + pass + + if local_xgb_version and local_xgb_version >= version.parse("2.1.0") and enable_explainability: + warnings.warn( + f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1." + + "If you want model explanations, lower the xgboost version to <2.1.0.", + category=UserWarning, + stacklevel=1, + ) + enable_explainability = False + if not is_sub_model: target_methods = handlers_utils.get_target_methods( model=model, @@ -148,17 +138,16 @@ def get_prediction( sample_input_data=sample_input_data, get_prediction_fn=get_prediction, ) - model_objective = cls.get_model_objective(model) - model_meta.model_objective = model_objective - if kwargs.get("enable_explainability", True): - output_type = model_signature.DataType.DOUBLE - if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION: - output_type = model_signature.DataType.STRING + model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model) + model_meta.model_objective = handlers_utils.validate_model_objective( + model_meta.model_objective, model_objective_and_output.objective + ) + if enable_explainability: model_meta = handlers_utils.add_explain_method_signature( model_meta=model_meta, explain_method="explain", target_method="predict", - output_return_type=output_type, + output_return_type=model_objective_and_output.output_type, ) model_meta.function_properties = { "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False} @@ -180,15 +169,26 @@ def get_prediction( model_meta.env.include_if_absent( [ model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"), - model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"), ], check_local_version=True, ) - if kwargs.get("enable_explainability", True): + if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability: model_meta.env.include_if_absent( - [model_env.ModelDependency(requirement="shap", pip_name="shap")], + [ + model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"), + ], + check_local_version=False, + ) + else: + model_meta.env.include_if_absent( + [ + model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"), + ], check_local_version=True, ) + + if enable_explainability: + model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")]) model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION) @@ -269,7 +269,7 @@ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame: import shap explainer = shap.TreeExplainer(raw_model) - df = pd.DataFrame(explainer(X).values) + df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values) return model_signature_utils.rename_pandas_df(df, signature.outputs) if target_method == "explain": diff --git a/snowflake/ml/model/_packager/model_handlers_test/BUILD.bazel b/snowflake/ml/model/_packager/model_handlers_test/BUILD.bazel index cc8d104f..1e8de101 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/BUILD.bazel +++ b/snowflake/ml/model/_packager/model_handlers_test/BUILD.bazel @@ -21,6 +21,16 @@ py_test( ], ) +py_test( + name = "model_objective_utils_test", + srcs = ["model_objective_utils_test.py"], + deps = [ + "//snowflake/ml/model:model_signature", + "//snowflake/ml/model/_packager/model_handlers:_utils", + "//snowflake/ml/model/_packager/model_handlers:model_objective_utils", + ], +) + py_test( name = "catboost_test", srcs = ["catboost_test.py"], @@ -79,6 +89,7 @@ py_test( "//snowflake/ml/model:model_signature", "//snowflake/ml/model/_packager:model_packager", "//snowflake/ml/modeling/linear_model:linear_regression", + "//snowflake/ml/modeling/xgboost:xgb_regressor", ], ) @@ -110,6 +121,7 @@ py_test( deps = [ "//snowflake/ml/model:model_signature", "//snowflake/ml/model/_packager:model_packager", + "//snowflake/ml/model/_packager/model_handlers_test:test_utils", ], ) diff --git a/snowflake/ml/model/_packager/model_handlers_test/_utils_test.py b/snowflake/ml/model/_packager/model_handlers_test/_utils_test.py index 1e02e8f0..df627ab6 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/_utils_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/_utils_test.py @@ -5,7 +5,7 @@ import pandas as pd from absl.testing import absltest -from snowflake.ml.model import model_signature +from snowflake.ml.model import model_signature, type_hints from snowflake.ml.model._packager.model_env import model_env from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils from snowflake.ml.model._packager.model_meta import model_meta @@ -106,6 +106,39 @@ def test_convert_explanations_to_2D_df_multi_value_no_class_attr(self) -> None: ) pd.testing.assert_frame_equal(explanations_df, expected_df) + def test_validate_model_objective(self) -> None: + + model_objective_list = list(type_hints.ModelObjective) + for model_objective in model_objective_list: + for inferred_model_objective in model_objective_list: + expected_model_objective = ( + inferred_model_objective + if inferred_model_objective != type_hints.ModelObjective.UNKNOWN + else model_objective + ) + self.assertEqual( + expected_model_objective, + handlers_utils.validate_model_objective(model_objective, inferred_model_objective), + ) + if inferred_model_objective != type_hints.ModelObjective.UNKNOWN: + if model_objective == type_hints.ModelObjective.UNKNOWN: + with self.assertLogs(level="INFO") as cm: + handlers_utils.validate_model_objective(model_objective, inferred_model_objective) + assert len(cm.output) == 1, "expecting only 1 log" + log = cm.output[0] + self.assertEqual( + f"INFO:absl:Inferred ModelObjective: {inferred_model_objective.name} is used as model " + f"objective for this model version", + log, + ) + elif inferred_model_objective != model_objective: + with self.assertWarnsRegex( + UserWarning, + f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for " + f"this model version and passed argument ModelObjective: {model_objective.name} is ignored", + ): + handlers_utils.validate_model_objective(model_objective, inferred_model_objective) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_packager/model_handlers_test/catboost_test.py b/snowflake/ml/model/_packager/model_handlers_test/catboost_test.py index d14ff98b..e0aae3b3 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/catboost_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/catboost_test.py @@ -11,6 +11,7 @@ from snowflake.ml.model import model_signature, type_hints as model_types from snowflake.ml.model._packager import model_packager +from snowflake.ml.model._packager.model_handlers import catboost as catboost_handler from snowflake.ml.model._packager.model_handlers_test import test_utils @@ -214,6 +215,40 @@ def test_catboost_multiclass_explainablity_enabled(self) -> None: test_utils.convert2D_json_to_3D(explain_method(cal_X_test).to_numpy()), explanations ) + def test_model_objective_catboost_binary_classifier(self) -> None: + cal_data = datasets.load_breast_cancer() + cal_X = pd.DataFrame(cal_data.data, columns=cal_data.feature_names) + cal_y = pd.Series(cal_data.target) + catboost_binary_classifier = catboost.CatBoostClassifier() + catboost_binary_classifier.fit(cal_X, cal_y) + self.assertEqual( + model_types.ModelObjective.BINARY_CLASSIFICATION, + catboost_handler.CatBoostModelHandler.get_model_objective_and_output_type(catboost_binary_classifier), + ) + + def test_model_objective_catboost_multi_classifier(self) -> None: + cal_data = datasets.load_iris() + cal_X = pd.DataFrame(cal_data.data, columns=cal_data.feature_names) + cal_y = pd.Series(cal_data.target) + catboost_multi_classifier = catboost.CatBoostClassifier() + catboost_multi_classifier.fit(cal_X, cal_y) + self.assertEqual( + model_types.ModelObjective.MULTI_CLASSIFICATION, + catboost_handler.CatBoostModelHandler.get_model_objective_and_output_type(catboost_multi_classifier), + ) + + def test_model_objective_catboost_ranking(self) -> None: + self.assertEqual( + model_types.ModelObjective.RANKING, + catboost_handler.CatBoostModelHandler.get_model_objective_and_output_type(catboost.CatBoostRanker()), + ) + + def test_model_objective_catboost_regressor(self) -> None: + self.assertEqual( + model_types.ModelObjective.REGRESSION, + catboost_handler.CatBoostModelHandler.get_model_objective_and_output_type(catboost.CatBoostRegressor()), + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_packager/model_handlers_test/huggingface_pipeline_test.py b/snowflake/ml/model/_packager/model_handlers_test/huggingface_pipeline_test.py index f4373cfa..c97371b4 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/huggingface_pipeline_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/huggingface_pipeline_test.py @@ -8,6 +8,7 @@ import pandas as pd import torch from absl.testing import absltest +from packaging import version from snowflake.ml.model._packager import model_packager from snowflake.ml.model._packager.model_handlers.huggingface_pipeline import ( @@ -220,6 +221,9 @@ def _basic_test_case( def test_conversational_pipeline(self) -> None: import transformers + if version.parse(transformers.__version__) >= version.parse("4.42.0"): + self.skipTest("This test is not compatible with transformers>=4.42.0") + x = transformers.Conversation( text="Do you know how to say Snowflake in French?", past_user_inputs=["Do you speak French?"], diff --git a/snowflake/ml/model/_packager/model_handlers_test/model_objective_utils_test.py b/snowflake/ml/model/_packager/model_handlers_test/model_objective_utils_test.py new file mode 100644 index 00000000..d06b6fd0 --- /dev/null +++ b/snowflake/ml/model/_packager/model_handlers_test/model_objective_utils_test.py @@ -0,0 +1,127 @@ +from typing import Any + +import lightgbm +import numpy as np +import pandas as pd +import xgboost +from absl.testing import absltest +from sklearn import datasets + +from snowflake.ml.model import model_signature, type_hints +from snowflake.ml.model._packager.model_handlers import model_objective_utils + +binary_dataset = datasets.load_breast_cancer() +binary_data_X = pd.DataFrame(binary_dataset.data, columns=binary_dataset.feature_names) +binary_data_y = pd.Series(binary_dataset.target) + +multiclass_data = datasets.load_iris() +multiclass_data_X = pd.DataFrame(multiclass_data.data, columns=multiclass_data.feature_names) +multiclass_data_y = pd.Series(multiclass_data.target) + + +class ModelObjectiveUtilsTest(absltest.TestCase): + def _validate_model_objective_and_output( + self, + model: Any, + expected_objective: type_hints.ModelObjective, + expected_output: model_signature.DataType, + ) -> None: + model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model) + self.assertEqual(expected_objective, model_objective_and_output.objective) + self.assertEqual(expected_output, model_objective_and_output.output_type) + + def test_model_objective_and_output_xgb_binary_classifier(self) -> None: + classifier = xgboost.XGBClassifier(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) + classifier.fit(binary_data_X, binary_data_y) + self._validate_model_objective_and_output( + classifier, type_hints.ModelObjective.BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE + ) + + def test_model_objective_and_output_xgb_for_single_class(self) -> None: + single_class_y = pd.Series([0] * len(binary_dataset.target)) + # without objective + classifier = xgboost.XGBClassifier() + classifier.fit(binary_data_X, single_class_y) + self._validate_model_objective_and_output( + classifier, type_hints.ModelObjective.BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE + ) + # with binary objective + classifier = xgboost.XGBClassifier(objective="binary:logistic") + classifier.fit(binary_data_X, single_class_y) + self._validate_model_objective_and_output( + classifier, type_hints.ModelObjective.BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE + ) + # with multiclass objective + params = {"objective": "multi:softmax", "num_class": 3} + classifier = xgboost.XGBClassifier(**params) + classifier.fit(binary_data_X, single_class_y) + self._validate_model_objective_and_output( + classifier, type_hints.ModelObjective.MULTI_CLASSIFICATION, model_signature.DataType.STRING + ) + + def test_model_objective_and_output_xgb_multiclass_classifier(self) -> None: + classifier = xgboost.XGBClassifier() + classifier.fit(multiclass_data_X, multiclass_data_y) + self._validate_model_objective_and_output( + classifier, type_hints.ModelObjective.MULTI_CLASSIFICATION, model_signature.DataType.STRING + ) + + def test_model_objective_and_output_xgb_regressor(self) -> None: + regressor = xgboost.XGBRegressor() + regressor.fit(multiclass_data_X, multiclass_data_y) + self._validate_model_objective_and_output( + regressor, type_hints.ModelObjective.REGRESSION, model_signature.DataType.DOUBLE + ) + + def test_model_objective_and_output_xgb_booster(self) -> None: + params = dict(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3, objective="binary:logistic") + booster = xgboost.train(params, xgboost.DMatrix(data=binary_data_X, label=binary_data_y)) + self._validate_model_objective_and_output( + booster, type_hints.ModelObjective.BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE + ) + + def test_model_objective_and_output_xgb_ranker(self) -> None: + # Make a synthetic ranking dataset for demonstration + seed = 1994 + X, y = datasets.make_classification(random_state=seed) + rng = np.random.default_rng(seed) + n_query_groups = 3 + qid = rng.integers(0, n_query_groups, size=X.shape[0]) + + # Sort the inputs based on query index + sorted_idx = np.argsort(qid) + X = X[sorted_idx, :] + y = y[sorted_idx] + qid = qid[sorted_idx] + ranker = xgboost.XGBRanker( + tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk" + ) + ranker.fit(X, y, qid=qid) + self._validate_model_objective_and_output( + ranker, type_hints.ModelObjective.RANKING, model_signature.DataType.DOUBLE + ) + + def test_model_objective_and_output_lightgbm_classifier(self) -> None: + classifier = lightgbm.LGBMClassifier() + classifier.fit(binary_data_X, binary_data_y) + self._validate_model_objective_and_output( + classifier, type_hints.ModelObjective.BINARY_CLASSIFICATION, model_signature.DataType.STRING + ) + + def test_model_objective_and_output_lightgbm_booster(self) -> None: + booster = lightgbm.train({"objective": "binary"}, lightgbm.Dataset(binary_data_X, label=binary_data_y)) + self._validate_model_objective_and_output( + booster, type_hints.ModelObjective.BINARY_CLASSIFICATION, model_signature.DataType.STRING + ) + + def test_model_objective_and_output_unknown_model(self) -> None: + def unknown_model(x: int) -> int: + return x + 1 + + with self.assertRaises(ValueError) as e: + model_objective_utils.get_model_objective_and_output_type(unknown_model) + self.assertEqual(str(e.exception), "Model type is not supported") + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/model/_packager/model_handlers_test/pytorch_test.py b/snowflake/ml/model/_packager/model_handlers_test/pytorch_test.py index 652c1e23..6e4d9a1e 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/pytorch_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/pytorch_test.py @@ -35,7 +35,7 @@ def _prepare_torch_model( n_input, n_hidden, n_out, batch_size, learning_rate = 10, 15, 1, 100, 0.01 x = np.random.rand(batch_size, n_input) data_x = torch.from_numpy(x).to(dtype=dtype) - data_y = (torch.rand(size=(batch_size, 1)) < 0.5).to(dtype=dtype) + data_y = (torch.rand(size=(batch_size, 1)) < 0.5).to(dtype=dtype) # type: ignore[attr-defined] model = TorchModel(n_input, n_hidden, n_out, dtype=dtype) loss_function = torch.nn.MSELoss() diff --git a/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py b/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py index d49cd837..73c33bc9 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py @@ -13,7 +13,7 @@ class SKLearnHandlerTest(absltest.TestCase): - def test_skl_multiple_output_proba(self) -> None: + def test_skl_multiple_output_proba_no_explain(self) -> None: iris_X, iris_y = datasets.load_iris(return_X_y=True) target2 = np.random.randint(0, 6, size=iris_y.shape) dual_target = np.vstack([iris_y, target2]).T @@ -57,7 +57,11 @@ def test_skl_multiple_output_proba(self) -> None: model=model, sample_input_data=iris_X_df, metadata={"author": "halu", "version": "1"}, - options=model_types.SKLModelSaveOptions({"target_methods": ["random"]}), + options=model_types.SKLModelSaveOptions( + { + "target_methods": ["random"], + } + ), ) model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( @@ -65,6 +69,11 @@ def test_skl_multiple_output_proba(self) -> None: model=model, sample_input_data=iris_X_df, metadata={"author": "halu", "version": "1"}, + options=model_types.SKLModelSaveOptions( + { + "enable_explainability": False, + } + ), ) pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) @@ -123,7 +132,6 @@ def test_skl_unsupported_explain(self) -> None: model=model, sample_input_data=iris_X_df, metadata={"author": "halu", "version": "1"}, - options=model_types.SKLModelSaveOptions(enable_explainability=True), ) pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) @@ -159,7 +167,7 @@ def test_skl_unsupported_explain(self) -> None: with self.assertRaises(ValueError): explain_method(iris_X_df[-10:]) - def test_skl(self) -> None: + def test_skl_no_explain(self) -> None: iris_X, iris_y = datasets.load_iris(return_X_y=True) regr = linear_model.LinearRegression() iris_X_df = pd.DataFrame(iris_X, columns=["c1", "c2", "c3", "c4"]) @@ -172,6 +180,7 @@ def test_skl(self) -> None: model=regr, signatures={**s, "another_predict": s["predict"]}, metadata={"author": "halu", "version": "1"}, + options=model_types.SKLModelSaveOptions(enable_explainability=False), ) model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( @@ -179,6 +188,7 @@ def test_skl(self) -> None: model=regr, signatures=s, metadata={"author": "halu", "version": "1"}, + options=model_types.SKLModelSaveOptions(enable_explainability=False), ) with warnings.catch_warnings(): @@ -204,6 +214,7 @@ def test_skl(self) -> None: model=regr, sample_input_data=iris_X_df, metadata={"author": "halu", "version": "1"}, + options=model_types.SKLModelSaveOptions(enable_explainability=False), ) pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) @@ -248,7 +259,6 @@ def test_skl_explain(self) -> None: model=regr, sample_input_data=iris_X_df, metadata={"author": "halu", "version": "1"}, - options=model_types.SKLModelSaveOptions(enable_explainability=True), ) with warnings.catch_warnings(): @@ -265,6 +275,33 @@ def test_skl_explain(self) -> None: np.testing.assert_allclose(np.array([[-0.08254936]]), predict_method(iris_X_df[:1])) np.testing.assert_allclose(explain_method(iris_X_df), explanations) + def test_skl_no_default_explain_without_background_data(self) -> None: + iris_X, iris_y = datasets.load_iris(return_X_y=True) + regr = linear_model.LinearRegression() + iris_X_df = pd.DataFrame(iris_X, columns=["c1", "c2", "c3", "c4"]) + regr.fit(iris_X_df, iris_y) + with tempfile.TemporaryDirectory() as tmpdir: + s = {"predict": model_signature.infer_signature(iris_X_df, regr.predict(iris_X_df))} + + model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( + name="model1", + model=regr, + signatures=s, + metadata={"author": "halu", "version": "1"}, + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + + pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1")) + pk.load(as_custom_model=True) + assert pk.model + assert pk.meta + predict_method = getattr(pk.model, "predict", None) + explain_method = getattr(pk.model, "explain", None) + assert callable(predict_method) + self.assertEqual(explain_method, None) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py b/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py index 5fbb8147..93b7a4e6 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd +import shap from absl.testing import absltest from sklearn import datasets @@ -12,10 +13,11 @@ from snowflake.ml.modeling.linear_model import ( # type:ignore[attr-defined] LinearRegression, ) +from snowflake.ml.modeling.xgboost import XGBRegressor class SnowMLModelHandlerTest(absltest.TestCase): - def test_snowml_all_input(self) -> None: + def test_snowml_all_input_no_explain(self) -> None: iris = datasets.load_iris() df = pd.DataFrame(data=np.c_[iris["data"], iris["target"]], columns=iris["feature_names"] + ["target"]) @@ -31,7 +33,7 @@ def test_snowml_all_input(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: s = {"predict": model_signature.infer_signature(df[INPUT_COLUMNS], regr.predict(df)[[OUTPUT_COLUMNS]])} - with self.assertRaises(NotImplementedError): + with self.assertRaisesRegex(ValueError, "Explain only support for xgboost or lightgbm Snowpark ML models."): model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( name="model1", model=regr, @@ -45,6 +47,7 @@ def test_snowml_all_input(self) -> None: model=regr, signatures=s, metadata={"author": "halu", "version": "1"}, + options={"enable_explainability": False}, ) with self.assertWarnsRegex(UserWarning, "Model signature will automatically be inferred during fitting"): model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( @@ -52,6 +55,7 @@ def test_snowml_all_input(self) -> None: model=regr, sample_input_data=df[INPUT_COLUMNS], metadata={"author": "halu", "version": "1"}, + options={"enable_explainability": False}, ) with tempfile.TemporaryDirectory() as tmpdir: @@ -59,6 +63,7 @@ def test_snowml_all_input(self) -> None: name="model1", model=regr, metadata={"author": "halu", "version": "1"}, + options={"enable_explainability": False}, ) with warnings.catch_warnings(): warnings.simplefilter("error") @@ -97,6 +102,7 @@ def test_snowml_signature_partial_input(self) -> None: name="model1", model=regr, metadata={"author": "halu", "version": "1"}, + options={"enable_explainability": False}, ) with warnings.catch_warnings(): @@ -138,6 +144,7 @@ def test_snowml_signature_drop_input_cols(self) -> None: name="model1", model=regr, metadata={"author": "halu", "version": "1"}, + options={"enable_explainability": False}, ) with warnings.catch_warnings(): @@ -158,6 +165,42 @@ def test_snowml_signature_drop_input_cols(self) -> None: assert callable(predict_method) np.testing.assert_allclose(predictions, predict_method(df[:1])[[OUTPUT_COLUMNS]]) + def test_snowml_xgboost_explain_default(self) -> None: + iris = datasets.load_iris() + + df = pd.DataFrame(data=np.c_[iris["data"], iris["target"]], columns=iris["feature_names"] + ["target"]) + df.columns = [s.replace(" (CM)", "").replace(" ", "") for s in df.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + OUTPUT_COLUMNS = "PREDICTED_TARGET" + regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + regr.fit(df) + + predictions = regr.predict(df[:1])[[OUTPUT_COLUMNS]] + + explanations = shap.TreeExplainer(regr.to_xgboost())(df[INPUT_COLUMNS]).values + + with tempfile.TemporaryDirectory() as tmpdir: + model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( + name="model1", + model=regr, + metadata={"author": "halu", "version": "1"}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + + pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1")) + pk.load(as_custom_model=True) + assert pk.model + assert pk.meta + predict_method = getattr(pk.model, "predict", None) + explain_method = getattr(pk.model, "explain", None) + assert callable(predict_method) + assert callable(explain_method) + np.testing.assert_allclose(predictions, predict_method(df[:1])[[OUTPUT_COLUMNS]]) + np.testing.assert_allclose(explanations, explain_method(df[INPUT_COLUMNS]).values) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_packager/model_handlers_test/torchscript_test.py b/snowflake/ml/model/_packager/model_handlers_test/torchscript_test.py index 4ad227ab..db0702fb 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/torchscript_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/torchscript_test.py @@ -35,7 +35,7 @@ def _prepare_torch_model( n_input, n_hidden, n_out, batch_size, learning_rate = 10, 15, 1, 100, 0.01 x = np.random.rand(batch_size, n_input) data_x = torch.from_numpy(x).to(dtype=dtype) - data_y = (torch.rand(size=(batch_size, 1)) < 0.5).to(dtype=dtype) + data_y = (torch.rand(size=(batch_size, 1)) < 0.5).to(dtype=dtype) # type: ignore[attr-defined] model = TorchModel(n_input, n_hidden, n_out, dtype=dtype) loss_function = torch.nn.MSELoss() diff --git a/snowflake/ml/model/_packager/model_handlers_test/xgboost_test.py b/snowflake/ml/model/_packager/model_handlers_test/xgboost_test.py index 2886d869..49ad9363 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/xgboost_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/xgboost_test.py @@ -11,6 +11,7 @@ from snowflake.ml.model import model_signature, type_hints as model_types from snowflake.ml.model._packager import model_packager +from snowflake.ml.model._packager.model_handlers_test import test_utils class XgboostHandlerTest(absltest.TestCase): @@ -204,6 +205,56 @@ def test_xgb_explainablity_enabled(self) -> None: assert callable(explain_method) np.testing.assert_allclose(explain_method(cal_X_test), explanations) + def test_xgb_explainablity_multiclass(self) -> None: + cal_data = datasets.load_iris() + cal_X = pd.DataFrame(cal_data.data, columns=cal_data.feature_names) + cal_y = pd.Series(cal_data.target) + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + classifier = xgboost.XGBClassifier(reg_lambda=1, gamma=0, max_depth=3) + classifier.fit(cal_X_train, cal_y_train) + y_pred = classifier.predict(cal_X_test) + explanations = shap.TreeExplainer(classifier)(cal_X_test).values + with tempfile.TemporaryDirectory() as tmpdir: + + model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( + name="model1", + model=classifier, + signatures={"predict": model_signature.infer_signature(cal_X_test, y_pred)}, + metadata={"author": "halu", "version": "1"}, + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + + pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1")) + pk.load(as_custom_model=True) + predict_method = getattr(pk.model, "predict", None) + explain_method = getattr(pk.model, "explain", None) + assert callable(predict_method) + assert callable(explain_method) + np.testing.assert_allclose(predict_method(cal_X_test), np.expand_dims(y_pred, axis=1)) + np.testing.assert_allclose( + test_utils.convert2D_json_to_3D(explain_method(cal_X_test).to_numpy()), explanations + ) + + model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( + name="model1_no_sig", + model=classifier, + sample_input_data=cal_X_test, + metadata={"author": "halu", "version": "1"}, + ) + + pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) + pk.load(as_custom_model=True) + predict_method = getattr(pk.model, "predict", None) + assert callable(predict_method) + np.testing.assert_allclose(predict_method(cal_X_test), np.expand_dims(y_pred, axis=1)) + explain_method = getattr(pk.model, "explain", None) + assert callable(explain_method) + np.testing.assert_allclose( + test_utils.convert2D_json_to_3D(explain_method(cal_X_test).to_numpy()), explanations + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_packager/model_meta/model_meta.py b/snowflake/ml/model/_packager/model_meta/model_meta.py index 78704618..d4e7460b 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta.py @@ -55,6 +55,7 @@ def create_model_metadata( conda_dependencies: Optional[List[str]] = None, pip_requirements: Optional[List[str]] = None, python_version: Optional[str] = None, + model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN, **kwargs: Any, ) -> Generator["ModelMetadata", None, None]: """Create a generator for model metadata object. Use generator to ensure correct register and unregister for @@ -74,6 +75,9 @@ def create_model_metadata( pip_requirements: List of pip Python packages requirements for running the model. Defaults to None. python_version: A string of python version where model is run. Used for user override. If specified as None, current version would be captured. Defaults to None. + model_objective: The objective of the Model Version. It is an enum class ModelObjective with values REGRESSION, + BINARY_CLASSIFICATION, MULTI_CLASSIFICATION, RANKING, or UNKNOWN. By default it is set to + ModelObjective.UNKNOWN and may be overridden by inferring from the Model Object. **kwargs: Dict of attributes and values of the metadata. Used when loading from file. Raises: @@ -131,6 +135,7 @@ def create_model_metadata( model_type=model_type, signatures=signatures, function_properties=function_properties, + model_objective=model_objective, ) code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR) @@ -261,7 +266,7 @@ def __init__( min_snowpark_ml_version: Optional[str] = None, models: Optional[Dict[str, model_blob_meta.ModelBlobMeta]] = None, original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION, - model_objective: Optional[model_meta_schema.ModelObjective] = model_meta_schema.ModelObjective.UNKNOWN, + model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN, explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None, ) -> None: self.name = name @@ -287,9 +292,7 @@ def __init__( self.original_metadata_version = original_metadata_version - self.model_objective: model_meta_schema.ModelObjective = ( - model_objective or model_meta_schema.ModelObjective.UNKNOWN - ) + self.model_objective: model_types.ModelObjective = model_objective self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm @property @@ -387,7 +390,7 @@ def _validate_model_metadata(loaded_meta: Any) -> model_meta_schema.ModelMetadat signatures=loaded_meta["signatures"], version=original_loaded_meta_version, min_snowpark_ml_version=loaded_meta_min_snowpark_ml_version, - model_objective=loaded_meta.get("model_objective", model_meta_schema.ModelObjective.UNKNOWN.value), + model_objective=loaded_meta.get("model_objective", model_types.ModelObjective.UNKNOWN.value), explainability=loaded_meta.get("explainability", None), function_properties=loaded_meta.get("function_properties", {}), ) @@ -442,8 +445,8 @@ def load(cls, model_dir_path: str) -> "ModelMetadata": min_snowpark_ml_version=model_dict["min_snowpark_ml_version"], models=models, original_metadata_version=model_dict["version"], - model_objective=model_meta_schema.ModelObjective( - model_dict.get("model_objective", model_meta_schema.ModelObjective.UNKNOWN.value) + model_objective=model_types.ModelObjective( + model_dict.get("model_objective", model_types.ModelObjective.UNKNOWN.value) ), explain_algorithm=explanation_algorithm, function_properties=model_dict.get("function_properties", {}), diff --git a/snowflake/ml/model/_packager/model_meta/model_meta_schema.py b/snowflake/ml/model/_packager/model_meta/model_meta_schema.py index 22efeb01..8b46a19b 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta_schema.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta_schema.py @@ -101,13 +101,5 @@ class ModelMetadataDict(TypedDict): function_properties: NotRequired[Dict[str, Dict[str, Any]]] -class ModelObjective(Enum): - UNKNOWN = "unknown" - BINARY_CLASSIFICATION = "binary_classification" - MULTI_CLASSIFICATION = "multi_classification" - REGRESSION = "regression" - RANKING = "ranking" - - class ModelExplainAlgorithm(Enum): SHAP = "shap" diff --git a/snowflake/ml/model/_packager/model_meta/model_meta_test.py b/snowflake/ml/model/_packager/model_meta/model_meta_test.py index be3f9432..0b2b7f4e 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta_test.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta_test.py @@ -7,7 +7,7 @@ from packaging import requirements, version from snowflake.ml._internal import env as snowml_env, env_utils -from snowflake.ml.model import model_signature +from snowflake.ml.model import model_signature, type_hints from snowflake.ml.model._packager.model_env import model_env from snowflake.ml.model._packager.model_meta import ( model_blob_meta, @@ -645,7 +645,7 @@ def test_model_meta_metadata(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - self.assertEqual(meta.model_objective, model_meta_schema.ModelObjective.UNKNOWN) + self.assertEqual(meta.model_objective, type_hints.ModelObjective.UNKNOWN) self.assertEqual(meta.explain_algorithm, None) saved_meta = meta @@ -686,10 +686,10 @@ def test_model_meta_model_specified_objective(self) -> None: metadata={"foo": "bar"}, ) as meta: meta.models["model1"] = _DUMMY_BLOB - meta.model_objective = model_meta_schema.ModelObjective.REGRESSION + meta.model_objective = type_hints.ModelObjective.REGRESSION loaded_meta = model_meta.ModelMetadata.load(tmpdir) - self.assertEqual(loaded_meta.model_objective, model_meta_schema.ModelObjective.REGRESSION) + self.assertEqual(loaded_meta.model_objective, type_hints.ModelObjective.REGRESSION) def test_model_meta_explain_algorithm(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: @@ -701,11 +701,11 @@ def test_model_meta_explain_algorithm(self) -> None: metadata={"foo": "bar"}, ) as meta: meta.models["model1"] = _DUMMY_BLOB - meta.model_objective = model_meta_schema.ModelObjective.REGRESSION + meta.model_objective = type_hints.ModelObjective.REGRESSION meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP loaded_meta = model_meta.ModelMetadata.load(tmpdir) - self.assertEqual(loaded_meta.model_objective, model_meta_schema.ModelObjective.REGRESSION) + self.assertEqual(loaded_meta.model_objective, type_hints.ModelObjective.REGRESSION) self.assertEqual(loaded_meta.explain_algorithm, model_meta_schema.ModelExplainAlgorithm.SHAP) def test_model_meta_new_fields(self) -> None: diff --git a/snowflake/ml/model/_packager/model_packager.py b/snowflake/ml/model/_packager/model_packager.py index 2f06f385..984381ef 100644 --- a/snowflake/ml/model/_packager/model_packager.py +++ b/snowflake/ml/model/_packager/model_packager.py @@ -47,6 +47,7 @@ def save( ext_modules: Optional[List[ModuleType]] = None, code_paths: Optional[List[str]] = None, options: Optional[model_types.ModelSaveOption] = None, + model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN, ) -> model_meta.ModelMetadata: if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model): raise snowml_exceptions.SnowflakeMLException( @@ -84,6 +85,7 @@ def save( conda_dependencies=conda_dependencies, pip_requirements=pip_requirements, python_version=python_version, + model_objective=model_objective, **options, ) as meta: model_blobs_path = os.path.join(self.local_dir_path, ModelPackager.MODEL_BLOBS_DIR) diff --git a/snowflake/ml/model/_packager/model_packager_test.py b/snowflake/ml/model/_packager/model_packager_test.py index 62103fdb..9d83f4e5 100644 --- a/snowflake/ml/model/_packager/model_packager_test.py +++ b/snowflake/ml/model/_packager/model_packager_test.py @@ -10,7 +10,7 @@ from sklearn import datasets, linear_model from snowflake.ml._internal import file_utils -from snowflake.ml.model import custom_model, model_signature +from snowflake.ml.model import custom_model, model_signature, type_hints from snowflake.ml.model._packager import model_packager from snowflake.ml.modeling.linear_model import ( # type:ignore[attr-defined] LinearRegression, @@ -173,12 +173,14 @@ def test_save_validation_2(self) -> None: name="model1", model=regr, metadata={"author": "halu", "version": "1"}, + model_objective=type_hints.ModelObjective.REGRESSION, ) pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1")) pk.load() assert pk.model assert pk.meta + self.assertEqual(type_hints.ModelObjective.REGRESSION, pk.meta.model_objective) assert isinstance(pk.model, LinearRegression) np.testing.assert_allclose(predictions, desired=pk.model.predict(df[:1])[[OUTPUT_COLUMNS]]) diff --git a/snowflake/ml/model/_signatures/pytorch_handler.py b/snowflake/ml/model/_signatures/pytorch_handler.py index af8d9043..30b7d9ce 100644 --- a/snowflake/ml/model/_signatures/pytorch_handler.py +++ b/snowflake/ml/model/_signatures/pytorch_handler.py @@ -30,7 +30,7 @@ def can_handle(data: model_types.SupportedDataType) -> TypeGuard[Sequence["torch @staticmethod def count(data: Sequence["torch.Tensor"]) -> int: - return min(data_col.shape[0] for data_col in data) + return min(data_col.shape[0] for data_col in data) # type: ignore[no-any-return] @staticmethod def truncate(data: Sequence["torch.Tensor"]) -> Sequence["torch.Tensor"]: diff --git a/snowflake/ml/model/_signatures/utils.py b/snowflake/ml/model/_signatures/utils.py index 1335bbd0..354df78f 100644 --- a/snowflake/ml/model/_signatures/utils.py +++ b/snowflake/ml/model/_signatures/utils.py @@ -110,6 +110,15 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any]) # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline # Needs to convert to conversation object. if task == "conversational": + warnings.warn( + ( + "Conversational pipeline is removed from transformers since 4.42.0. " + "Support will be removed from snowflake-ml-python soon." + ), + category=DeprecationWarning, + stacklevel=1, + ) + return core.ModelSignature( inputs=[ core.FeatureSpec(name="user_inputs", dtype=core.DataType.STRING, shape=(-1,)), diff --git a/snowflake/ml/model/models/llm.py b/snowflake/ml/model/models/llm.py index 7ad9f42f..21aec7ba 100644 --- a/snowflake/ml/model/models/llm.py +++ b/snowflake/ml/model/models/llm.py @@ -70,7 +70,9 @@ def __init__( import peft - peft_config = peft.PeftConfig.from_pretrained(model_id_or_path, **hub_kwargs) # type: ignore[attr-defined] + peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined] + model_id_or_path, **hub_kwargs + ) if peft_config.peft_type != peft.PeftType.LORA: # type: ignore[attr-defined] raise ValueError("Only LORA is supported.") if peft_config.task_type != peft.TaskType.CAUSAL_LM: # type: ignore[attr-defined] diff --git a/snowflake/ml/model/type_hints.py b/snowflake/ml/model/type_hints.py index 1726baec..c8873c6f 100644 --- a/snowflake/ml/model/type_hints.py +++ b/snowflake/ml/model/type_hints.py @@ -1,4 +1,5 @@ # mypy: disable-error-code="import" +from enum import Enum from typing import ( TYPE_CHECKING, Any, @@ -232,7 +233,6 @@ class BaseModelSaveOption(TypedDict): _legacy_save: NotRequired[bool] function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]] method_options: NotRequired[Dict[str, ModelMethodSaveOptions]] - include_pip_dependencies: NotRequired[bool] enable_explainability: NotRequired[bool] @@ -431,3 +431,11 @@ class Deployment(TypedDict): signature: core.ModelSignature options: Required[DeployOptions] details: NotRequired[DeployDetails] + + +class ModelObjective(Enum): + UNKNOWN = "unknown" + BINARY_CLASSIFICATION = "binary_classification" + MULTI_CLASSIFICATION = "multi_classification" + REGRESSION = "regression" + RANKING = "ranking" diff --git a/snowflake/ml/modeling/_internal/constants.py b/snowflake/ml/modeling/_internal/constants.py index c62b1143..15814185 100644 --- a/snowflake/ml/modeling/_internal/constants.py +++ b/snowflake/ml/modeling/_internal/constants.py @@ -1 +1,2 @@ IN_ML_RUNTIME_ENV_VAR = "IN_SPCS_ML_RUNTIME" +USE_OPTIMIZED_DATA_INGESTOR = "USE_OPTIMIZED_DATA_INGESTOR" diff --git a/snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py b/snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py index 3f3a97a1..91dda255 100644 --- a/snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +++ b/snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py @@ -166,10 +166,10 @@ def score( SnowflakeMLException: The input column list does not have one of `X` and `X_test`. """ assert hasattr(self.estimator, "score") # make type checker happy - argspec = inspect.getfullargspec(self.estimator.score) - if "X" in argspec.args: + params = inspect.signature(self.estimator.score).parameters + if "X" in params: score_args = {"X": self.dataset[input_cols]} - elif "X_test" in argspec.args: + elif "X_test" in params: score_args = {"X_test": self.dataset[input_cols]} else: raise exceptions.SnowflakeMLException( @@ -178,10 +178,10 @@ def score( ) if len(label_cols) > 0: - label_arg_name = "Y" if "Y" in argspec.args else "y" + label_arg_name = "Y" if "Y" in params else "y" score_args[label_arg_name] = self.dataset[label_cols].squeeze() - if sample_weight_col is not None and "sample_weight" in argspec.args: + if sample_weight_col is not None and "sample_weight" in params: score_args["sample_weight"] = self.dataset[sample_weight_col].squeeze() score = self.estimator.score(**score_args) diff --git a/snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py b/snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py index 151ad2b3..df706fe3 100644 --- a/snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +++ b/snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py @@ -43,14 +43,14 @@ def train(self) -> object: Trained model """ assert hasattr(self.estimator, "fit") # Keep mypy happy - argspec = inspect.getfullargspec(self.estimator.fit) + params = inspect.signature(self.estimator.fit).parameters args = {"X": self.dataset[self.input_cols]} if self.label_cols: - label_arg_name = "Y" if "Y" in argspec.args else "y" + label_arg_name = "Y" if "Y" in params else "y" args[label_arg_name] = self.dataset[self.label_cols].squeeze() - if self.sample_weight_col is not None and "sample_weight" in argspec.args: + if self.sample_weight_col is not None and "sample_weight" in params: args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze() return self.estimator.fit(**args) @@ -59,6 +59,7 @@ def train_fit_predict( self, expected_output_cols_list: List[str], drop_input_cols: Optional[bool] = False, + example_output_pd_df: Optional[pd.DataFrame] = None, ) -> Tuple[pd.DataFrame, object]: """Trains the model using specified features and target columns from the dataset. This API is different from fit itself because it would also provide the predict @@ -69,6 +70,8 @@ def train_fit_predict( name as a list. Defaults to None. drop_input_cols (Optional[bool]): Boolean to determine whether to drop the input columns from the output dataset. + example_output_pd_df (Optional[pd.DataFrame]): Example output dataframe + This is not used in PandasModelTrainer. It is used in SnowparkModelTrainer. Returns: Tuple[pd.DataFrame, object]: [predicted dataset, estimator] @@ -108,13 +111,13 @@ def train_fit_transform( assert hasattr(self.estimator, "fit") # make type checker happy assert hasattr(self.estimator, "fit_transform") # make type checker happy - argspec = inspect.getfullargspec(self.estimator.fit) + params = inspect.signature(self.estimator.fit).parameters args = {"X": self.dataset[self.input_cols]} if self.label_cols: - label_arg_name = "Y" if "Y" in argspec.args else "y" + label_arg_name = "Y" if "Y" in params else "y" args[label_arg_name] = self.dataset[self.label_cols].squeeze() - if self.sample_weight_col is not None and "sample_weight" in argspec.args: + if self.sample_weight_col is not None and "sample_weight" in params: args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze() inference_res = self.estimator.fit_transform(**args) diff --git a/snowflake/ml/modeling/_internal/model_specifications.py b/snowflake/ml/modeling/_internal/model_specifications.py index 10f5b85f..af7b1ec4 100644 --- a/snowflake/ml/modeling/_internal/model_specifications.py +++ b/snowflake/ml/modeling/_internal/model_specifications.py @@ -53,11 +53,13 @@ def __init__(self) -> None: class XGBoostModelSpecifications(ModelSpecifications): def __init__(self) -> None: + import sklearn import xgboost imports: List[str] = ["xgboost"] pkgDependencies: List[str] = [ f"numpy=={np.__version__}", + f"scikit-learn=={sklearn.__version__}", f"xgboost=={xgboost.__version__}", f"cloudpickle=={cp.__version__}", ] diff --git a/snowflake/ml/modeling/_internal/model_trainer.py b/snowflake/ml/modeling/_internal/model_trainer.py index 7896121b..9ebe02a3 100644 --- a/snowflake/ml/modeling/_internal/model_trainer.py +++ b/snowflake/ml/modeling/_internal/model_trainer.py @@ -20,6 +20,7 @@ def train_fit_predict( self, expected_output_cols_list: List[str], drop_input_cols: Optional[bool] = False, + example_output_pd_df: Optional[pd.DataFrame] = None, ) -> Tuple[Union[DataFrame, pd.DataFrame], object]: raise NotImplementedError diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py b/snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py index 8679fa63..d0e2d065 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py @@ -495,7 +495,7 @@ def _load_data_into_udf() -> Tuple[ label_arg_name = "Y" if "Y" in argspec.args else "y" args[label_arg_name] = df[label_cols].squeeze() - if sample_weight_col is not None and "sample_weight" in argspec.args: + if sample_weight_col is not None: args["sample_weight"] = df[sample_weight_col].squeeze() return args, estimator, indices, len(df), params_to_evaluate @@ -1061,7 +1061,7 @@ def _distributed_search( if label_cols: label_arg_name = "Y" if "Y" in argspec.args else "y" args[label_arg_name] = y - if sample_weight_col is not None and "sample_weight" in argspec.args: + if sample_weight_col is not None: args["sample_weight"] = df[sample_weight_col].squeeze() # estimator.refit = original_refit refit_start_time = time.time() diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py index 7a28fdf4..dc8b1e41 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py @@ -318,19 +318,19 @@ def score_wrapper_sproc( with open(local_score_file_name_path, mode="r+b") as local_score_file_obj: estimator = cp.load(local_score_file_obj) - argspec = inspect.getfullargspec(estimator.score) - if "X" in argspec.args: + params = inspect.signature(estimator.score).parameters + if "X" in params: args = {"X": df[input_cols]} - elif "X_test" in argspec.args: + elif "X_test" in params: args = {"X_test": df[input_cols]} else: raise RuntimeError("Neither 'X' or 'X_test' exist in argument") if label_cols: - label_arg_name = "Y" if "Y" in argspec.args else "y" + label_arg_name = "Y" if "Y" in params else "y" args[label_arg_name] = df[label_cols].squeeze() - if sample_weight_col is not None and "sample_weight" in argspec.args: + if sample_weight_col is not None and "sample_weight" in params: args["sample_weight"] = df[sample_weight_col].squeeze() result: float = estimator.score(**args) diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py index 8c809674..cc37c338 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py @@ -35,6 +35,7 @@ _PROJECT = "ModelDevelopment" _ENABLE_ANONYMOUS_SPROC = False +_ENABLE_TRACER = True class SnowparkModelTrainer: @@ -119,6 +120,8 @@ def _build_fit_wrapper_sproc( A callable that can be registered as a stored procedure. """ imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml + method_name = "fit" + tracer_name = f"snowpark.ml.modeling.{self._class_name.lower()}.{method_name}" def fit_wrapper_function( session: Session, @@ -138,110 +141,98 @@ def fit_wrapper_function( for import_name in imports: importlib.import_module(import_name) - # Execute snowpark queries and obtain the results as pandas dataframe - # NB: this implies that the result data must fit into memory. - for query in sql_queries[:-1]: - _ = session.sql(query).collect(statement_params=statement_params) - sp_df = session.sql(sql_queries[-1]) - df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params) - df.columns = sp_df.columns + def fit_and_return_estimator() -> str: + """This is a helper function within the sproc to download the data, fit the model, and upload the model. + + Returns: + The name of the file in session's temp stage (temp_stage_name) that contains the serialized model. + """ + # Execute snowpark queries and obtain the results as pandas dataframe + # NB: this implies that the result data must fit into memory. + for query in sql_queries[:-1]: + _ = session.sql(query).collect(statement_params=statement_params) + sp_df = session.sql(sql_queries[-1]) + df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params) + df.columns = sp_df.columns + + local_transform_file_name = temp_file_utils.get_temp_file_path() + + session.file.get( + stage_location=temp_stage_name, + target_directory=local_transform_file_name, + statement_params=statement_params, + ) - local_transform_file_name = temp_file_utils.get_temp_file_path() + local_transform_file_path = os.path.join( + local_transform_file_name, os.listdir(local_transform_file_name)[0] + ) + with open(local_transform_file_path, mode="r+b") as local_transform_file_obj: + estimator = cp.load(local_transform_file_obj) - session.file.get( - stage_location=temp_stage_name, - target_directory=local_transform_file_name, - statement_params=statement_params, - ) + params = inspect.signature(estimator.fit).parameters + args = {"X": df[input_cols]} + if label_cols: + label_arg_name = "Y" if "Y" in params else "y" + args[label_arg_name] = df[label_cols].squeeze() - local_transform_file_path = os.path.join( - local_transform_file_name, os.listdir(local_transform_file_name)[0] - ) - with open(local_transform_file_path, mode="r+b") as local_transform_file_obj: - estimator = cp.load(local_transform_file_obj) + if sample_weight_col is not None and "sample_weight" in params: + args["sample_weight"] = df[sample_weight_col].squeeze() - argspec = inspect.getfullargspec(estimator.fit) - args = {"X": df[input_cols]} - if label_cols: - label_arg_name = "Y" if "Y" in argspec.args else "y" - args[label_arg_name] = df[label_cols].squeeze() + estimator.fit(**args) - if sample_weight_col is not None and "sample_weight" in argspec.args: - args["sample_weight"] = df[sample_weight_col].squeeze() + local_result_file_name = temp_file_utils.get_temp_file_path() - estimator.fit(**args) + with open(local_result_file_name, mode="w+b") as local_result_file_obj: + cp.dump(estimator, local_result_file_obj) - local_result_file_name = temp_file_utils.get_temp_file_path() + session.file.put( + local_file_name=local_result_file_name, + stage_location=temp_stage_name, + auto_compress=False, + overwrite=True, + statement_params=statement_params, + ) + return local_result_file_name - with open(local_result_file_name, mode="w+b") as local_result_file_obj: - cp.dump(estimator, local_result_file_obj) + if _ENABLE_TRACER: - session.file.put( - local_file_name=local_result_file_name, - stage_location=temp_stage_name, - auto_compress=False, - overwrite=True, - statement_params=statement_params, - ) + # Use opentelemetry to trace the dist and span of the fit operation. + # This would allow user to see the trace in the Snowflake UI. + from opentelemetry import trace - # Note: you can add something like + "|" + str(df) to the return string - # to pass debug information to the caller. - return str(os.path.basename(local_result_file_name)) + tracer = trace.get_tracer(tracer_name) + with tracer.start_as_current_span("fit"): + local_result_file_name = fit_and_return_estimator() + # Note: you can add something like + "|" + str(df) to the return string + # to pass debug information to the caller. + return str(os.path.basename(local_result_file_name)) + else: + local_result_file_name = fit_and_return_estimator() + return str(os.path.basename(local_result_file_name)) return fit_wrapper_function - def _get_fit_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure: + def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure: model_spec = ModelSpecificationsBuilder.build(model=self.estimator) - fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE) - - relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( - pkg_versions=model_spec.pkgDependencies, session=self.session - ) - - fit_wrapper_sproc = self.session.sproc.register( - func=self._build_fit_wrapper_sproc(model_spec=model_spec), - is_permanent=False, - name=fit_sproc_name, - packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type] - replace=True, - session=self.session, - statement_params=statement_params, - anonymous=True, - execute_as="caller", - ) - - return fit_wrapper_sproc - - def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure: - # If the sproc already exists, don't register. - if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"): - self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc] - - model_spec = ModelSpecificationsBuilder.build(model=self.estimator) - fit_sproc_key = model_spec.__class__.__name__ - if fit_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined] - fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined] - return fit_sproc fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE) relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( pkg_versions=model_spec.pkgDependencies, session=self.session ) + packages = ["snowflake-snowpark-python", "snowflake-telemetry-python"] + relaxed_dependencies fit_wrapper_sproc = self.session.sproc.register( func=self._build_fit_wrapper_sproc(model_spec=model_spec), is_permanent=False, name=fit_sproc_name, - packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type] + packages=packages, # type: ignore[arg-type] replace=True, session=self.session, statement_params=statement_params, execute_as="caller", + anonymous=anonymous, ) - - self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] = fit_wrapper_sproc # type: ignore[attr-defined] - return fit_wrapper_sproc def _build_fit_predict_wrapper_sproc( @@ -333,7 +324,9 @@ def fit_predict_wrapper_function( # write into a temp table in sproc and load the table from outside session.write_pandas( - fit_predict_result_pd, fit_predict_result_name, auto_create_table=True, table_type="temp" + fit_predict_result_pd, + fit_predict_result_name, + overwrite=True, ) # Note: you can add something like + "|" + str(df) to the return string @@ -414,13 +407,13 @@ def fit_transform_wrapper_function( with open(local_transform_file_path, mode="r+b") as local_transform_file_obj: estimator = cp.load(local_transform_file_obj) - argspec = inspect.getfullargspec(estimator.fit) + params = inspect.signature(estimator.fit).parameters args = {"X": df[input_cols]} if label_cols: - label_arg_name = "Y" if "Y" in argspec.args else "y" + label_arg_name = "Y" if "Y" in params else "y" args[label_arg_name] = df[label_cols].squeeze() - if sample_weight_col is not None and "sample_weight" in argspec.args: + if sample_weight_col is not None and "sample_weight" in params: args["sample_weight"] = df[sample_weight_col].squeeze() fit_transform_result = estimator.fit_transform(**args) @@ -477,7 +470,7 @@ def fit_transform_wrapper_function( return fit_transform_wrapper_function - def _get_fit_predict_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure: + def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure: model_spec = ModelSpecificationsBuilder.build(model=self.estimator) fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE) @@ -494,82 +487,14 @@ def _get_fit_predict_wrapper_sproc_anonymous(self, statement_params: Dict[str, s replace=True, session=self.session, statement_params=statement_params, - anonymous=True, + anonymous=anonymous, execute_as="caller", ) return fit_predict_wrapper_sproc - def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure: - # If the sproc already exists, don't register. - if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"): - self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc] - - model_spec = ModelSpecificationsBuilder.build(model=self.estimator) - fit_predict_sproc_key = model_spec.__class__.__name__ + "_fit_predict" - if fit_predict_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined] - fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined] - fit_predict_sproc_key - ] - return fit_sproc - - fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE) - - relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( - pkg_versions=model_spec.pkgDependencies, session=self.session - ) - - fit_predict_wrapper_sproc = self.session.sproc.register( - func=self._build_fit_predict_wrapper_sproc(model_spec=model_spec), - is_permanent=False, - name=fit_predict_sproc_name, - packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type] - replace=True, - session=self.session, - statement_params=statement_params, - execute_as="caller", - ) - - self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined] - fit_predict_sproc_key - ] = fit_predict_wrapper_sproc - - return fit_predict_wrapper_sproc - - def _get_fit_transform_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure: - model_spec = ModelSpecificationsBuilder.build(model=self.estimator) - - fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE) - - relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( - pkg_versions=model_spec.pkgDependencies, session=self.session - ) - - fit_transform_wrapper_sproc = self.session.sproc.register( - func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec), - is_permanent=False, - name=fit_transform_sproc_name, - packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type] - replace=True, - session=self.session, - statement_params=statement_params, - anonymous=True, - execute_as="caller", - ) - return fit_transform_wrapper_sproc - - def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure: - # If the sproc already exists, don't register. - if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"): - self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc] - + def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure: model_spec = ModelSpecificationsBuilder.build(model=self.estimator) - fit_transform_sproc_key = model_spec.__class__.__name__ + "_fit_transform" - if fit_transform_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined] - fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined] - fit_transform_sproc_key - ] - return fit_sproc fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE) @@ -586,12 +511,9 @@ def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> session=self.session, statement_params=statement_params, execute_as="caller", + anonymous=anonymous, ) - self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined] - fit_transform_sproc_key - ] = fit_transform_wrapper_sproc - return fit_transform_wrapper_sproc def train(self) -> object: @@ -629,9 +551,9 @@ def train(self) -> object: # Call fit sproc if _ENABLE_ANONYMOUS_SPROC: - fit_wrapper_sproc = self._get_fit_wrapper_sproc_anonymous(statement_params=statement_params) + fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params, anonymous=True) else: - fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params) + fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params, anonymous=False) try: sproc_export_file_name: str = fit_wrapper_sproc( @@ -665,6 +587,7 @@ def train_fit_predict( self, expected_output_cols_list: List[str], drop_input_cols: Optional[bool] = False, + example_output_pd_df: Optional[pd.DataFrame] = None, ) -> Tuple[Union[DataFrame, pd.DataFrame], object]: """Trains the model by pushing down the compute into Snowflake using stored procedures. This API is different from fit itself because it would also provide the predict @@ -675,6 +598,11 @@ def train_fit_predict( name as a list. Defaults to None. drop_input_cols (Optional[bool]): Boolean to determine drop the input columns from the output dataset or not + example_output_pd_df (Optional[pd.DataFrame]): Example output dataframe + This is to create a temp table in the client side with df_one_row. This can maintain the same column + name and data type as the output dataframe. Within the sproc, we don't need to create another temp table + again - instead, we overwrite into this table without changing the schema. + This is not used in PandasModelTrainer. Returns: Tuple[Union[DataFrame, pd.DataFrame], object]: [predicted dataset, estimator] @@ -702,12 +630,35 @@ def train_fit_predict( # Call fit sproc if _ENABLE_ANONYMOUS_SPROC: - fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc_anonymous(statement_params=statement_params) + fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc( + statement_params=statement_params, anonymous=True + ) else: - fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params) + fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc( + statement_params=statement_params, anonymous=False + ) fit_predict_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE) + # Create a temp table in advance to store the output + # This would allow us to use the same table outside the stored procedure + if not drop_input_cols: + assert example_output_pd_df is not None + remove_dataset_col_name_exist_in_output_col = list(set(dataset.columns) - set(example_output_pd_df.columns)) + pd_df_one_row = ( + dataset.select(remove_dataset_col_name_exist_in_output_col) + .limit(1) + .to_pandas(statement_params=statement_params) + ) + example_output_pd_df = pd.concat([pd_df_one_row, example_output_pd_df], axis=1) + + self.session.write_pandas( + example_output_pd_df, + fit_predict_result_name, + auto_create_table=True, + table_type="temp", + ) + sproc_export_file_name: str = fit_predict_wrapper_sproc( self.session, queries, @@ -769,11 +720,13 @@ def train_fit_transform( # Call fit sproc if _ENABLE_ANONYMOUS_SPROC: - fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc_anonymous( - statement_params=statement_params + fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc( + statement_params=statement_params, anonymous=True ) else: - fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params) + fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc( + statement_params=statement_params, anonymous=False + ) fit_transform_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE) diff --git a/snowflake/ml/modeling/parameters/BUILD.bazel b/snowflake/ml/modeling/parameters/BUILD.bazel index 99fe5d38..2ed7f6bf 100644 --- a/snowflake/ml/modeling/parameters/BUILD.bazel +++ b/snowflake/ml/modeling/parameters/BUILD.bazel @@ -12,6 +12,20 @@ py_library( ], ) +py_test( + name = "disable_distributed_hpo_test", + srcs = [ + "disable_distributed_hpo_test.py", + ], + deps = [ + ":disable_distributed_hpo", + "//snowflake/ml/modeling/_internal:model_trainer_builder", + "//snowflake/ml/modeling/_internal/snowpark_implementations:distributed_hpo_trainer", + "//snowflake/ml/modeling/_internal/snowpark_implementations:snowpark_trainer", + "//snowflake/ml/modeling/xgboost:xgb_classifier", + ], +) + py_library( name = "enable_anonymous_sproc", srcs = [ @@ -33,17 +47,24 @@ py_test( ], ) +py_library( + name = "disable_model_tracer", + srcs = [ + "disable_model_tracer.py", + ], + deps = [ + "//snowflake/ml/modeling/_internal/snowpark_implementations:snowpark_trainer", + ], +) + py_test( - name = "disable_distributed_hpo_test", + name = "disable_model_tracer_test", srcs = [ - "disable_distributed_hpo_test.py", + "disable_model_tracer_test.py", ], deps = [ - ":disable_distributed_hpo", - "//snowflake/ml/modeling/_internal:model_trainer_builder", - "//snowflake/ml/modeling/_internal/snowpark_implementations:distributed_hpo_trainer", + ":disable_model_tracer", "//snowflake/ml/modeling/_internal/snowpark_implementations:snowpark_trainer", - "//snowflake/ml/modeling/xgboost:xgb_classifier", ], ) @@ -52,6 +73,7 @@ py_package( packages = ["snowflake.ml"], deps = [ ":disable_distributed_hpo", + ":disable_model_tracer", ":enable_anonymous_sproc", ], ) diff --git a/snowflake/ml/modeling/parameters/disable_model_tracer.py b/snowflake/ml/modeling/parameters/disable_model_tracer.py new file mode 100644 index 00000000..93ebc48f --- /dev/null +++ b/snowflake/ml/modeling/parameters/disable_model_tracer.py @@ -0,0 +1,5 @@ +"""Disables the snowpark observability tracer when running modeling fit""" + +from snowflake.ml.modeling._internal.snowpark_implementations import snowpark_trainer + +snowpark_trainer._ENABLE_TRACER = False diff --git a/snowflake/ml/modeling/parameters/disable_model_tracer_test.py b/snowflake/ml/modeling/parameters/disable_model_tracer_test.py new file mode 100644 index 00000000..8686b95c --- /dev/null +++ b/snowflake/ml/modeling/parameters/disable_model_tracer_test.py @@ -0,0 +1,24 @@ +from unittest import mock + +from absl.testing import absltest + +from snowflake.ml.modeling._internal.snowpark_implementations import snowpark_trainer +from snowflake.snowpark import DataFrame, Session + + +class EnableAnonymousSPROC(absltest.TestCase): + def test_disable_distributed_hpo(self) -> None: + mock_session = mock.MagicMock(spec=Session) + mock_dataframe = mock.MagicMock(spec=DataFrame) + mock_dataframe._session = mock_session + + self.assertTrue(snowpark_trainer._ENABLE_TRACER) + + # Disable distributed HPO + import snowflake.ml.modeling.parameters.disable_model_tracer # noqa: F401 + + self.assertFalse(snowpark_trainer._ENABLE_TRACER) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/modeling/pipeline/BUILD.bazel b/snowflake/ml/modeling/pipeline/BUILD.bazel index b7a2de6d..85e4aae1 100644 --- a/snowflake/ml/modeling/pipeline/BUILD.bazel +++ b/snowflake/ml/modeling/pipeline/BUILD.bazel @@ -46,6 +46,9 @@ py_test( "//snowflake/ml/modeling/lightgbm:lgbm_classifier", "//snowflake/ml/modeling/linear_model:linear_regression", "//snowflake/ml/modeling/preprocessing:min_max_scaler", + "//snowflake/ml/modeling/preprocessing:one_hot_encoder", + "//snowflake/ml/modeling/preprocessing:standard_scaler", + "//snowflake/ml/modeling/xgboost:xgb_classifier", "//snowflake/ml/modeling/xgboost:xgb_regressor", ], ) diff --git a/snowflake/ml/modeling/pipeline/pipeline.py b/snowflake/ml/modeling/pipeline/pipeline.py index e4baf152..039b5e60 100644 --- a/snowflake/ml/modeling/pipeline/pipeline.py +++ b/snowflake/ml/modeling/pipeline/pipeline.py @@ -418,9 +418,6 @@ def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame], squash: Optional Returns: Fitted pipeline. - - Raises: - ValueError: A pipeline incompatible with sklearn is used on MLRS """ self._validate_steps() @@ -437,8 +434,6 @@ def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame], squash: Optional lineage_utils.set_data_sources(self, data_sources) if self._can_be_trained_in_ml_runtime(dataset): - if not self._is_convertible_to_sklearn: - raise ValueError("This pipeline cannot be converted to an sklearn pipeline.") self._fit_ml_runtime(dataset) elif squash and isinstance(dataset, snowpark.DataFrame): @@ -611,14 +606,8 @@ def predict(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[sno Returns: Output dataset. - - Raises: - ValueError: An sklearn object has not been fit and stored before calling this function. """ - if os.environ.get(IN_ML_RUNTIME_ENV_VAR): - if self._sklearn_object is None: - raise ValueError("Model must be fit before inference.") - + if os.environ.get(IN_ML_RUNTIME_ENV_VAR) and self._sklearn_object is not None: expected_output_cols = self._infer_output_cols() handler = ModelTransformerBuilder.build( dataset=dataset, diff --git a/snowflake/ml/modeling/pipeline/pipeline_test.py b/snowflake/ml/modeling/pipeline/pipeline_test.py index ade2f0f6..dddf08b6 100644 --- a/snowflake/ml/modeling/pipeline/pipeline_test.py +++ b/snowflake/ml/modeling/pipeline/pipeline_test.py @@ -1,5 +1,6 @@ import os +import numpy as np import pandas as pd from absl.testing import absltest from sklearn.compose import ColumnTransformer @@ -10,8 +11,12 @@ from snowflake.ml.modeling.lightgbm import LGBMClassifier from snowflake.ml.modeling.linear_model import LinearRegression from snowflake.ml.modeling.pipeline.pipeline import IN_ML_RUNTIME_ENV_VAR, Pipeline -from snowflake.ml.modeling.preprocessing import MinMaxScaler -from snowflake.ml.modeling.xgboost import XGBRegressor +from snowflake.ml.modeling.preprocessing import ( + MinMaxScaler, + OneHotEncoder, + StandardScaler, +) +from snowflake.ml.modeling.xgboost import XGBClassifier, XGBRegressor from snowflake.snowpark import DataFrame @@ -51,6 +56,23 @@ def setUp(self) -> None: ] ) + # Sample data for the new test + self.categorical_columns = ["AGE", "CAMPAIGN", "DEFAULT"] + self.numerical_columns = ["CONS_CONF_IDX"] + self.label_column = ["LABEL"] + + # Create a small sample dataset + self._test_data = pd.DataFrame( + { + "AGE": ["30", "40", "50"], + "CAMPAIGN": ["1", "2", "1"], + "CONS_CONF_IDX": [-42.7, -50.8, -36.1], + "DEFAULT": ["no", "yes", "no"], + "LABEL": [0, 1, 0], + } + ) + self._test_data["ROW_INDEX"] = self._test_data.index + return super().setUp() def test_dataset_can_be_trained_in_ml_runtime(self) -> None: @@ -233,6 +255,92 @@ def test_to_sklearn(self) -> None: """Tests behavior for converting the pipeline to an sklearn pipeline""" assert isinstance(self.simple_pipeline.to_sklearn(), sklearn_Pipeline) + def test_fit_and_compare_results_pandas_dataframe(self) -> None: + with absltest.mock.patch.dict(os.environ, {IN_ML_RUNTIME_ENV_VAR: ""}, clear=True): + raw_data_pandas = self._test_data + + pipeline = Pipeline( + steps=[ + ( + "OHE", + OneHotEncoder( + input_cols=self.categorical_columns, + output_cols=self.categorical_columns, + drop_input_cols=True, + ), + ), + ( + "MMS", + MinMaxScaler( + clip=True, + input_cols=self.numerical_columns, + output_cols=self.numerical_columns, + ), + ), + ("regression", XGBClassifier(label_cols=self.label_column)), + ] + ) + + pipeline.fit(raw_data_pandas) + pipeline.predict(raw_data_pandas) + + def test_pipeline_export(self): + raw_data_pandas = self._test_data + + # Simulate the creation of the pipeline + pipeline = Pipeline( + steps=[ + ( + "OHE", + OneHotEncoder( + input_cols=self.categorical_columns, output_cols=self.categorical_columns, drop_input_cols=True + ), + ), + ("MMS", MinMaxScaler(clip=True, input_cols=self.numerical_columns, output_cols=self.numerical_columns)), + ("SS", StandardScaler(input_cols=self.numerical_columns, output_cols=self.numerical_columns)), + ("regression", XGBClassifier(label_cols=self.label_column, passthrough_cols="ROW_INDEX")), + ] + ) + + pipeline.fit(raw_data_pandas) + snow_results = pipeline.predict(raw_data_pandas).sort_values(by=["ROW_INDEX"])["OUTPUT_LABEL"].to_numpy() + + # Create a similar scikit-learn pipeline for comparison + sk_pipeline = pipeline.to_sklearn() + sk_results = sk_pipeline.predict(raw_data_pandas.drop(columns=["LABEL"])) + + # Assert the results are close + np.testing.assert_allclose(snow_results, sk_results, rtol=1.0e-1, atol=1.0e-2) + + def test_pipeline_with_limited_number_of_columns_in_estimator_export(self) -> None: + raw_data_pandas = self._test_data + snow_raw_data_pandas = raw_data_pandas.drop("DEFAULT", axis=1) + + pipeline = Pipeline( + steps=[ + ( + "MMS", + MinMaxScaler( + clip=True, + input_cols=self.numerical_columns, + output_cols=self.numerical_columns, + ), + ), + ( + "SS", + StandardScaler(input_cols=(self.numerical_columns[0:2]), output_cols=(self.numerical_columns[0:2])), + ), + ("regression", XGBClassifier(input_cols=self.numerical_columns, label_cols=self.label_column)), + ] + ) + + pipeline.fit(snow_raw_data_pandas) + snow_results = pipeline.predict(snow_raw_data_pandas).sort_values(by=["ROW_INDEX"])["OUTPUT_LABEL"].to_numpy() + + sk_pipeline = pipeline.to_sklearn() + sk_results = sk_pipeline.predict(raw_data_pandas.drop(columns=["LABEL"])) + np.testing.assert_allclose(snow_results, sk_results, rtol=1.0e-1, atol=1.0e-2) + def tearDown(self) -> None: os.environ.pop(IN_ML_RUNTIME_ENV_VAR, None) self.send_custom_usage_mock.stop() diff --git a/snowflake/ml/registry/_manager/model_manager.py b/snowflake/ml/registry/_manager/model_manager.py index eb347ed2..ae34c95c 100644 --- a/snowflake/ml/registry/_manager/model_manager.py +++ b/snowflake/ml/registry/_manager/model_manager.py @@ -50,6 +50,7 @@ def log_model( sample_input_data: Optional[model_types.SupportedDataType] = None, code_paths: Optional[List[str]] = None, ext_modules: Optional[List[ModuleType]] = None, + model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN, options: Optional[model_types.ModelSaveOption] = None, statement_params: Optional[Dict[str, Any]] = None, ) -> model_version_impl.ModelVersion: @@ -89,6 +90,7 @@ def log_model( sample_input_data=sample_input_data, code_paths=code_paths, ext_modules=ext_modules, + model_objective=model_objective, options=options, statement_params=statement_params, ) @@ -108,6 +110,7 @@ def _log_model( sample_input_data: Optional[model_types.SupportedDataType] = None, code_paths: Optional[List[str]] = None, ext_modules: Optional[List[ModuleType]] = None, + model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN, options: Optional[model_types.ModelSaveOption] = None, statement_params: Optional[Dict[str, Any]] = None, ) -> model_version_impl.ModelVersion: @@ -156,6 +159,7 @@ def _log_model( code_paths=code_paths, ext_modules=ext_modules, options=options, + model_objective=model_objective, ) statement_params = telemetry.add_statement_params_custom_tags( statement_params, model_metadata.telemetry_metadata() diff --git a/snowflake/ml/registry/_manager/model_manager_test.py b/snowflake/ml/registry/_manager/model_manager_test.py index 6b988284..2a29985e 100644 --- a/snowflake/ml/registry/_manager/model_manager_test.py +++ b/snowflake/ml/registry/_manager/model_manager_test.py @@ -6,6 +6,7 @@ from snowflake.ml._internal import telemetry from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import type_hints from snowflake.ml.model._client.model import model_impl, model_version_impl from snowflake.ml.model._client.ops import service_ops from snowflake.ml.model._client.ops.model_ops import ModelOperator @@ -211,6 +212,7 @@ def test_log_model_minimal(self) -> None: code_paths=None, ext_modules=None, options=None, + model_objective=type_hints.ModelObjective.UNKNOWN, ) mock_create_from_stage.assert_called_once_with( composed_model=mock.ANY, @@ -279,6 +281,7 @@ def test_log_model_1(self) -> None: code_paths=None, ext_modules=None, options=None, + model_objective=type_hints.ModelObjective.UNKNOWN, ) mock_create_from_stage.assert_called_once_with( composed_model=mock.ANY, @@ -332,6 +335,7 @@ def test_log_model_2(self) -> None: code_paths=None, ext_modules=None, options=m_options, + model_objective=type_hints.ModelObjective.UNKNOWN, ) mock_create_from_stage.assert_called_once_with( composed_model=mock.ANY, @@ -388,6 +392,7 @@ def test_log_model_3(self) -> None: code_paths=m_code_paths, ext_modules=m_ext_modules, options=None, + model_objective=type_hints.ModelObjective.UNKNOWN, ) mock_create_from_stage.assert_called_once_with( composed_model=mock.ANY, @@ -444,6 +449,7 @@ def test_log_model_4(self) -> None: code_paths=None, ext_modules=None, options=None, + model_objective=type_hints.ModelObjective.UNKNOWN, ) mock_create_from_stage.assert_called_once_with( composed_model=mock.ANY, @@ -544,6 +550,7 @@ def test_log_model_fully_qualified(self) -> None: code_paths=None, ext_modules=None, options=None, + model_objective=type_hints.ModelObjective.UNKNOWN, ) mock_create_from_stage.assert_called_once_with( composed_model=mock.ANY, diff --git a/snowflake/ml/registry/model_registry.py b/snowflake/ml/registry/model_registry.py index 3e6a9862..e5b04df8 100644 --- a/snowflake/ml/registry/model_registry.py +++ b/snowflake/ml/registry/model_registry.py @@ -576,7 +576,7 @@ def _get_fully_qualified_stage_name_from_uri(self, model_uri: str) -> Optional[s raw_stage_path = uri.get_snowflake_stage_path_from_uri(model_uri) if not raw_stage_path: return None - (db, schema, stage, _) = identifier.parse_schema_level_object_identifier(raw_stage_path) + (db, schema, stage, _) = identifier.parse_snowflake_stage_path(raw_stage_path) return identifier.get_schema_level_object_identifier(db, schema, stage) def _list_selected_models( diff --git a/snowflake/ml/registry/registry.py b/snowflake/ml/registry/registry.py index 710e603c..ae36e430 100644 --- a/snowflake/ml/registry/registry.py +++ b/snowflake/ml/registry/registry.py @@ -244,8 +244,7 @@ def log_model( warnings.warn( "Models logged specifying `pip_requirements` can not be executed " "in Snowflake Warehouse where all dependencies are required to be retrieved " - "from Snowflake Anaconda Channel. Specify model save option `include_pip_dependencies`" - "to log model with pip dependencies.", + "from Snowflake Anaconda Channel.", category=UserWarning, stacklevel=1, ) diff --git a/snowflake/ml/registry/registry_test.py b/snowflake/ml/registry/registry_test.py index e78c1f73..f562e1d8 100644 --- a/snowflake/ml/registry/registry_test.py +++ b/snowflake/ml/registry/registry_test.py @@ -129,6 +129,7 @@ def test_log_model(self) -> None: comment=m_comment, metrics=m_metrics, conda_dependencies=m_conda_dependency, + pip_requirements=None, python_version=m_python_version, signatures=m_signatures, sample_input_data=m_sample_input_data, diff --git a/snowflake/ml/test_utils/mock_session.py b/snowflake/ml/test_utils/mock_session.py index c955a8d4..fe6bbcef 100644 --- a/snowflake/ml/test_utils/mock_session.py +++ b/snowflake/ml/test_utils/mock_session.py @@ -3,6 +3,7 @@ from typing import Any, Type from unittest import TestCase +from snowflake import snowpark from snowflake.ml._internal.utils.string_matcher import StringMatcherSql from snowflake.ml.test_utils import mock_data_frame, mock_snowml_base from snowflake.snowpark import Session @@ -94,3 +95,18 @@ def sql(self, query: str) -> Any: check_kwargs=True, ) return mo.result + + def add_mock_query_history(self, result: snowpark.QueryHistory) -> mock_snowml_base.MockSnowMLBase: + """Add an expected query history to the session.""" + return self.add_operation(operation="query_history", args=(), kwargs={}, result=result) + + def query_history(self) -> Any: + """Execute a mock query_history call.""" + mo = self._check_operation( + operation="query_history", + args=(), + kwargs={}, + check_args=False, + check_kwargs=False, + ) + return mo.result diff --git a/snowflake/ml/version.bzl b/snowflake/ml/version.bzl index 442fac15..d69b564f 100644 --- a/snowflake/ml/version.bzl +++ b/snowflake/ml/version.bzl @@ -1,2 +1,2 @@ # This is parsed by regex in conda reciper meta file. Make sure not to break it. -VERSION = "1.6.1" +VERSION = "1.6.2" diff --git a/tests/integ/snowflake/ml/extra_tests/BUILD.bazel b/tests/integ/snowflake/ml/extra_tests/BUILD.bazel index 9470ef8c..21b2237a 100644 --- a/tests/integ/snowflake/ml/extra_tests/BUILD.bazel +++ b/tests/integ/snowflake/ml/extra_tests/BUILD.bazel @@ -85,12 +85,12 @@ py_test( data = ["//tests/integ/snowflake/ml/test_data:UCI_BANK_MARKETING_20COLUMNS.csv"], shard_count = 4, deps = [ + "//snowflake/ml/modeling/compose:column_transformer", "//snowflake/ml/modeling/framework", "//snowflake/ml/modeling/impute:knn_imputer", "//snowflake/ml/modeling/pipeline", "//snowflake/ml/modeling/preprocessing:min_max_scaler", "//snowflake/ml/modeling/preprocessing:one_hot_encoder", - "//snowflake/ml/modeling/preprocessing:standard_scaler", "//snowflake/ml/modeling/xgboost:xgb_classifier", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/test_utils:test_env_utils", @@ -159,6 +159,7 @@ py_test( py_test( name = "xgboost_external_memory_training_test", srcs = ["xgboost_external_memory_training_test.py"], + data = ["//tests/integ/snowflake/ml/test_data:UCI_BANK_MARKETING_20COLUMNS.csv"], deps = [ "//snowflake/ml/modeling/metrics:classification", "//snowflake/ml/modeling/xgboost:xgb_classifier", @@ -176,3 +177,16 @@ py_test( "//snowflake/ml/utils:connection_params", ], ) + +py_test( + name = "sample_weight_col_test", + srcs = ["sample_weight_col_test.py"], + data = ["//tests/integ/snowflake/ml/test_data:UCI_BANK_MARKETING_20COLUMNS.csv"], + deps = [ + "//snowflake/ml/modeling/framework", + "//snowflake/ml/modeling/model_selection:grid_search_cv", + "//snowflake/ml/modeling/xgboost:xgb_classifier", + "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:test_env_utils", + ], +) diff --git a/tests/integ/snowflake/ml/extra_tests/batch_inference_with_nan_data_test.py b/tests/integ/snowflake/ml/extra_tests/batch_inference_with_nan_data_test.py index d286e615..3dd623ee 100644 --- a/tests/integ/snowflake/ml/extra_tests/batch_inference_with_nan_data_test.py +++ b/tests/integ/snowflake/ml/extra_tests/batch_inference_with_nan_data_test.py @@ -46,12 +46,13 @@ def test_nan_data(self) -> None: input_df_pandas[input_cols], input_df_pandas[label_cols], ) - training_predictions = ( - classifier.predict_proba(input_df_sp) - .to_pandas() - .sort_values(by="INDEX")[['"PREDICT_PROBA_0.0"', '"PREDICT_PROBA_1.0"']] - .to_numpy() - ) + training_predictions = classifier.predict_proba(input_df_sp).to_pandas().sort_values(by="INDEX") + PREDICT_PROBA_COLS = [] + for c in training_predictions.columns: + if "PREDICT_PROBA_" in c: + PREDICT_PROBA_COLS.append(c) + + training_predictions = training_predictions[PREDICT_PROBA_COLS].to_numpy() native_predictions = native_classifier.predict_proba(input_df_pandas[input_cols]) np.testing.assert_allclose( training_predictions.flatten(), native_predictions.flatten(), rtol=1.0e-1, atol=1.0e-2 diff --git a/tests/integ/snowflake/ml/extra_tests/column_name_inference_test.py b/tests/integ/snowflake/ml/extra_tests/column_name_inference_test.py index 934d6d0f..52862e1a 100644 --- a/tests/integ/snowflake/ml/extra_tests/column_name_inference_test.py +++ b/tests/integ/snowflake/ml/extra_tests/column_name_inference_test.py @@ -53,7 +53,7 @@ def _test_column_name_inference( sklearn_results = sklearn_reg.predict(input_df_pandas[input_cols]) np.testing.assert_array_equal(reg.get_input_cols(), input_cols) - np.testing.assert_allclose(actual_results.flatten(), sklearn_results.flatten()) + np.testing.assert_allclose(actual_results.flatten(), sklearn_results.flatten(), rtol=1.0e-3, atol=1.0e-3) def test_snowpark_interface_with_passthrough_cols(self): self._test_column_name_inference(use_snowpark_interface=True, use_passthrough_cols=True) diff --git a/tests/integ/snowflake/ml/extra_tests/decimal_type_test.py b/tests/integ/snowflake/ml/extra_tests/decimal_type_test.py index 7b4f8959..8324923b 100644 --- a/tests/integ/snowflake/ml/extra_tests/decimal_type_test.py +++ b/tests/integ/snowflake/ml/extra_tests/decimal_type_test.py @@ -49,7 +49,7 @@ def test_decimal_type(self) -> None: actual_results = reg.predict(input_df_pandas)[reg.get_output_cols()].to_numpy() sklearn_results = sklearn_reg.predict(input_df_pandas[input_cols]) - np.testing.assert_allclose(actual_results.flatten(), sklearn_results.flatten()) + np.testing.assert_allclose(actual_results.flatten(), sklearn_results.flatten(), rtol=1.0e-3, atol=1.0e-3) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/extra_tests/multi_label_column_name_test.py b/tests/integ/snowflake/ml/extra_tests/multi_label_column_name_test.py index e9704401..c9045331 100644 --- a/tests/integ/snowflake/ml/extra_tests/multi_label_column_name_test.py +++ b/tests/integ/snowflake/ml/extra_tests/multi_label_column_name_test.py @@ -1,4 +1,7 @@ +import os + import numpy as np +import pytest from absl.testing.absltest import TestCase, main from sklearn.datasets import make_multilabel_classification from sklearn.ensemble import RandomForestClassifier as SkRandomForestClassifier @@ -84,6 +87,13 @@ def test_random_forest_regressor_with_five_label_cols(self): training_predictions_log_proba.flatten(), concatenated_array_log_proba.flatten(), rtol=1.0e-1, atol=1.0e-2 ) + @pytest.mark.skipif( + os.getenv("IN_SPCS_ML_RUNTIME") == "True", + reason=( + "Skipping test, xgboost_ray doesn't support multi-output" + "See: https://github.com/ray-project/xgboost_ray/issues/286" + ), + ) def test_xgb_regressor_with_five_label_cols(self): snf_df = self._session.create_dataframe(self.mult_cl_data.tolist(), schema=self.feature_cols + self.target_cols) snf_df.write.save_as_table("multi_target_cl", mode="overwrite") diff --git a/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py b/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py index 7a3a9244..33dd0d60 100644 --- a/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py +++ b/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py @@ -1,5 +1,8 @@ +import os + import numpy as np import pandas as pd +import pytest from absl.testing import absltest from importlib_resources import files from sklearn.compose import ColumnTransformer as SkColumnTransformer @@ -11,13 +14,10 @@ ) from xgboost import XGBClassifier as XGB_XGBClassifier +from snowflake.ml.modeling.compose import ColumnTransformer from snowflake.ml.modeling.impute import KNNImputer from snowflake.ml.modeling.pipeline import Pipeline -from snowflake.ml.modeling.preprocessing import ( - MinMaxScaler, - OneHotEncoder, - StandardScaler, -) +from snowflake.ml.modeling.preprocessing import MinMaxScaler, OneHotEncoder from snowflake.ml.modeling.xgboost import XGBClassifier from snowflake.ml.utils.connection_params import SnowflakeLoginOptions from snowflake.snowpark import Session @@ -48,6 +48,7 @@ "PREVIOUS", ] label_column = ["LABEL"] +IN_ML_RUNTIME_ENV_VAR = "IN_SPCS_ML_RUNTIME" feature_cols = categorical_columns + numerical_columns @@ -61,32 +62,65 @@ def setUp(self): def tearDown(self): self._session.close() + def _get_preprocessor(self, categorical_columns, numerical_columns, use_knn_imputer=True): + """Helper method to create the ColumnTransformer for preprocessing.""" + transformers = [ + ("ohe", OneHotEncoder(drop_input_cols=True), categorical_columns), + ("mms", MinMaxScaler(clip=True), numerical_columns), + ] + + if use_knn_imputer: + transformers.append(("knn_imputer", KNNImputer(), numerical_columns)) + + return ColumnTransformer( + transformers=transformers, + remainder="passthrough", # Ensures columns not specified are passed through without transformation + ) + + def _get_pipeline(self, categorical_columns, numerical_columns, label_column, use_knn_imputer=True): + """Helper method to create the Pipeline with the appropriate preprocessor and XGBClassifier.""" + + # Check if the environment variable is set to True + if os.environ.get(IN_ML_RUNTIME_ENV_VAR): + # Create the preprocessor using the helper method + preprocessor = self._get_preprocessor(categorical_columns, numerical_columns, use_knn_imputer) + + # Create and return the pipeline with the preprocessor + return Pipeline( + steps=[ + ("preprocessor", preprocessor), + ("regression", XGBClassifier(label_cols=label_column, passthrough_cols="ROW_INDEX")), + ] + ) + + # When the environment variable is not set + steps = [ + ( + "OHE", + OneHotEncoder(input_cols=categorical_columns, output_cols=categorical_columns, drop_input_cols=True), + ), + ( + "MMS", + MinMaxScaler( + clip=True, + input_cols=numerical_columns, + output_cols=numerical_columns, + ), + ), + ("regression", XGBClassifier(label_cols=label_column, passthrough_cols="ROW_INDEX")), + ] + + if use_knn_imputer: + steps.insert(2, ("KNNImputer", KNNImputer(input_cols=numerical_columns, output_cols=numerical_columns))) + + return Pipeline(steps=steps) + def test_fit_and_compare_results(self) -> None: pd_data = self._test_data pd_data["ROW_INDEX"] = pd_data.reset_index().index raw_data = self._session.create_dataframe(pd_data) - pipeline = Pipeline( - steps=[ - ( - "OHE", - OneHotEncoder( - input_cols=categorical_columns, output_cols=categorical_columns, drop_input_cols=True - ), - ), - ( - "MMS", - MinMaxScaler( - clip=True, - input_cols=numerical_columns, - output_cols=numerical_columns, - ), - ), - ("KNNImputer", KNNImputer(input_cols=numerical_columns, output_cols=numerical_columns)), - ("regression", XGBClassifier(label_cols=label_column, passthrough_cols="ROW_INDEX")), - ] - ) - + pipeline = self._get_pipeline(categorical_columns, numerical_columns, label_column) pipeline.fit(raw_data) results = pipeline.predict(raw_data).to_pandas().sort_values(by=["ROW_INDEX"])["OUTPUT_LABEL"].to_numpy() @@ -115,30 +149,10 @@ def test_fit_predict_proba_and_compare_results(self) -> None: pd_data["ROW_INDEX"] = pd_data.reset_index().index raw_data = self._session.create_dataframe(pd_data) - pipeline = Pipeline( - steps=[ - ( - "OHE", - OneHotEncoder( - input_cols=categorical_columns, output_cols=categorical_columns, drop_input_cols=True - ), - ), - ( - "MMS", - MinMaxScaler( - clip=True, - input_cols=numerical_columns, - output_cols=numerical_columns, - ), - ), - ("KNNImputer", KNNImputer(input_cols=numerical_columns, output_cols=numerical_columns)), - ("regression", XGBClassifier(label_cols=label_column, passthrough_cols="ROW_INDEX")), - ] - ) - + pipeline = self._get_pipeline(categorical_columns, numerical_columns, label_column) pipeline.fit(raw_data) results = pipeline.predict_proba(raw_data).to_pandas().sort_values(by=["ROW_INDEX"]) - proba_cols = [c for c in results.columns if c.startswith("PREDICT_PROBA_")] + proba_cols = [c for c in results.columns if c.startswith("PREDICT_PROBA")] proba_results = results[proba_cols].to_numpy() sk_pipeline = SkPipeline( @@ -161,152 +175,29 @@ def test_fit_predict_proba_and_compare_results(self) -> None: np.testing.assert_allclose(proba_results.flatten(), sk_proba_results.flatten(), rtol=1.0e-1, atol=1.0e-2) - def test_fit_and_compare_results_pandas_dataframe(self) -> None: - raw_data_pandas = self._test_data - - pipeline = Pipeline( - steps=[ - ( - "OHE", - OneHotEncoder( - input_cols=categorical_columns, output_cols=categorical_columns, drop_input_cols=True - ), - ), - ( - "MMS", - MinMaxScaler( - clip=True, - input_cols=numerical_columns, - output_cols=numerical_columns, - ), - ), - ("regression", XGBClassifier(label_cols=label_column)), - ] - ) - - pipeline.fit(raw_data_pandas) - pipeline.predict(raw_data_pandas) - + @pytest.mark.skipif( + os.getenv("IN_SPCS_ML_RUNTIME") == "True", + reason=( + "Skipping this test, as we go ahead with this PR" + "See: https://github.com/snowflakedb/snowml/pull/2651/files" + ), + ) def test_fit_and_compare_results_pandas(self) -> None: pd_data = self._test_data + pd_data["ROW_INDEX"] = pd_data.reset_index().index raw_data = self._session.create_dataframe(pd_data) - pipeline = Pipeline( - steps=[ - ( - "OHE", - OneHotEncoder( - input_cols=categorical_columns, output_cols=categorical_columns, drop_input_cols=True - ), - ), - ( - "MMS", - MinMaxScaler( - clip=True, - input_cols=numerical_columns, - output_cols=numerical_columns, - ), - ), - ("regression", XGBClassifier(label_cols=label_column)), - ] - ) + pipeline = self._get_pipeline(categorical_columns, numerical_columns, label_column, use_knn_imputer=False) pipeline.fit(raw_data) pipeline.predict(raw_data.to_pandas()) - def test_pipeline_export(self) -> None: - pd_data = self._test_data - pd_data["ROW_INDEX"] = pd_data.reset_index().index - snow_df = self._session.create_dataframe(pd_data) - pd_df = pd_data.drop("LABEL", axis=1) - - pipeline = Pipeline( - steps=[ - ( - "OHE", - OneHotEncoder( - input_cols=categorical_columns, output_cols=categorical_columns, drop_input_cols=True - ), - ), - ( - "MMS", - MinMaxScaler( - clip=True, - input_cols=numerical_columns, - output_cols=numerical_columns, - ), - ), - ( - "SS", - StandardScaler(input_cols=(numerical_columns[0:2]), output_cols=(numerical_columns[0:2])), - ), - ("regression", XGBClassifier(label_cols=label_column, passthrough_cols="ROW_INDEX")), - ] - ) - - pipeline.fit(snow_df) - snow_results = pipeline.predict(snow_df).to_pandas().sort_values(by=["ROW_INDEX"])["OUTPUT_LABEL"].to_numpy() - - sk_pipeline = pipeline.to_sklearn() - sk_results = sk_pipeline.predict(pd_df) - np.testing.assert_allclose(snow_results.flatten(), sk_results.flatten(), rtol=1.0e-1, atol=1.0e-2) - - def test_pipeline_with_limited_number_of_columns_in_estimator_export(self) -> None: - pd_data = self._test_data - pd_data["ROW_INDEX"] = pd_data.reset_index().index - snow_df = self._session.create_dataframe(pd_data.drop("DEFAULT", axis=1)) - pd_df = pd_data.drop("LABEL", axis=1) - - pipeline = Pipeline( - steps=[ - ( - "MMS", - MinMaxScaler( - clip=True, - input_cols=numerical_columns, - output_cols=numerical_columns, - ), - ), - ( - "SS", - StandardScaler(input_cols=(numerical_columns[0:2]), output_cols=(numerical_columns[0:2])), - ), - ("regression", XGBClassifier(input_cols=numerical_columns, label_cols=label_column)), - ] - ) - - pipeline.fit(snow_df) - snow_results = pipeline.predict(snow_df).to_pandas().sort_values(by=["ROW_INDEX"])["OUTPUT_LABEL"].to_numpy() - - sk_pipeline = pipeline.to_sklearn() - sk_results = sk_pipeline.predict(pd_df) - np.testing.assert_allclose(snow_results.flatten(), sk_results.flatten(), rtol=1.0e-1, atol=1.0e-2) - def test_pipeline_squash(self) -> None: pd_data = self._test_data pd_data["ROW_INDEX"] = pd_data.reset_index().index raw_data = self._session.create_dataframe(pd_data) - pipeline = Pipeline( - steps=[ - ( - "OHE", - OneHotEncoder( - input_cols=categorical_columns, output_cols=categorical_columns, drop_input_cols=True - ), - ), - ( - "MMS", - MinMaxScaler( - clip=True, - input_cols=numerical_columns, - output_cols=numerical_columns, - ), - ), - ("KNNImputer", KNNImputer(input_cols=numerical_columns, output_cols=numerical_columns)), - ("regression", XGBClassifier(label_cols=label_column, passthrough_cols="ROW_INDEX")), - ] - ) + pipeline = self._get_pipeline(categorical_columns, numerical_columns, label_column) pipeline._deps.append( test_env_utils.get_latest_package_version_spec_in_server(self._session, "snowflake-snowpark-python") diff --git a/tests/integ/snowflake/ml/extra_tests/sample_weight_col_test.py b/tests/integ/snowflake/ml/extra_tests/sample_weight_col_test.py new file mode 100644 index 00000000..66f97162 --- /dev/null +++ b/tests/integ/snowflake/ml/extra_tests/sample_weight_col_test.py @@ -0,0 +1,104 @@ +import random + +import numpy as np +import pandas as pd +from absl.testing import absltest +from importlib_resources import files +from sklearn.model_selection import GridSearchCV as SkGridSearchCV +from xgboost import XGBClassifier as XGB_XGBClassifier + +from snowflake.ml.modeling.model_selection import GridSearchCV +from snowflake.ml.modeling.xgboost import XGBClassifier +from snowflake.ml.utils.connection_params import SnowflakeLoginOptions +from snowflake.snowpark import Session + +numerical_columns = [ + "CONS_CONF_IDX", + "CONS_PRICE_IDX", + "DURATION", + "EMP_VAR_RATE", + "EURIBOR3M", + "NR_EMPLOYED", + "PDAYS", + "PREVIOUS", +] +label_column = ["LABEL"] +feature_cols = numerical_columns + + +class XGBSampleWeightTest(absltest.TestCase): + def setUp(self): + """Creates Snowpark and Snowflake environments for testing.""" + self._session = Session.builder.configs(SnowflakeLoginOptions()).create() + data_file = files("tests.integ.snowflake.ml.test_data").joinpath("UCI_BANK_MARKETING_20COLUMNS.csv") + self._test_data = pd.read_csv(data_file, index_col=0) + + def tearDown(self): + self._session.close() + + def test_fit_and_compare_results(self) -> None: + pd_data = self._test_data + pd_data["ROW_INDEX"] = pd_data.reset_index().index + sample_weight_col = "SAMPLE_WEIGHT" + pd_data[sample_weight_col] = np.array([random.randint(0, 100) for _ in range(pd_data.shape[0])]) + + snowml_classifier = XGBClassifier( + input_cols=feature_cols, + label_cols=label_column, + passthrough_cols="ROW_INDEX", + sample_weight_col=sample_weight_col, + ) + xgb_classifier = XGB_XGBClassifier() + + xgb_classifier.fit(pd_data[feature_cols], pd_data[label_column], sample_weight=pd_data[sample_weight_col]) + predictions = xgb_classifier.predict(pd_data[feature_cols]) + + raw_data = self._session.create_dataframe(pd_data) + snowml_classifier.fit(raw_data) + snowml_predictions = ( + snowml_classifier.predict(raw_data).to_pandas().sort_values(by=["ROW_INDEX"])["OUTPUT_LABEL"].to_numpy() + ) + + np.testing.assert_allclose(predictions.flatten(), snowml_predictions.flatten(), rtol=1.0e-3, atol=1.0e-3) + + def test_grid_search_on_xgboost_sample_weight(self) -> None: + pd_data = self._test_data + pd_data["ROW_INDEX"] = pd_data.reset_index().index + sample_weight_col = "SAMPLE_WEIGHT" + pd_data[sample_weight_col] = np.array([random.randint(0, 100) for _ in range(pd_data.shape[0])]) + + snowml_classifier = XGBClassifier( + input_cols=feature_cols, + label_cols=label_column, + passthrough_cols="ROW_INDEX", + ) + xgb_classifier = XGB_XGBClassifier() + + param_grid = { + "max_depth": [80, 100], + } + + grid_search = GridSearchCV( + param_grid=param_grid, + estimator=snowml_classifier, + input_cols=feature_cols, + label_cols=label_column, + passthrough_cols="ROW_INDEX", + sample_weight_col=sample_weight_col, + ) + sk_grid_search = SkGridSearchCV(param_grid=param_grid, estimator=xgb_classifier) + + sk_grid_search.fit(pd_data[feature_cols], pd_data[label_column], sample_weight=pd_data[sample_weight_col]) + predictions = sk_grid_search.predict(pd_data[feature_cols]) + + raw_data = self._session.create_dataframe(pd_data) + grid_search.fit(raw_data) + snowml_predictions = ( + grid_search.predict(raw_data).to_pandas().sort_values(by=["ROW_INDEX"])["OUTPUT_LABEL"].to_numpy() + ) + + np.testing.assert_allclose(predictions.flatten(), snowml_predictions.flatten(), rtol=1.0e-3, atol=1.0e-3) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/extra_tests/xgboost_external_memory_training_test.py b/tests/integ/snowflake/ml/extra_tests/xgboost_external_memory_training_test.py index 1ab9100d..87517078 100644 --- a/tests/integ/snowflake/ml/extra_tests/xgboost_external_memory_training_test.py +++ b/tests/integ/snowflake/ml/extra_tests/xgboost_external_memory_training_test.py @@ -1,11 +1,13 @@ import numpy as np +import pandas as pd from absl.testing.absltest import TestCase, main +from importlib_resources import files from sklearn.metrics import accuracy_score as sk_accuracy_score from xgboost import XGBClassifier as NativeXGBClassifier from snowflake.ml.modeling.xgboost import XGBClassifier from snowflake.ml.utils.connection_params import SnowflakeLoginOptions -from snowflake.snowpark import Session, functions as F +from snowflake.snowpark import Session categorical_columns = [ "AGE", @@ -39,28 +41,29 @@ class XGBoostExternalMemoryTrainingTest(TestCase): def setUp(self): """Creates Snowpark and Snowflake environments for testing.""" self._session = Session.builder.configs(SnowflakeLoginOptions()).create() + data_file = files("tests.integ.snowflake.ml.test_data").joinpath("UCI_BANK_MARKETING_20COLUMNS.csv") + self._test_data = pd.read_csv(data_file, index_col=0) def tearDown(self): self._session.close() def test_fit_and_compare_results(self) -> None: - input_df = ( - self._session.sql( - """SELECT *, IFF(Y = 'yes', 1.0, 0.0) as LABEL - FROM ML_DATASETS.PUBLIC.UCI_BANK_MARKETING_20COLUMNS""" - ) - .drop("Y") - .withColumn("ROW_INDEX", F.monotonically_increasing_id()) - ) - pd_df = input_df.to_pandas().sort_values(by=["ROW_INDEX"])[numerical_columns + ["ROW_INDEX", "LABEL"]] - sp_df = self._session.create_dataframe(pd_df) + pd_data = self._test_data + pd_data["ROW_INDEX"] = pd_data.reset_index().index + + # Create the Snowpark DataFrame from pandas DataFrame + sp_df = self._session.create_dataframe(pd_data) + # Prepare the data for oss + pd_df = pd_data[numerical_columns + ["ROW_INDEX", "LABEL"]].sort_values(by=["ROW_INDEX"]) + + # Train oss model sk_reg = NativeXGBClassifier(random_state=0) sk_reg.fit(pd_df[numerical_columns], pd_df["LABEL"]) sk_result = sk_reg.predict(pd_df[numerical_columns]) - sk_accuracy = sk_accuracy_score(pd_df["LABEL"], sk_result) + # Train Snowflake model reg = XGBClassifier( random_state=0, input_cols=numerical_columns, diff --git a/tests/integ/snowflake/ml/feature_store/common_utils.py b/tests/integ/snowflake/ml/feature_store/common_utils.py index f7892638..debe8901 100644 --- a/tests/integ/snowflake/ml/feature_store/common_utils.py +++ b/tests/integ/snowflake/ml/feature_store/common_utils.py @@ -119,12 +119,14 @@ def is_object_expired(row: Row) -> bool: result = session.sql(f"SHOW SCHEMAS IN DATABASE {db}").collect() permanent_schemas = to_sql_identifiers(["INFORMATION_SCHEMA", "PUBLIC", FS_INTEG_TEST_DATASET_SCHEMA]) for row in result: - if SqlIdentifier(row["name"]) not in permanent_schemas and is_object_expired(row): - session.sql(f"DROP SCHEMA IF EXISTS {db}.{row['name']}").collect() + schema = SqlIdentifier(row["name"], case_sensitive=True) + if schema.resolved() not in permanent_schemas and is_object_expired(row): + session.sql(f"DROP SCHEMA IF EXISTS {db}.{schema.identifier()}").collect() full_schema_path = f"{FS_INTEG_TEST_DB}.{FS_INTEG_TEST_DATASET_SCHEMA}" result = session.sql(f"SHOW TABLES IN {full_schema_path}").collect() permanent_tables = to_sql_identifiers([FS_INTEG_TEST_YELLOW_TRIP_DATA, FS_INTEG_TEST_WINE_QUALITY_DATA]) for row in result: - if SqlIdentifier(row["name"]) not in permanent_tables and is_object_expired(row): - session.sql(f"DROP TABLE IF EXISTS {full_schema_path}.{row['name']}").collect() + table = SqlIdentifier(row["name"], case_sensitive=True) + if table.resolved() not in permanent_tables and is_object_expired(row): + session.sql(f"DROP TABLE IF EXISTS {full_schema_path}.{table.identifier()}").collect() diff --git a/tests/integ/snowflake/ml/feature_store/feature_store_test.py b/tests/integ/snowflake/ml/feature_store/feature_store_test.py index fc55f1a6..0927ea69 100644 --- a/tests/integ/snowflake/ml/feature_store/feature_store_test.py +++ b/tests/integ/snowflake/ml/feature_store/feature_store_test.py @@ -1987,7 +1987,7 @@ def test_update_static_feature_view(self) -> None: ): fs.update_feature_view("fv1", "v1", warehouse=self._session.get_current_warehouse()) - updated_fv = fs.update_feature_view("fv1", "v1", desc="") + updated_fv = fs.update_feature_view(fv, desc="") self.assertEqual(updated_fv.desc, "") def test_update_managed_feature_view(self) -> None: diff --git a/tests/integ/snowflake/ml/model/warehouse_huggingface_pipeline_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_huggingface_pipeline_model_integ_test.py index 942d3379..2242bd59 100644 --- a/tests/integ/snowflake/ml/model/warehouse_huggingface_pipeline_model_integ_test.py +++ b/tests/integ/snowflake/ml/model/warehouse_huggingface_pipeline_model_integ_test.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd from absl.testing import absltest, parameterized -from packaging import requirements +from packaging import requirements, version from snowflake.ml._internal import env_utils from snowflake.ml.model import type_hints as model_types @@ -92,6 +92,9 @@ def test_conversational_pipeline( # Only by doing so can we make the cache dir setting effective. import transformers + if version.parse(transformers.__version__) >= version.parse("4.42.0"): + self.skipTest("This test is not compatible with transformers>=4.42.0") + model = transformers.pipeline(task="conversational", model="ToddGoldfarb/Cadet-Tiny") x_df = pd.DataFrame( diff --git a/tests/integ/snowflake/ml/modeling/metrics/BUILD.bazel b/tests/integ/snowflake/ml/modeling/metrics/BUILD.bazel index e1954e23..107898dc 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/BUILD.bazel +++ b/tests/integ/snowflake/ml/modeling/metrics/BUILD.bazel @@ -146,6 +146,7 @@ py_test( srcs = ["mean_absolute_error_test.py"], shard_count = SHARD_COUNT, deps = [ + ":generator", "//snowflake/ml/modeling/metrics:regression", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/modeling/framework:utils", @@ -158,6 +159,7 @@ py_test( srcs = ["mean_absolute_percentage_error_test.py"], shard_count = SHARD_COUNT, deps = [ + ":generator", "//snowflake/ml/modeling/metrics:regression", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/modeling/framework:utils", @@ -170,6 +172,7 @@ py_test( srcs = ["mean_squared_error_test.py"], shard_count = SHARD_COUNT, deps = [ + ":generator", "//snowflake/ml/modeling/metrics:regression", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/modeling/framework:utils", @@ -194,6 +197,7 @@ py_test( srcs = ["precision_recall_curve_test.py"], shard_count = SHARD_COUNT, deps = [ + ":generator", "//snowflake/ml/modeling/metrics:ranking", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/modeling/framework:utils", diff --git a/tests/integ/snowflake/ml/modeling/metrics/mean_absolute_error_test.py b/tests/integ/snowflake/ml/modeling/metrics/mean_absolute_error_test.py index 209ee2ee..0d7f9018 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/mean_absolute_error_test.py +++ b/tests/integ/snowflake/ml/modeling/metrics/mean_absolute_error_test.py @@ -1,7 +1,8 @@ -from typing import Any, Dict +from typing import Optional, Union from unittest import mock import numpy as np +import numpy.typing as npt from absl.testing import parameterized from absl.testing.absltest import main from sklearn import metrics as sklearn_metrics @@ -10,35 +11,29 @@ from snowflake.ml.modeling import metrics as snowml_metrics from snowflake.ml.utils import connection_params from tests.integ.snowflake.ml.modeling.framework import utils +from tests.integ.snowflake.ml.modeling.metrics import generator -_ROWS = 100 _TYPES = [utils.DataType.INTEGER] * 4 + [utils.DataType.FLOAT] -_BINARY_DATA, _SF_SCHEMA = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=2, -) -_MULTICLASS_DATA, _ = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=5, -) +_BINARY_LOW, _BINARY_HIGH = 0, 2 +_MULTICLASS_LOW, _MULTICLASS_HIGH = 0, 5 +_BINARY_DATA_LIST, _SF_SCHEMA = generator.gen_test_cases(_TYPES, _BINARY_LOW, _BINARY_HIGH) +_MULTICLASS_DATA_LIST, _ = generator.gen_test_cases(_TYPES, _MULTICLASS_LOW, _MULTICLASS_HIGH) +_REGULAR_BINARY_DATA_LIST, _LARGE_BINARY_DATA = _BINARY_DATA_LIST[:-1], _BINARY_DATA_LIST[-1] +_REGULAR_MULTICLASS_DATA_LIST, _LARGE_MULTICLASS_DATA = _MULTICLASS_DATA_LIST[:-1], _MULTICLASS_DATA_LIST[-1] _Y_TRUE_COL = _SF_SCHEMA[1] _Y_PRED_COL = _SF_SCHEMA[2] _Y_TRUE_COLS = [_SF_SCHEMA[1], _SF_SCHEMA[2]] _Y_PRED_COLS = [_SF_SCHEMA[3], _SF_SCHEMA[4]] _SAMPLE_WEIGHT_COL = _SF_SCHEMA[5] -_MULTILABEL_DATA = [ - [1, 0, 1, 0.8, 0.3, 0.6], - [0, 1, 0, 0.2, 0.7, 0.4], - [1, 1, 0, 0.9, 0.6, 0.2], - [0, 0, 1, 0.1, 0.4, 0.8], -] -_MULTILABEL_SCHEMA = ["Y_0", "Y_1", "Y_2", "S_0", "S_1", "S_2"] -_MULTILABEL_Y_TRUE_COLS = [_MULTILABEL_SCHEMA[0], _MULTILABEL_SCHEMA[1], _MULTILABEL_SCHEMA[2]] -_MULTILABEL_Y_PRED_COLS = [_MULTILABEL_SCHEMA[3], _MULTILABEL_SCHEMA[4], _MULTILABEL_SCHEMA[5]] + +_MULTILABEL_TYPES = [utils.DataType.INTEGER] * 3 + [utils.DataType.FLOAT] * 3 +_MULTILABEL_LOW, _MULTILABEL_HIGH = 0, [2, 2, 2, 1, 1, 1] +_MULTILABEL_DATA_LIST, _MULTILABEL_SCHEMA = generator.gen_test_cases( + _MULTILABEL_TYPES, _MULTILABEL_LOW, _MULTILABEL_HIGH +) +_REGULAR_MULTILABEL_DATA_LIST, _LARGE_MULTILABEL_DATA = _MULTILABEL_DATA_LIST[:-1], _MULTILABEL_DATA_LIST[-1] +_MULTILABEL_Y_TRUE_COLS = [_MULTILABEL_SCHEMA[1], _MULTILABEL_SCHEMA[2], _MULTILABEL_SCHEMA[3]] +_MULTILABEL_Y_PRED_COLS = [_MULTILABEL_SCHEMA[4], _MULTILABEL_SCHEMA[5], _MULTILABEL_SCHEMA[6]] class MeanAbsoluteErrorTest(parameterized.TestCase): @@ -51,61 +46,52 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "sample_weight_col_name": [None, _SAMPLE_WEIGHT_COL], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], + ) + def test_sample_weight(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.mean_absolute_error( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_loss = sklearn_metrics.mean_absolute_error( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTILABEL_DATA_LIST))), + multioutput=["raw_values", "uniform_average", [0.2, 1.0, 1.66]], ) - def test_sample_weight(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for sample_weight_col_name in params["sample_weight_col_name"]: - actual_loss = snowml_metrics.mean_absolute_error( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - sample_weight_col_name=sample_weight_col_name, - ) - sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None - sklearn_loss = sklearn_metrics.mean_absolute_error( - pandas_df[y_true], - pandas_df[y_pred], - sample_weight=sample_weight, - ) - self.assertAlmostEqual(sklearn_loss, actual_loss) - - @parameterized.parameters( # type: ignore[misc] - {"params": {"multioutput": ["raw_values", "uniform_average", [0.2, 1.0, 1.66]]}}, + def test_multioutput(self, data_index: int, multioutput: Union[str, npt.ArrayLike]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTILABEL_DATA_LIST[data_index], _MULTILABEL_SCHEMA) + + actual_loss = snowml_metrics.mean_absolute_error( + df=input_df, + y_true_col_names=_MULTILABEL_Y_TRUE_COLS, + y_pred_col_names=_MULTILABEL_Y_PRED_COLS, + multioutput=multioutput, + ) + sklearn_loss = sklearn_metrics.mean_absolute_error( + pandas_df[_MULTILABEL_Y_TRUE_COLS], + pandas_df[_MULTILABEL_Y_PRED_COLS], + multioutput=multioutput, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTILABEL_DATA_LIST))), ) - def test_multioutput(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTILABEL_DATA, _MULTILABEL_SCHEMA) - - for multioutput in params["multioutput"]: - actual_loss = snowml_metrics.mean_absolute_error( - df=input_df, - y_true_col_names=_MULTILABEL_Y_TRUE_COLS, - y_pred_col_names=_MULTILABEL_Y_PRED_COLS, - multioutput=multioutput, - ) - sklearn_loss = sklearn_metrics.mean_absolute_error( - pandas_df[_MULTILABEL_Y_TRUE_COLS], - pandas_df[_MULTILABEL_Y_PRED_COLS], - multioutput=multioutput, - ) - np.testing.assert_allclose(actual_loss, sklearn_loss) - - def test_multilabel(self) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTILABEL_DATA, _MULTILABEL_SCHEMA) + def test_multilabel(self, data_index: int) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTILABEL_DATA_LIST[data_index], _MULTILABEL_SCHEMA) actual_loss = snowml_metrics.mean_absolute_error( df=input_df, @@ -116,11 +102,14 @@ def test_multilabel(self) -> None: pandas_df[_MULTILABEL_Y_TRUE_COLS], pandas_df[_MULTILABEL_Y_PRED_COLS], ) - self.assertAlmostEqual(sklearn_loss, actual_loss) + np.testing.assert_allclose(actual_loss, sklearn_loss) + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + ) @mock.patch("snowflake.ml.modeling.metrics.regression.result._RESULT_SIZE_THRESHOLD", 0) - def test_metric_size_threshold(self) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_metric_size_threshold(self, data_index: int) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) actual_loss = snowml_metrics.mean_absolute_error( df=input_df, @@ -131,7 +120,7 @@ def test_metric_size_threshold(self) -> None: pandas_df[_Y_TRUE_COLS], pandas_df[_Y_PRED_COLS], ) - self.assertAlmostEqual(sklearn_loss, actual_loss) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/modeling/metrics/mean_absolute_percentage_error_test.py b/tests/integ/snowflake/ml/modeling/metrics/mean_absolute_percentage_error_test.py index bc608c63..3c954c3e 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/mean_absolute_percentage_error_test.py +++ b/tests/integ/snowflake/ml/modeling/metrics/mean_absolute_percentage_error_test.py @@ -1,7 +1,8 @@ -from typing import Any, Dict +from typing import Optional, Union from unittest import mock import numpy as np +import numpy.typing as npt from absl.testing import parameterized from absl.testing.absltest import main from sklearn import metrics as sklearn_metrics @@ -10,35 +11,29 @@ from snowflake.ml.modeling import metrics as snowml_metrics from snowflake.ml.utils import connection_params from tests.integ.snowflake.ml.modeling.framework import utils +from tests.integ.snowflake.ml.modeling.metrics import generator -_ROWS = 100 _TYPES = [utils.DataType.INTEGER] * 4 + [utils.DataType.FLOAT] -_BINARY_DATA, _SF_SCHEMA = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=2, -) -_MULTICLASS_DATA, _ = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=5, -) +_BINARY_LOW, _BINARY_HIGH = 0, 2 +_MULTICLASS_LOW, _MULTICLASS_HIGH = 0, 5 +_BINARY_DATA_LIST, _SF_SCHEMA = generator.gen_test_cases(_TYPES, _BINARY_LOW, _BINARY_HIGH) +_MULTICLASS_DATA_LIST, _ = generator.gen_test_cases(_TYPES, _MULTICLASS_LOW, _MULTICLASS_HIGH) +_REGULAR_BINARY_DATA_LIST, _LARGE_BINARY_DATA = _BINARY_DATA_LIST[:-1], _BINARY_DATA_LIST[-1] +_REGULAR_MULTICLASS_DATA_LIST, _LARGE_MULTICLASS_DATA = _MULTICLASS_DATA_LIST[:-1], _MULTICLASS_DATA_LIST[-1] _Y_TRUE_COL = _SF_SCHEMA[1] _Y_PRED_COL = _SF_SCHEMA[2] _Y_TRUE_COLS = [_SF_SCHEMA[1], _SF_SCHEMA[2]] _Y_PRED_COLS = [_SF_SCHEMA[3], _SF_SCHEMA[4]] _SAMPLE_WEIGHT_COL = _SF_SCHEMA[5] -_MULTILABEL_DATA = [ - [1, 0, 1, 0.8, 0.3, 0.6], - [0, 1, 0, 0.2, 0.7, 0.4], - [1, 1, 0, 0.9, 0.6, 0.2], - [0, 0, 1, 0.1, 0.4, 0.8], -] -_MULTILABEL_SCHEMA = ["Y_0", "Y_1", "Y_2", "S_0", "S_1", "S_2"] -_MULTILABEL_Y_TRUE_COLS = [_MULTILABEL_SCHEMA[0], _MULTILABEL_SCHEMA[1], _MULTILABEL_SCHEMA[2]] -_MULTILABEL_Y_PRED_COLS = [_MULTILABEL_SCHEMA[3], _MULTILABEL_SCHEMA[4], _MULTILABEL_SCHEMA[5]] + +_MULTILABEL_TYPES = [utils.DataType.INTEGER] * 3 + [utils.DataType.FLOAT] * 3 +_MULTILABEL_LOW, _MULTILABEL_HIGH = 0, [2, 2, 2, 1, 1, 1] +_MULTILABEL_DATA_LIST, _MULTILABEL_SCHEMA = generator.gen_test_cases( + _MULTILABEL_TYPES, _MULTILABEL_LOW, _MULTILABEL_HIGH +) +_REGULAR_MULTILABEL_DATA_LIST, _LARGE_MULTILABEL_DATA = _MULTILABEL_DATA_LIST[:-1], _MULTILABEL_DATA_LIST[-1] +_MULTILABEL_Y_TRUE_COLS = [_MULTILABEL_SCHEMA[1], _MULTILABEL_SCHEMA[2], _MULTILABEL_SCHEMA[3]] +_MULTILABEL_Y_PRED_COLS = [_MULTILABEL_SCHEMA[4], _MULTILABEL_SCHEMA[5], _MULTILABEL_SCHEMA[6]] class MeanAbsolutePercentageErrorTest(parameterized.TestCase): @@ -51,61 +46,52 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "sample_weight_col_name": [None, _SAMPLE_WEIGHT_COL], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], + ) + def test_sample_weight(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.mean_absolute_percentage_error( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_loss = sklearn_metrics.mean_absolute_percentage_error( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTILABEL_DATA_LIST))), + multioutput=["raw_values", "uniform_average", [0.2, 1.0, 1.66]], ) - def test_sample_weight(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for sample_weight_col_name in params["sample_weight_col_name"]: - actual_loss = snowml_metrics.mean_absolute_percentage_error( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - sample_weight_col_name=sample_weight_col_name, - ) - sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None - sklearn_loss = sklearn_metrics.mean_absolute_percentage_error( - pandas_df[y_true], - pandas_df[y_pred], - sample_weight=sample_weight, - ) - np.testing.assert_approx_equal(sklearn_loss, actual_loss) - - @parameterized.parameters( # type: ignore[misc] - {"params": {"multioutput": ["raw_values", "uniform_average", [0.2, 1.0, 1.66]]}}, + def test_multioutput(self, data_index: int, multioutput: Union[str, npt.ArrayLike]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTILABEL_DATA_LIST[data_index], _MULTILABEL_SCHEMA) + + actual_loss = snowml_metrics.mean_absolute_percentage_error( + df=input_df, + y_true_col_names=_MULTILABEL_Y_TRUE_COLS, + y_pred_col_names=_MULTILABEL_Y_PRED_COLS, + multioutput=multioutput, + ) + sklearn_loss = sklearn_metrics.mean_absolute_percentage_error( + pandas_df[_MULTILABEL_Y_TRUE_COLS], + pandas_df[_MULTILABEL_Y_PRED_COLS], + multioutput=multioutput, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTILABEL_DATA_LIST))), ) - def test_multioutput(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTILABEL_DATA, _MULTILABEL_SCHEMA) - - for multioutput in params["multioutput"]: - actual_loss = snowml_metrics.mean_absolute_percentage_error( - df=input_df, - y_true_col_names=_MULTILABEL_Y_TRUE_COLS, - y_pred_col_names=_MULTILABEL_Y_PRED_COLS, - multioutput=multioutput, - ) - sklearn_loss = sklearn_metrics.mean_absolute_percentage_error( - pandas_df[_MULTILABEL_Y_TRUE_COLS], - pandas_df[_MULTILABEL_Y_PRED_COLS], - multioutput=multioutput, - ) - np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=0.000001) - - def test_multilabel(self) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTILABEL_DATA, _MULTILABEL_SCHEMA) + def test_multilabel(self, data_index: int) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTILABEL_DATA_LIST[data_index], _MULTILABEL_SCHEMA) actual_loss = snowml_metrics.mean_absolute_percentage_error( df=input_df, @@ -116,11 +102,14 @@ def test_multilabel(self) -> None: pandas_df[_MULTILABEL_Y_TRUE_COLS], pandas_df[_MULTILABEL_Y_PRED_COLS], ) - np.testing.assert_approx_equal(sklearn_loss, actual_loss) + np.testing.assert_allclose(actual_loss, sklearn_loss) + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + ) @mock.patch("snowflake.ml.modeling.metrics.regression.result._RESULT_SIZE_THRESHOLD", 0) - def test_metric_size_threshold(self) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_metric_size_threshold(self, data_index: int) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) actual_loss = snowml_metrics.mean_absolute_percentage_error( df=input_df, @@ -131,7 +120,7 @@ def test_metric_size_threshold(self) -> None: pandas_df[_Y_TRUE_COLS], pandas_df[_Y_PRED_COLS], ) - np.testing.assert_approx_equal(sklearn_loss, actual_loss) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/modeling/metrics/mean_squared_error_test.py b/tests/integ/snowflake/ml/modeling/metrics/mean_squared_error_test.py index a8fbd615..d77ab353 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/mean_squared_error_test.py +++ b/tests/integ/snowflake/ml/modeling/metrics/mean_squared_error_test.py @@ -1,7 +1,8 @@ -from typing import Any, Dict +from typing import Optional, Union from unittest import mock import numpy as np +import numpy.typing as npt from absl.testing import parameterized from absl.testing.absltest import main from sklearn import metrics as sklearn_metrics @@ -10,35 +11,29 @@ from snowflake.ml.modeling import metrics as snowml_metrics from snowflake.ml.utils import connection_params from tests.integ.snowflake.ml.modeling.framework import utils +from tests.integ.snowflake.ml.modeling.metrics import generator -_ROWS = 100 _TYPES = [utils.DataType.INTEGER] * 4 + [utils.DataType.FLOAT] -_BINARY_DATA, _SF_SCHEMA = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=2, -) -_MULTICLASS_DATA, _ = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=5, -) +_BINARY_LOW, _BINARY_HIGH = 0, 2 +_MULTICLASS_LOW, _MULTICLASS_HIGH = 0, 5 +_BINARY_DATA_LIST, _SF_SCHEMA = generator.gen_test_cases(_TYPES, _BINARY_LOW, _BINARY_HIGH) +_MULTICLASS_DATA_LIST, _ = generator.gen_test_cases(_TYPES, _MULTICLASS_LOW, _MULTICLASS_HIGH) +_REGULAR_BINARY_DATA_LIST, _LARGE_BINARY_DATA = _BINARY_DATA_LIST[:-1], _BINARY_DATA_LIST[-1] +_REGULAR_MULTICLASS_DATA_LIST, _LARGE_MULTICLASS_DATA = _MULTICLASS_DATA_LIST[:-1], _MULTICLASS_DATA_LIST[-1] _Y_TRUE_COL = _SF_SCHEMA[1] _Y_PRED_COL = _SF_SCHEMA[2] _Y_TRUE_COLS = [_SF_SCHEMA[1], _SF_SCHEMA[2]] _Y_PRED_COLS = [_SF_SCHEMA[3], _SF_SCHEMA[4]] _SAMPLE_WEIGHT_COL = _SF_SCHEMA[5] -_MULTILABEL_DATA = [ - [1, 0, 1, 0.8, 0.3, 0.6], - [0, 1, 0, 0.2, 0.7, 0.4], - [1, 1, 0, 0.9, 0.6, 0.2], - [0, 0, 1, 0.1, 0.4, 0.8], -] -_MULTILABEL_SCHEMA = ["Y_0", "Y_1", "Y_2", "S_0", "S_1", "S_2"] -_MULTILABEL_Y_TRUE_COLS = [_MULTILABEL_SCHEMA[0], _MULTILABEL_SCHEMA[1], _MULTILABEL_SCHEMA[2]] -_MULTILABEL_Y_PRED_COLS = [_MULTILABEL_SCHEMA[3], _MULTILABEL_SCHEMA[4], _MULTILABEL_SCHEMA[5]] + +_MULTILABEL_TYPES = [utils.DataType.INTEGER] * 3 + [utils.DataType.FLOAT] * 3 +_MULTILABEL_LOW, _MULTILABEL_HIGH = 0, [2, 2, 2, 1, 1, 1] +_MULTILABEL_DATA_LIST, _MULTILABEL_SCHEMA = generator.gen_test_cases( + _MULTILABEL_TYPES, _MULTILABEL_LOW, _MULTILABEL_HIGH +) +_REGULAR_MULTILABEL_DATA_LIST, _LARGE_MULTILABEL_DATA = _MULTILABEL_DATA_LIST[:-1], _MULTILABEL_DATA_LIST[-1] +_MULTILABEL_Y_TRUE_COLS = [_MULTILABEL_SCHEMA[1], _MULTILABEL_SCHEMA[2], _MULTILABEL_SCHEMA[3]] +_MULTILABEL_Y_PRED_COLS = [_MULTILABEL_SCHEMA[4], _MULTILABEL_SCHEMA[5], _MULTILABEL_SCHEMA[6]] class MeanSquaredErrorTest(parameterized.TestCase): @@ -51,93 +46,72 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "sample_weight_col_name": [None, _SAMPLE_WEIGHT_COL], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], ) - def test_sample_weight(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for sample_weight_col_name in params["sample_weight_col_name"]: - actual_loss = snowml_metrics.mean_squared_error( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - sample_weight_col_name=sample_weight_col_name, - ) - sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None - sklearn_loss = sklearn_metrics.mean_squared_error( - pandas_df[y_true], - pandas_df[y_pred], - sample_weight=sample_weight, - ) - self.assertAlmostEqual(sklearn_loss, actual_loss) - - @parameterized.parameters( # type: ignore[misc] - {"params": {"multioutput": ["raw_values", "uniform_average", [0.2, 1.0, 1.66]]}}, + def test_sample_weight(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.mean_squared_error( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_loss = sklearn_metrics.mean_squared_error( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTILABEL_DATA_LIST))), + multioutput=["raw_values", "uniform_average", [0.2, 1.0, 1.66]], ) - def test_multioutput(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTILABEL_DATA, _MULTILABEL_SCHEMA) - - for multioutput in params["multioutput"]: - actual_loss = snowml_metrics.mean_squared_error( - df=input_df, - y_true_col_names=_MULTILABEL_Y_TRUE_COLS, - y_pred_col_names=_MULTILABEL_Y_PRED_COLS, - multioutput=multioutput, - ) - sklearn_loss = sklearn_metrics.mean_squared_error( - pandas_df[_MULTILABEL_Y_TRUE_COLS], - pandas_df[_MULTILABEL_Y_PRED_COLS], - multioutput=multioutput, - ) - np.testing.assert_allclose(actual_loss, sklearn_loss) - - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "squared": [True, False], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + def test_multioutput(self, data_index: int, multioutput: Union[str, npt.ArrayLike]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTILABEL_DATA_LIST[data_index], _MULTILABEL_SCHEMA) + + actual_loss = snowml_metrics.mean_squared_error( + df=input_df, + y_true_col_names=_MULTILABEL_Y_TRUE_COLS, + y_pred_col_names=_MULTILABEL_Y_PRED_COLS, + multioutput=multioutput, + ) + sklearn_loss = sklearn_metrics.mean_squared_error( + pandas_df[_MULTILABEL_Y_TRUE_COLS], + pandas_df[_MULTILABEL_Y_PRED_COLS], + multioutput=multioutput, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + squared=[True, False], ) - def test_squared(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for squared in params["squared"]: - actual_loss = snowml_metrics.mean_squared_error( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - squared=squared, - ) - sklearn_loss = sklearn_metrics.mean_squared_error( - pandas_df[y_true], - pandas_df[y_pred], - squared=squared, - ) - self.assertAlmostEqual(sklearn_loss, actual_loss) - - def test_multilabel(self) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTILABEL_DATA, _MULTILABEL_SCHEMA) + def test_squared(self, data_index: int, squared: bool) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.mean_squared_error( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + squared=squared, + ) + sklearn_loss = sklearn_metrics.mean_squared_error( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + squared=squared, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTILABEL_DATA_LIST))), + ) + def test_multilabel(self, data_index: int) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTILABEL_DATA_LIST[data_index], _MULTILABEL_SCHEMA) actual_loss = snowml_metrics.mean_squared_error( df=input_df, @@ -148,22 +122,25 @@ def test_multilabel(self) -> None: pandas_df[_MULTILABEL_Y_TRUE_COLS], pandas_df[_MULTILABEL_Y_PRED_COLS], ) - self.assertAlmostEqual(sklearn_loss, actual_loss) + np.testing.assert_allclose(actual_loss, sklearn_loss) + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + ) @mock.patch("snowflake.ml.modeling.metrics.regression.result._RESULT_SIZE_THRESHOLD", 0) - def test_metric_size_threshold(self) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_metric_size_threshold(self, data_index: int) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) actual_loss = snowml_metrics.mean_squared_error( df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, ) sklearn_loss = sklearn_metrics.mean_squared_error( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], ) - self.assertAlmostEqual(sklearn_loss, actual_loss) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/modeling/metrics/precision_recall_curve_test.py b/tests/integ/snowflake/ml/modeling/metrics/precision_recall_curve_test.py index 500d1e0b..1b30151c 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/precision_recall_curve_test.py +++ b/tests/integ/snowflake/ml/modeling/metrics/precision_recall_curve_test.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Optional, Union from unittest import mock import numpy as np @@ -10,18 +10,15 @@ from snowflake.ml.modeling import metrics as snowml_metrics from snowflake.ml.utils import connection_params from tests.integ.snowflake.ml.modeling.framework import utils +from tests.integ.snowflake.ml.modeling.metrics import generator -_ROWS = 100 -_TYPES = [utils.DataType.INTEGER] + [utils.DataType.FLOAT] * 2 -_BINARY_DATA, _SF_SCHEMA = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=[2, 1, 1], -) +_TYPES = [utils.DataType.INTEGER] * 4 + [utils.DataType.FLOAT] +_BINARY_LOW, _BINARY_HIGH = 0, 2 +_BINARY_DATA_LIST, _SF_SCHEMA = generator.gen_test_cases(_TYPES, _BINARY_LOW, _BINARY_HIGH) +_REGULAR_BINARY_DATA_LIST, _LARGE_BINARY_DATA = _BINARY_DATA_LIST[:-1], _BINARY_DATA_LIST[-1] _Y_TRUE_COL = _SF_SCHEMA[1] _PROBAS_PRED_COL = _SF_SCHEMA[2] -_SAMPLE_WEIGHT_COL = _SF_SCHEMA[3] +_SAMPLE_WEIGHT_COL = _SF_SCHEMA[5] class PrecisionRecallCurveTest(parameterized.TestCase): @@ -34,54 +31,57 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @parameterized.parameters( # type: ignore[misc] - {"params": {"pos_label": [0, 2, 4]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_BINARY_DATA_LIST))), + pos_label=[0, 2, 4], ) - def test_pos_label(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_pos_label(self, data_index: int, pos_label: Union[str, int]) -> None: + pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA_LIST[data_index], _SF_SCHEMA) - for pos_label in params["pos_label"]: - actual_precision, actual_recall, actual_thresholds = snowml_metrics.precision_recall_curve( - df=input_df, - y_true_col_name=_Y_TRUE_COL, - probas_pred_col_name=_PROBAS_PRED_COL, - pos_label=pos_label, - ) - sklearn_precision, sklearn_recall, sklearn_thresholds = sklearn_metrics.precision_recall_curve( - pandas_df[_Y_TRUE_COL], - pandas_df[_PROBAS_PRED_COL], - pos_label=pos_label, - ) - np.testing.assert_allclose(actual_precision, sklearn_precision) - np.testing.assert_allclose(actual_recall, sklearn_recall) - np.testing.assert_allclose(actual_thresholds, sklearn_thresholds) + actual_precision, actual_recall, actual_thresholds = snowml_metrics.precision_recall_curve( + df=input_df, + y_true_col_name=_Y_TRUE_COL, + probas_pred_col_name=_PROBAS_PRED_COL, + pos_label=pos_label, + ) + sklearn_precision, sklearn_recall, sklearn_thresholds = sklearn_metrics.precision_recall_curve( + pandas_df[_Y_TRUE_COL], + pandas_df[_PROBAS_PRED_COL], + pos_label=pos_label, + ) + np.testing.assert_allclose(actual_precision, sklearn_precision) + np.testing.assert_allclose(actual_recall, sklearn_recall) + np.testing.assert_allclose(actual_thresholds, sklearn_thresholds) - @parameterized.parameters( # type: ignore[misc] - {"params": {"sample_weight_col_name": [None, _SAMPLE_WEIGHT_COL]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], ) - def test_sample_weight(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_sample_weight(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) - for sample_weight_col_name in params["sample_weight_col_name"]: - actual_precision, actual_recall, actual_thresholds = snowml_metrics.precision_recall_curve( - df=input_df, - y_true_col_name=_Y_TRUE_COL, - probas_pred_col_name=_PROBAS_PRED_COL, - sample_weight_col_name=sample_weight_col_name, - ) - sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None - sklearn_precision, sklearn_recall, sklearn_thresholds = sklearn_metrics.precision_recall_curve( - pandas_df[_Y_TRUE_COL], - pandas_df[_PROBAS_PRED_COL], - sample_weight=sample_weight, - ) - np.testing.assert_allclose(actual_precision, sklearn_precision) - np.testing.assert_allclose(actual_recall, sklearn_recall) - np.testing.assert_allclose(actual_thresholds, sklearn_thresholds) + actual_precision, actual_recall, actual_thresholds = snowml_metrics.precision_recall_curve( + df=input_df, + y_true_col_name=_Y_TRUE_COL, + probas_pred_col_name=_PROBAS_PRED_COL, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_precision, sklearn_recall, sklearn_thresholds = sklearn_metrics.precision_recall_curve( + pandas_df[_Y_TRUE_COL], + pandas_df[_PROBAS_PRED_COL], + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_precision, sklearn_precision) + np.testing.assert_allclose(actual_recall, sklearn_recall) + np.testing.assert_allclose(actual_thresholds, sklearn_thresholds) + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + ) @mock.patch("snowflake.ml.modeling.metrics.ranking.result._RESULT_SIZE_THRESHOLD", 0) - def test_metric_size_threshold(self) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_metric_size_threshold(self, data_index: int) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) actual_precision, actual_recall, actual_thresholds = snowml_metrics.precision_recall_curve( df=input_df, diff --git a/tests/integ/snowflake/ml/observability/BUILD.bazel b/tests/integ/snowflake/ml/observability/BUILD.bazel new file mode 100644 index 00000000..f72c76ec --- /dev/null +++ b/tests/integ/snowflake/ml/observability/BUILD.bazel @@ -0,0 +1,16 @@ +load("//bazel:py_rules.bzl", "py_test") + +py_test( + name = "model_monitor_integ_test", + timeout = "long", + srcs = ["model_monitor_integ_test.py"], + deps = [ + "//snowflake/ml/beta/observability:observability_lib", + "//snowflake/ml/model/_client/model:model_version_impl", + "//snowflake/ml/registry:registry_impl", + "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:db_manager", + "//tests/integ/snowflake/ml/test_utils:model_factory", + "//tests/integ/snowflake/ml/test_utils:test_env_utils", + ], +) diff --git a/tests/integ/snowflake/ml/observability/model_monitor_integ_test.py b/tests/integ/snowflake/ml/observability/model_monitor_integ_test.py new file mode 100644 index 00000000..17436eae --- /dev/null +++ b/tests/integ/snowflake/ml/observability/model_monitor_integ_test.py @@ -0,0 +1,263 @@ +import uuid + +import pandas as pd +from absl.testing import absltest, parameterized + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.beta.observability import ( + model_monitor, + model_monitor_config, + model_monitor_registry, + monitor_sql_client, +) +from snowflake.ml.model._client.model import model_version_impl +from snowflake.ml.registry import registry +from snowflake.ml.utils import connection_params +from snowflake.snowpark import Session +from tests.integ.snowflake.ml.test_utils import db_manager, model_factory + + +class ModelMonitorRegistryIntegrationTest(parameterized.TestCase): + def _create_test_table(self, fully_qualified_table_name: str): + self._session.sql( + f"""CREATE OR REPLACE TABLE {fully_qualified_table_name} + (label FLOAT, prediction FLOAT, F1 FLOAT, id STRING, timestamp TIMESTAMP)""" + ).collect() + + @classmethod + def setUpClass(cls) -> None: + """Creates Snowpark and Snowflake environments for testing.""" + cls._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() + + def setUp(self) -> None: + """Creates Snowpark and Snowflake environments for testing.""" + self.run_id = uuid.uuid4().hex + self._db_manager = db_manager.DBManager(self._session) + self._schema_name = "PUBLIC" + # TODO(jfishbein): Investigate whether conversion to sql identifier requires uppercase. + self._db_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( + self.run_id, "monitor_registry" + ).upper() + self._session.sql(f"CREATE DATABASE IF NOT EXISTS {self._db_name}").collect() + self.perm_stage = "@" + self._db_manager.create_stage( + stage_name="model_registry_test_stage", + schema_name=self._schema_name, + db_name=self._db_name, + sse_encrypted=True, + ) + model_monitor_registry.ModelMonitorRegistry.setup( + session=self._session, database_name=self._db_name, schema_name=self._schema_name + ) + self._warehouse_name = "REGTEST_ML_SMALL" + self._db_manager.set_warehouse(self._warehouse_name) + + self._db_manager.cleanup_databases(expire_hours=6) + self.registry = registry.Registry(self._session, database_name=self._db_name, schema_name=self._schema_name) + + def tearDown(self) -> None: + self._db_manager.drop_database(self._db_name) + super().tearDown() + + @classmethod + def tearDownClass(cls) -> None: + cls._session.close() + + def _add_sample_model_version_and_monitor( + self, + monitor_registry: model_monitor_registry.ModelMonitorRegistry, + source_table: str, + model_name: str, + version_name: str, + monitor_name: str, + ) -> model_monitor.ModelMonitor: + model, features, _ = model_factory.ModelFactory.prepare_sklearn_model() + model_version: model_version_impl.ModelVersion = self.registry.log_model( + model=model, + model_name=model_name, + version_name=version_name, + sample_input_data=features, + ) + + return monitor_registry.add_monitor( + name=monitor_name, + source_table_name=source_table, + model_monitor_config=model_monitor_config.ModelMonitorConfig( + model_version=model_version, + model_function_name="predict", + prediction_columns=["prediction"], + label_columns=["label"], + id_columns=["id"], + timestamp_column="timestamp", + background_compute_warehouse_name=self._warehouse_name, + ), + ) + + def test_add_model_monitor(self) -> None: + # Create an instance of the ModelMonitorRegistry class + _monitor_registry = model_monitor_registry.ModelMonitorRegistry( + session=self._session, database_name=self._db_name, schema_name=self._schema_name + ) + + source_table_name = "TEST_TABLE" + self._create_test_table(f"{self._db_name}.{self._schema_name}.{source_table_name}") + + model_name = "TEST_MODEL" + version_name = "TEST_VERSION" + monitor_name = f"TEST_MONITOR_{model_name}_{version_name}_{self.run_id}" + monitor = self._add_sample_model_version_and_monitor( + _monitor_registry, source_table_name, model_name, version_name, monitor_name + ) + + self.assertEqual( + self._session.sql( + f"""SELECT * + FROM {self._db_name}.{self._schema_name}.{monitor_sql_client.SNOWML_MONITORING_METADATA_TABLE_NAME} + WHERE FULLY_QUALIFIED_MODEL_NAME = '{self._db_name}.{self._schema_name}.{model_name}' AND + MODEL_VERSION_NAME = '{version_name}'""" + ).count(), + 1, + ) + + self.assertEqual( + self._session.sql( + f"""SELECT * + FROM {self._db_name}.{self._schema_name}. + _SNOWML_OBS_BASELINE_{model_name}_{version_name}""" + ).count(), + 0, + ) + + table_columns = self._session.sql( + f"""DESCRIBE TABLE + {self._db_name}.{self._schema_name}._SNOWML_OBS_BASELINE_{model_name}_{version_name}""" + ).collect() + + for col in table_columns: + self.assertTrue(col["name"].upper() in ["PREDICTION", "LABEL", "F1", "ID", "TIMESTAMP"]) + + df = self._session.create_dataframe( + [ + (1.0, 1.0, 1.0, "1", "2021-01-01 00:00:00"), + (2.0, 2.0, 2.0, "2", "2021-01-01 00:00:00"), + ], + ["LABEL", "PREDICTION", "F1", "ID", "TIMESTAMP"], + ) + monitor.set_baseline(df) + self.assertEqual( + self._session.sql( + f"""SELECT * + FROM {self._db_name}.{self._schema_name}. + _SNOWML_OBS_BASELINE_{model_name}_{version_name}""" + ).count(), + 2, + ) + + pandas_df = pd.DataFrame( + { + "LABEL": [1.0, 2.0, 3.0], + "PREDICTION": [1.0, 2.0, 3.0], + "F1": [1.0, 2.0, 3.0], + "ID": ["1", "2", "3"], + "TIMESTAMP": ["2021-01-01 00:00:00", "2021-01-01 00:00:00", "2021-01-01 00:00:00"], + } + ) + monitor.set_baseline(pandas_df) + self.assertEqual( + self._session.sql( + f"""SELECT * + FROM {self._db_name}.{self._schema_name}. + _SNOWML_OBS_BASELINE_{model_name}_{version_name}""" + ).count(), + 3, + ) + + # create a snowpark dataframe that does not conform to the existing schema + df = self._session.create_dataframe( + [ + (1.0, "bad", 1.0, "1", "2021-01-01 00:00:00"), + (2.0, "very_bad", 2.0, "2", "2021-01-01 00:00:00"), + ], + ["LABEL", "PREDICTION", "F1", "ID", "TIMESTAMP"], + ) + with self.assertRaises(ValueError) as e: + monitor.set_baseline(df) + + expected_msg = "Ensure that the baseline dataframe columns match those provided in your monitored table" + self.assertTrue(expected_msg in str(e.exception)) + expected_msg = "Numeric value 'bad' is not recognized" + self.assertTrue(expected_msg in str(e.exception)) + + # Delete monitor + _monitor_registry.delete_monitor(monitor_name) + + # Validate that metadata entry is deleted + self.assertEqual( + self._session.sql( + f"""SELECT * + FROM {self._db_name}.{self._schema_name}.{monitor_sql_client.SNOWML_MONITORING_METADATA_TABLE_NAME} + WHERE MONITOR_NAME = '{monitor.name}'""" + ).count(), + 0, + ) + + # Validate that baseline table is deleted + self.assertEqual( + self._session.sql( + f"""SHOW TABLES LIKE '%{self._db_name}.{self._schema_name}. + _SNOWML_OBS_BASELINE_{model_name}_{version_name}%'""" + ).count(), + 0, + ) + + # Validate that dynamic tables are deleted + self.assertEqual( + self._session.sql( + f"""SHOW TABLES LIKE '%{self._db_name}.{self._schema_name}. + _SNOWML_OBS_MONITORING_{model_name}_{version_name}%'""" + ).count(), + 0, + ) + self.assertEqual( + self._session.sql( + f"""SHOW TABLES LIKE '%{self._db_name}.{self._schema_name}. + _SNOWML_OBS_ACCURACY_{model_name}_{version_name}%'""" + ).count(), + 0, + ) + + def test_show_model_monitors(self) -> None: + _monitor_registry = model_monitor_registry.ModelMonitorRegistry( + session=self._session, database_name=self._db_name, schema_name=self._schema_name + ) + source_table_1 = "TEST_TABLE_1" + self._create_test_table(f"{self._db_name}.{self._schema_name}.{source_table_1}") + + source_table_2 = "TEST_TABLE_2" + self._create_test_table(f"{self._db_name}.{self._schema_name}.{source_table_2}") + + model_1 = "TEST_MODEL_1" + version_1 = "TEST_VERSION_1" + monitor_1 = f"TEST_MONITOR_{model_1}_{version_1}_{self.run_id}" + self._add_sample_model_version_and_monitor(_monitor_registry, source_table_1, model_1, version_1, monitor_1) + + model_2 = "TEST_MODEL_2" + version_2 = "TEST_VERSION_2" + monitor_2 = f"TEST_MONITOR_{model_2}_{version_2}_{self.run_id}" + self._add_sample_model_version_and_monitor(_monitor_registry, source_table_2, model_2, version_2, monitor_2) + + stored_monitors = sorted(_monitor_registry.show_model_monitors(), key=lambda x: x["MONITOR_NAME"]) + self.assertEqual(len(stored_monitors), 2) + row_1 = stored_monitors[0] + self.assertEqual(row_1["MONITOR_NAME"], sql_identifier.SqlIdentifier(monitor_1)) + self.assertEqual(row_1["SOURCE_TABLE_NAME"], source_table_1) + self.assertEqual(row_1["MODEL_VERSION_NAME"], version_1) + self.assertEqual(row_1["IS_ENABLED"], True) + row_2 = stored_monitors[1] + self.assertEqual(row_2["MONITOR_NAME"], sql_identifier.SqlIdentifier(monitor_2)) + self.assertEqual(row_2["SOURCE_TABLE_NAME"], source_table_2) + self.assertEqual(row_2["MODEL_VERSION_NAME"], version_2) + self.assertEqual(row_2["IS_ENABLED"], True) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/BUILD.bazel b/tests/integ/snowflake/ml/registry/model/BUILD.bazel index b82c51a5..9b21f080 100644 --- a/tests/integ/snowflake/ml/registry/model/BUILD.bazel +++ b/tests/integ/snowflake/ml/registry/model/BUILD.bazel @@ -171,7 +171,7 @@ py_test( name = "registry_huggingface_pipeline_model_test", timeout = "long", srcs = ["registry_huggingface_pipeline_model_test.py"], - shard_count = 6, + shard_count = 8, deps = [ ":registry_model_test_base", "//snowflake/ml/_internal:env_utils", diff --git a/tests/integ/snowflake/ml/registry/model/registry_huggingface_pipeline_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_huggingface_pipeline_model_test.py index 30586159..62892226 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_huggingface_pipeline_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_huggingface_pipeline_model_test.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd from absl.testing import absltest, parameterized -from packaging import requirements +from packaging import requirements, version from snowflake.ml._internal import env_utils from tests.integ.snowflake.ml.registry.model import registry_model_test_base @@ -35,6 +35,9 @@ def test_conversational_pipeline( # Only by doing so can we make the cache dir setting effective. import transformers + if version.parse(transformers.__version__) >= version.parse("4.42.0"): + self.skipTest("This test is not compatible with transformers>=4.42.0") + model = transformers.pipeline(task="conversational", model="ToddGoldfarb/Cadet-Tiny") x_df = pd.DataFrame( @@ -65,6 +68,7 @@ def check_res(res: pd.DataFrame) -> None: check_res, ), }, + options={"relax_version": False}, ) @parameterized.product( # type: ignore[misc] diff --git a/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py index 48fa3ad0..4da12e1d 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py @@ -2,6 +2,7 @@ import posixpath import numpy as np +import shap import yaml from absl.testing import absltest, parameterized from sklearn import datasets @@ -74,6 +75,126 @@ def test_snowml_model_deploy_xgboost( ), ), }, + options={"enable_explainability": False}, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_snowml_model_deploy_xgboost_explain_default( + self, + registry_test_fn: str, + ) -> None: + iris_X = datasets.load_iris(as_frame=True).frame + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" + EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] + + regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X + regr.fit(test_features) + + expected_explanations = shap.Explainer(regr.to_xgboost())(test_features[INPUT_COLUMNS]).values + + getattr(self, registry_test_fn)( + model=regr, + prediction_assert_fns={ + "predict": ( + test_features, + lambda res: np.testing.assert_allclose( + res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values + ), + ), + "explain": ( + test_features, + lambda res: np.testing.assert_allclose( + res[EXPLAIN_OUTPUT_COLUMNS].values, expected_explanations, rtol=1e-4 + ), + ), + }, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_snowml_model_deploy_xgboost_explain_enabled( + self, + registry_test_fn: str, + ) -> None: + iris_X = datasets.load_iris(as_frame=True).frame + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" + EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] + + regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X + regr.fit(test_features) + + expected_explanations = shap.Explainer(regr.to_xgboost())(test_features[INPUT_COLUMNS]).values + + getattr(self, registry_test_fn)( + model=regr, + prediction_assert_fns={ + "predict": ( + test_features, + lambda res: np.testing.assert_allclose( + res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values + ), + ), + "explain": ( + test_features, + lambda res: np.testing.assert_allclose( + res[EXPLAIN_OUTPUT_COLUMNS].values, expected_explanations, rtol=1e-4 + ), + ), + }, + options={"enable_explainability": True}, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_snowml_model_deploy_xgboost_explain( + self, + registry_test_fn: str, + ) -> None: + iris_X = datasets.load_iris(as_frame=True).frame + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" + EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] + + regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X + regr.fit(test_features) + + expected_explanations = shap.Explainer(regr.to_xgboost())(test_features[INPUT_COLUMNS]).values + + getattr(self, registry_test_fn)( + model=regr, + prediction_assert_fns={ + "predict": ( + test_features, + lambda res: np.testing.assert_allclose( + res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values + ), + ), + "explain": ( + test_features, + lambda res: np.testing.assert_allclose( + res[EXPLAIN_OUTPUT_COLUMNS].values, expected_explanations, rtol=1e-4 + ), + ), + }, + options={"enable_explainability": True}, ) @parameterized.product( # type: ignore[misc] @@ -103,6 +224,130 @@ def test_snowml_model_deploy_lightgbm( ), ), }, + options={"enable_explainability": False}, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_snowml_model_deploy_lightgbm_explain_default( + self, + registry_test_fn: str, + ) -> None: + iris_X = datasets.load_iris(as_frame=True).frame + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" + EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] + regr = LGBMRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X + regr.fit(test_features) + + expected_explanations = shap.Explainer(regr.to_lightgbm())(test_features[INPUT_COLUMNS]).values + + getattr(self, registry_test_fn)( + model=regr, + prediction_assert_fns={ + "predict": ( + test_features, + lambda res: np.testing.assert_allclose( + res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values + ), + ), + "explain": ( + test_features, + lambda res: np.testing.assert_allclose( + res[EXPLAIN_OUTPUT_COLUMNS].values, + expected_explanations, + rtol=1e-5, + ), + ), + }, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_snowml_model_deploy_lightgbm_explain_enabled( + self, + registry_test_fn: str, + ) -> None: + iris_X = datasets.load_iris(as_frame=True).frame + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" + EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] + regr = LGBMRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X + regr.fit(test_features) + + expected_explanations = shap.Explainer(regr.to_lightgbm())(test_features[INPUT_COLUMNS]).values + + getattr(self, registry_test_fn)( + model=regr, + prediction_assert_fns={ + "predict": ( + test_features, + lambda res: np.testing.assert_allclose( + res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values + ), + ), + "explain": ( + test_features, + lambda res: np.testing.assert_allclose( + res[EXPLAIN_OUTPUT_COLUMNS].values, + expected_explanations, + rtol=1e-5, + ), + ), + }, + options={"enable_explainability": True}, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_snowml_model_deploy_lightgbm_explain( + self, + registry_test_fn: str, + ) -> None: + iris_X = datasets.load_iris(as_frame=True).frame + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" + EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] + regr = LGBMRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X + regr.fit(test_features) + + expected_explanations = shap.Explainer(regr.to_lightgbm())(test_features[INPUT_COLUMNS]).values + print(expected_explanations) + + getattr(self, registry_test_fn)( + model=regr, + prediction_assert_fns={ + "predict": ( + test_features, + lambda res: np.testing.assert_allclose( + res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values + ), + ), + "explain": ( + test_features, + lambda res: np.testing.assert_allclose( + res[EXPLAIN_OUTPUT_COLUMNS].values, + expected_explanations, + rtol=1e-5, + ), + ), + }, + options={"enable_explainability": True}, ) @parameterized.product( # type: ignore[misc] diff --git a/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py index 957b64cf..ae8eb46c 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py @@ -11,6 +11,31 @@ class TestRegistryXGBoostModelInteg(registry_model_test_base.RegistryModelTestBase): + @parameterized.product( + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_xgb_manual_shap_override(self, registry_test_fn: str) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True) + cal_X = cal_data.data + cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) + regressor.fit(cal_X_train, cal_y_train) + expected_explanations = shap.Explainer(regressor)(cal_X_test).values + getattr(self, registry_test_fn)( + model=regressor, + sample_input_data=cal_X_test, + prediction_assert_fns={ + "explain": ( + cal_X_test, + lambda res: np.testing.assert_allclose(res.values, expected_explanations, rtol=1e-3), + ), + }, + # pin version of shap for tests + additional_dependencies=[f"shap=={shap.__version__}"], + ) + @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) diff --git a/tests/integ/snowflake/ml/snowpark_pandas/BUILD.bazel b/tests/integ/snowflake/ml/snowpark_pandas/BUILD.bazel index 1135ac64..288b17a6 100644 --- a/tests/integ/snowflake/ml/snowpark_pandas/BUILD.bazel +++ b/tests/integ/snowflake/ml/snowpark_pandas/BUILD.bazel @@ -8,7 +8,6 @@ py_test( name = "snowpark_pandas_test", timeout = "long", srcs = ["snowpark_pandas_test.py"], - compatible_with_snowpark = False, shard_count = 5, deps = [ "//snowflake/ml/_internal:env",