From 38d24978d529804ed355c993126e05c4e8a42fd1 Mon Sep 17 00:00:00 2001 From: Angel Antonio Avalos Cisneros Date: Tue, 5 Nov 2024 10:31:13 -0800 Subject: [PATCH] Project import generated by Copybara. (#127) GitOrigin-RevId: 50dc082133b4d28c56c1ecc0cad18e5751f6bc19 Co-authored-by: Snowflake Authors --- CHANGELOG.md | 27 +- bazel/environments/conda-env-build.yml | 4 +- bazel/environments/conda-env-snowflake.yml | 6 +- bazel/environments/conda-env.yml | 6 +- bazel/environments/conda-gpu-env.yml | 6 +- .../parse_and_generate_requirements.py | 30 +- ci/conda_recipe/meta.yaml | 7 +- ci/targets/slow.txt | 3 + docs/source/cortex.rst | 10 +- docs/source/data.rst | 21 + docs/source/distributors.rst | 294 ++++ docs/source/index.rst | 1 + docs/source/model.rst | 2 +- packages.bzl | 2 + requirements.txt | 4 +- requirements.yml | 27 +- snowflake/cortex/BUILD.bazel | 32 +- snowflake/cortex/__init__.py | 4 + snowflake/cortex/_complete.py | 171 ++- snowflake/cortex/_finetune.py | 273 ++++ snowflake/cortex/_sse_client.py | 119 +- snowflake/cortex/_util.py | 31 +- snowflake/cortex/complete_test.py | 94 +- snowflake/cortex/finetune_test.py | 197 +++ snowflake/cortex/package_visibility_test.py | 5 + snowflake/cortex/sse_test.py | 15 +- snowflake/ml/_internal/type_utils.py | 6 +- .../model/_client/model/model_version_impl.py | 14 +- .../_client/model/model_version_impl_test.py | 36 +- snowflake/ml/model/_client/ops/model_ops.py | 81 +- .../ml/model/_client/ops/model_ops_test.py | 279 +++- snowflake/ml/model/_client/ops/service_ops.py | 15 +- .../ml/model/_client/ops/service_ops_test.py | 190 ++- snowflake/ml/model/_client/sql/model.py | 14 - snowflake/ml/model/_client/sql/service.py | 26 +- .../ml/model/_client/sql/service_test.py | 40 + .../model_manifest/model_manifest_test.py | 52 +- .../model_method/fixtures/function_1.py | 3 +- .../model_method/fixtures/function_2.py | 3 +- .../model_method/infer_function.py_template | 3 +- .../ml/model/_packager/model_env/model_env.py | 12 + .../_packager/model_env/model_env_test.py | 50 + .../model/_packager/model_handlers/_utils.py | 2 +- .../_packager/model_handlers/catboost.py | 2 +- .../model/_packager/model_handlers/custom.py | 4 +- .../_packager/model_handlers/lightgbm.py | 3 +- .../model/_packager/model_handlers/sklearn.py | 49 +- .../_packager/model_handlers/snowmlmodel.py | 2 +- .../_packager/model_handlers/tensorflow.py | 29 +- .../_packager/model_handlers/torchscript.py | 28 +- .../_packager/model_handlers_test/BUILD.bazel | 1 + .../model_handlers_test/catboost_test.py | 1 + .../model_handlers_test/lightgbm_test.py | 1 + .../model_handlers_test/pytorch_test.py | 4 +- .../model_handlers_test/sklearn_test.py | 18 + .../model_handlers_test/snowmlmodel_test.py | 8 +- .../model_handlers_test/torchscript_test.py | 6 +- .../model_handlers_test/xgboost_test.py | 17 +- .../ml/model/_packager/model_meta/BUILD.bazel | 2 +- .../_packager/model_meta/model_meta_schema.py | 5 + .../model/_packager/model_runtime/BUILD.bazel | 2 +- .../_packager/model_runtime/model_runtime.py | 13 +- .../model_runtime/model_runtime_test.py | 12 +- .../_packager/model_task/model_task_utils.py | 2 +- snowflake/ml/model/_signatures/core.py | 79 +- snowflake/ml/model/_signatures/core_test.py | 38 +- .../ml/model/_signatures/pandas_handler.py | 98 +- snowflake/ml/model/_signatures/pandas_test.py | 178 ++- .../ml/model/_signatures/pytorch_handler.py | 4 +- .../ml/model/_signatures/pytorch_test.py | 38 +- .../ml/model/_signatures/snowpark_handler.py | 3 +- .../model/_signatures/tensorflow_handler.py | 4 +- .../ml/model/_signatures/tensorflow_test.py | 88 +- snowflake/ml/model/_signatures/utils.py | 4 + snowflake/ml/model/model_signature.py | 47 +- snowflake/ml/model/model_signature_test.py | 43 + snowflake/ml/model/type_hints.py | 2 +- snowflake/ml/monitoring/_client/BUILD.bazel | 11 + .../_client/model_monitor_sql_client.py | 1203 +++-------------- .../model_monitor_sql_client_server_test.py | 215 +++ .../_client/model_monitor_sql_client_test.py | 1166 ++-------------- .../_manager/model_monitor_manager.py | 336 ++--- .../_manager/model_monitor_manager_test.py | 365 +++-- snowflake/ml/monitoring/entities/BUILD.bazel | 12 - .../entities/model_monitor_config.py | 20 +- .../entities/model_monitor_interval.py | 46 - .../entities/model_monitor_interval_test.py | 41 - snowflake/ml/monitoring/model_monitor.py | 103 +- snowflake/ml/monitoring/model_monitor_test.py | 131 +- snowflake/ml/registry/registry.py | 46 +- snowflake/ml/registry/registry_test.py | 196 +-- snowflake/ml/version.bzl | 2 +- tests/integ/snowflake/cortex/BUILD.bazel | 30 + tests/integ/snowflake/cortex/complete_test.py | 6 + .../integ/snowflake/cortex/embed_text_test.py | 16 +- .../pipeline_with_ohe_and_xgbr_test.py | 17 +- .../model/input_validation_integ_test.py | 10 +- .../ml/monitoring/model_monitor_integ_test.py | 365 ++--- .../model/random_version_name_test.py | 8 +- .../model/registry_catboost_model_test.py | 174 ++- .../model/registry_custom_model_test.py | 185 ++- ...egistry_huggingface_pipeline_model_test.py | 20 +- .../model/registry_lightgbm_model_test.py | 177 ++- .../model/registry_mlflow_model_test.py | 14 +- .../model/registry_modeling_model_test.py | 265 +++- .../model/registry_pytorch_model_test.py | 6 +- .../model/registry_sklearn_model_test.py | 99 +- .../model/registry_tensorflow_model_test.py | 76 +- .../model/registry_xgboost_model_test.py | 202 ++- .../ml/registry/services/BUILD.bazel | 38 +- .../registry_custom_model_deployment_test.py | 9 +- ...face_pipeline_model_deployment_gpu_test.py | 72 + ...gingface_pipeline_model_deployment_test.py | 5 +- .../registry_model_deployment_test.py | 26 +- .../registry_model_deployment_test_base.py | 4 +- ..._transformers_model_deployment_gpu_test.py | 79 ++ ...ence_transformers_model_deployment_test.py | 7 +- .../registry_sklearn_model_deployment_test.py | 10 +- ...stry_xgboost_model_deployment_pip_test.py} | 21 +- .../integ/snowflake/ml/test_utils/BUILD.bazel | 1 + .../ml/test_utils/common_test_base.py | 2 +- .../snowflake/ml/test_utils/db_manager.py | 3 +- 122 files changed, 4840 insertions(+), 4052 deletions(-) create mode 100644 docs/source/data.rst create mode 100644 docs/source/distributors.rst create mode 100644 snowflake/cortex/_finetune.py create mode 100644 snowflake/cortex/finetune_test.py create mode 100644 snowflake/ml/monitoring/_client/model_monitor_sql_client_server_test.py delete mode 100644 snowflake/ml/monitoring/entities/model_monitor_interval.py delete mode 100644 snowflake/ml/monitoring/entities/model_monitor_interval_test.py create mode 100644 tests/integ/snowflake/cortex/BUILD.bazel create mode 100644 tests/integ/snowflake/ml/registry/services/registry_huggingface_pipeline_model_deployment_gpu_test.py create mode 100644 tests/integ/snowflake/ml/registry/services/registry_sentence_transformers_model_deployment_gpu_test.py rename tests/integ/snowflake/ml/registry/services/{registry_xgboost_model_deployment_test.py => registry_xgboost_model_deployment_pip_test.py} (71%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8aef346a..f880471a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,31 @@ # Release History -## 1.7.0 +## 1.7.1 + +### Bug Fixes + +- Registry: Null value is now allowed in the dataframe used in model signature inference. Null values will be ignored + and others will be used to infer the signature. +- Registry: Pandas Extension DTypes (`pandas.StringDType()`, `pandas.BooleanDType()`, etc.) are now supported in model +signature inference. +- Registry: Null value is now allowed in the dataframe used to predict. +- Data: Fix missing `snowflake.ml.data.*` module exports in wheel +- Dataset: Fix missing `snowflake.ml.dataset.*` module exports in wheel. +- Registry: Fix the issue that `tf_keras.Model` is not recognized as keras model when logging. + +### Behavior Changes + +### New Features + +- Registry: Option to `enable_monitoring` set to False by default. This will gate access to preview features of Model Monitoring. +- Model Monitoring: `show_model_monitors` Registry method. This feature is still in Private Preview. +- Registry: Support `pd.Series` in input and output data. +- Model Monitoring: `add_monitor` Registry method. This feature is still in Private Preview. +- Model Monitoring: `resume` and `suspend` ModelMonitor. This feature is still in Private Preview. +- Model Monitoring: `get_monitor` Registry method. This feature is still in Private Preview. +- Model Monitoring: `delete_monitor` Registry method. This feature is still in Private Preview. + +## 1.7.0 (10-22-2024) ### Behavior Change diff --git a/bazel/environments/conda-env-build.yml b/bazel/environments/conda-env-build.yml index b26fb48e..4c15ad31 100644 --- a/bazel/environments/conda-env-build.yml +++ b/bazel/environments/conda-env-build.yml @@ -11,7 +11,7 @@ dependencies: - conda-libmamba-solver==23.7.0 - inflection==0.5.1 - jsonschema==3.2.0 - - lightgbm==3.3.5 + - lightgbm==4.1.0 - numpy==1.23.5 - packaging==23.0 - ruamel.yaml==0.17.21 @@ -19,4 +19,4 @@ dependencies: - sphinx==5.0.2 - toml==0.10.2 - types-toml==0.10.8.6 - - xgboost==1.7.3 + - xgboost==1.7.6 diff --git a/bazel/environments/conda-env-snowflake.yml b/bazel/environments/conda-env-snowflake.yml index 47685af6..ae87b83b 100644 --- a/bazel/environments/conda-env-snowflake.yml +++ b/bazel/environments/conda-env-snowflake.yml @@ -25,7 +25,7 @@ dependencies: - inflection==0.5.1 - joblib==1.4.2 - jsonschema==3.2.0 - - lightgbm==3.3.5 + - lightgbm==4.1.0 - mlflow==2.3.1 - moto==4.0.11 - mypy==1.10.0 @@ -40,7 +40,7 @@ dependencies: - pytest-xdist==3.5.0 - pytest==7.4.0 - pytimeparse==1.1.8 - - pytorch==2.0.1 + - pytorch==2.1.0 - pyyaml==6.0 - requests==2.29.0 - retrying==1.3.3 @@ -67,4 +67,4 @@ dependencies: - types-toml==0.10.8.6 - typing-extensions==4.6.3 - werkzeug==2.2.2 - - xgboost==1.7.3 + - xgboost==1.7.6 diff --git a/bazel/environments/conda-env.yml b/bazel/environments/conda-env.yml index d2acdd48..6e102884 100644 --- a/bazel/environments/conda-env.yml +++ b/bazel/environments/conda-env.yml @@ -25,7 +25,7 @@ dependencies: - inflection==0.5.1 - joblib==1.4.2 - jsonschema==3.2.0 - - lightgbm==3.3.5 + - lightgbm==4.1.0 - mlflow==2.3.1 - moto==4.0.11 - mypy==1.10.0 @@ -40,7 +40,7 @@ dependencies: - pytest-xdist==3.5.0 - pytest==7.4.0 - pytimeparse==1.1.8 - - pytorch==2.0.1 + - pytorch==2.1.0 - pyyaml==6.0 - requests==2.29.0 - retrying==1.3.3 @@ -67,7 +67,7 @@ dependencies: - types-toml==0.10.8.6 - typing-extensions==4.6.3 - werkzeug==2.2.2 - - xgboost==1.7.3 + - xgboost==1.7.6 - pip - pip: - --extra-index-url https://pypi.org/simple diff --git a/bazel/environments/conda-gpu-env.yml b/bazel/environments/conda-gpu-env.yml index d652f787..00d8958e 100755 --- a/bazel/environments/conda-gpu-env.yml +++ b/bazel/environments/conda-gpu-env.yml @@ -25,7 +25,7 @@ dependencies: - inflection==0.5.1 - joblib==1.4.2 - jsonschema==3.2.0 - - lightgbm==3.3.5 + - lightgbm==4.1.0 - mlflow==2.3.1 - moto==4.0.11 - mypy==1.10.0 @@ -42,7 +42,7 @@ dependencies: - pytest==7.4.0 - pytimeparse==1.1.8 - pytorch::pytorch-cuda==11.7.* - - pytorch::pytorch==2.0.1 + - pytorch::pytorch==2.1.0 - pyyaml==6.0 - requests==2.29.0 - retrying==1.3.3 @@ -69,7 +69,7 @@ dependencies: - types-toml==0.10.8.6 - typing-extensions==4.6.3 - werkzeug==2.2.2 - - xgboost==1.7.3 + - xgboost==1.7.6 - pip - pip: - --extra-index-url https://pypi.org/simple diff --git a/bazel/requirements/parse_and_generate_requirements.py b/bazel/requirements/parse_and_generate_requirements.py index 0195a0f9..e4d2061e 100644 --- a/bazel/requirements/parse_and_generate_requirements.py +++ b/bazel/requirements/parse_and_generate_requirements.py @@ -418,7 +418,7 @@ def generate_requirements( ) sys.stdout.writelines(results) elif (mode, format) == ("version_requirements", "python"): - results = list( + reqs = list( sorted( filter( None, @@ -427,13 +427,28 @@ def generate_requirements( filter( lambda req_info: req_info.get("from_channel", SNOWFLAKE_CONDA_CHANNEL) == SNOWFLAKE_CONDA_CHANNEL, - requirements, + filter(lambda req_info: filter_by_extras(req_info, False, True), requirements), + ), + ), + ), + ) + ) + all_reqs = list( + sorted( + filter( + None, + map( + lambda req_info: generate_user_requirements_string(req_info, "conda"), + filter( + lambda req_info: req_info.get("from_channel", SNOWFLAKE_CONDA_CHANNEL) + == SNOWFLAKE_CONDA_CHANNEL, + filter(lambda req_info: filter_by_extras(req_info, False, False), requirements), ), ), ), ) ) - sys.stdout.write(f"REQUIREMENTS = {json.dumps(results, indent=4)}\n") + sys.stdout.write(f"REQUIREMENTS = {repr(reqs)}\nALL_REQUIREMENTS={repr(all_reqs)}\n") elif (mode, format) == ("version_requirements", "toml"): extras_requirements = list(filter(lambda req_info: filter_by_extras(req_info, True, False), requirements)) extras_results: MutableMapping[str, Sequence[str]] = {} @@ -478,7 +493,13 @@ def generate_requirements( elif (mode, format) == ("version_requirements", "python"): results = list( sorted( - filter(None, map(lambda req_info: generate_user_requirements_string(req_info, "conda"), requirements)), + filter( + None, + map( + lambda req_info: generate_user_requirements_string(req_info, "conda"), + filter(lambda req_info: filter_by_extras(req_info, False, True), requirements), + ), + ) ) ) sys.stdout.writelines(f"REQUIREMENTS = {repr(results)}\n") @@ -555,7 +576,6 @@ def main() -> None: ("dev_version", "text", False), # requirements.txt ("version_requirements", "python", True), # sproc test dependencies list ("version_requirements", "toml", False), # wheel rule requirements - ("version_requirements", "python", False), # model deployment core dependencies list ("dev_version", "conda_env", False), # dev conda-env.yml file ("dev_gpu_version", "conda_env", False), # dev conda-gpu-env.yml file ("dev_version", "conda_env", True), # dev conda-env-snowflake.yml file diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index 7a5dd99b..a8dba137 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -17,7 +17,7 @@ build: noarch: python package: name: snowflake-ml-python - version: 1.7.0 + version: 1.7.1 requirements: build: - python @@ -28,10 +28,11 @@ requirements: - anyio>=3.5.0,<4 - cachetools>=3.1.1,<6 - cloudpickle>=2.0.0 + - cryptography - fsspec>=2022.11,<2024 - importlib_resources>=6.1.1, <7 - numpy>=1.23,<2 - - packaging>=20.9,<24 + - packaging>=20.9,<25 - pandas>=1.0.0,<3 - pyarrow - pytimeparse>=1.1.8,<2 @@ -49,7 +50,7 @@ requirements: - python>=3.9,<3.12 run_constrained: - catboost>=1.2.0, <2 - - lightgbm>=3.3.5,<5 + - lightgbm>=4.1.0, <5 - mlflow>=2.1.0,<2.4 - pytorch>=2.0.1,<2.3.0 - sentence-transformers>=2.2.2,<3 diff --git a/ci/targets/slow.txt b/ci/targets/slow.txt index e69de29b..41cdf2a7 100644 --- a/ci/targets/slow.txt +++ b/ci/targets/slow.txt @@ -0,0 +1,3 @@ +//tests/integ/snowflake/ml/registry/services:registry_huggingface_pipeline_model_deployment_gpu_test +//tests/integ/snowflake/ml/registry/services:registry_sentence_transformers_model_deployment_gpu_test +//tests/integ/snowflake/ml/registry/services:registry_xgboost_model_deployment_pip_test diff --git a/docs/source/cortex.rst b/docs/source/cortex.rst index 2c3470e0..2faa730f 100644 --- a/docs/source/cortex.rst +++ b/docs/source/cortex.rst @@ -10,16 +10,22 @@ snowflake.ml.cortex .. rubric:: Classes .. autosummary:: - :toctree: api/model + :toctree: api/cortex CompleteOptions + Finetune + FinetuneJob + FinetuneStatus .. rubric:: Functions .. autosummary:: - :toctree: api/model + :toctree: api/cortex + ClassifyText Complete + EmbedText768 + EmbedText1024 ExtractAnswer Sentiment Summarize diff --git a/docs/source/data.rst b/docs/source/data.rst new file mode 100644 index 00000000..4ab70ea7 --- /dev/null +++ b/docs/source/data.rst @@ -0,0 +1,21 @@ +:orphan: + +=========================== +snowflake.ml.data +=========================== + +.. automodule:: snowflake.ml.data + :noindex: + +.. currentmodule:: snowflake.ml.data + +.. rubric:: Classes + +.. autosummary:: + :toctree: api/data + + data_connector.DataConnector + data_ingestor.DataIngestor + data_source.DataSource + data_source.DataFrameInfo + data_source.DatasetInfo diff --git a/docs/source/distributors.rst b/docs/source/distributors.rst new file mode 100644 index 00000000..7b8a5e30 --- /dev/null +++ b/docs/source/distributors.rst @@ -0,0 +1,294 @@ +:orphan: + +.. # + + This file is temporary until the snowflake.ml.modeling.distributors subpackage makes it into snowflake-ml-python + +**************************** +Distributed Modeling Classes +**************************** + +When using `Container Runtime for ML `_ +in a `Snowflake Notebook `_, a set of distributed +modeling classes is available to train selected types of models on large datasets using the full resources of a +Snowpark Container Services (SPCS) `compute pool `_. + +The following model types are supported: + +- :ref:`XGBoost ` +- :ref:`LightGBM ` +- :ref:`PyTorch ` + +.. _label-distributors_xgboost: + +:code:`snowflake.ml.modeling.distributors.xgboost.XGBEstimator` +=============================================================== + +Xgboost Estimator that supports distributed training. + +Args: + + n_estimators (int): + Number of estimators. Default is 100. + + objective (str): + The objective function used for training. 'reg:squarederror'[Default] for regression, + 'binary:logistic' for binary classification, 'multi:softmax' for multi-class classification. + + params (Optional[dict]): + Additional parameters for the XGBoost Estimator. + + Some key parameters are: + * booster: Specify which booster to use: gbtree[Default], gblinear or dart. + * max_depth: Maximum depth of a tree. Default is 6. + * max_leaves: Maximum number of nodes to be added. Default is 0. + * max_bin: Maximum number of bins that continuous feature values will be bucketed in. + Default is 256. + * eval_metric: Evaluation metrics for validation data. + + Full list of supported parameter can be found at https://xgboost.readthedocs.io/en/stable/parameter.html + If params dict contains keys 'n_estimators' or 'objective', they override the value provided + by n_estimators and objective arguments. + + scaling_config (Optional[XGBScalingConfig]): + Scaling config for XGBoost Estimator. Defaults to None. If None, the estimator will use all available + resources. + +Related classes +--------------- + +``XGBScalingConfig(BaseScalingConfig)`` + + Scaling config for XGBoost Estimator + + Attributes: + num_workers (int): + Number of workers to use for distributed training. Default is -1, meaning the estimator will + use all available workers. + + num_cpu_per_worker (int): + Number of CPU cores to use per worker. Default is -1, meaning the estimator will use + all available CPU cores. + + use_gpu (Optional[bool]): + Whether to use GPU for training. If None, the estimator will choose to use GPU or not + based on the environment. + +.. _label-distributors_lightgbm: + +:code:`snowflake.ml.modeling.distributors.lightgbm.LightGBMEstimator` +===================================================================== + +LightGBM Estimator for distributed training and inference. + +Args: + + n_estimators (int, optional): + Number of boosting iterations. Defaults to 100. + + objective (str, optional): + The learning task and corresponding objective. Defaults to "regression". + + "regression"[Default] for regression tasks, "binary" for binary classification, "multiclass" for + multi-class classification. + + params (Optional[Dict[str, Any]], optional): + + Additional parameters for LightGBM. Defaults to None. + + Some key params are: + + * boosting: The type of boosting to use. "gbdt"[Default] for Gradient Boosting Decision Tree, "dart" for + Dropouts meet Multiple Additive Regression Trees. + * num_leaves: The maximum number of leaves in one tree. Default is 31. + * max_depth: The maximum depth of the tree. Default is -1, which means no limit. + * early_stopping_rounds: Activates early stopping. The model will train until the validation score + stops improving. Default is 0, meaning no early stopping. + + Full list of supported parameter can be found at https://lightgbm.readthedocs.io/en/latest/Parameters.html. + + scaling_config (Optional[LightGBMScalingConfig], optional): + Configuration for scaling. Defaults to None. If None, the estimator will use all available resources. + +Related classes +--------------- + +``LightGBMScalingConfig(BaseScalingConfig)`` + + Scaling config for LightGBM Estimator. + + Attributes: + + num_workers (int): + The number of worker processes to use. Default is -1, which utilizes all available resources. + + num_cpu_per_worker (int): + Number of CPUs allocated per worker. Default is -1, which means all available resources. + + use_gpu (Optional[bool]): + Whether to use GPU for training. Default is None, allowing the estimator to choose + automatically based on the environment. + +.. _label-distributors_pytorch: + +:code:`snowflake.ml.modeling.distributors.pytorch.PyTorchDistributor` +===================================================================== + +Enables users to run distributed training with PyTorch on ContainerRuntime cluster. + +PyTorchDistributor is responsible for setting up the environment, scheduling the training processes, +managing the communication between the processes, and collecting the results. + +Args: + + train_func (Callable): + A callable object that defines the training logic to be executed. + + scaling_config (PyTorchScalingConfig): + Configuration for scaling and other settings related to the training job. + +Related classes +--------------- + +``snowflake.ml.modeling.distributors.pytorch.PyTorchScalingConfig`` + + Scaling configuration for training PyTorch models. + + This class defines the scaling configuration for a PyTorch training job, + including the number of nodes, the number of workers per node, and the + resource requirements for each worker. + + Attributes: + + num_nodes (int): The number of nodes to use for training. + + num_workers_per_node (int): The number of workers to use per node. + + resource_requirements_per_worker (WorkerResourceConfig): The resource requirements + for each worker, such as the number of CPUs and GPUs. + +``snowflake.ml.modeling.distributors.pytorch.WorkerResourceConfig`` + + Resources requirements per worker. + + This class defines the resource requirements for each worker in a distributed + training job, specifying the number of CPU and GPU resources to allocate. + + Attributes: + + num_cpus (int): The number of CPU cores to reserve for each worker. + + num_gpus (int): The number of GPUs to reserve for each worker. + Default is 0, indicating no GPUs are reserved. + +``snowflake.ml.modeling.distributors.pytorch.Context`` + + Context for setting up the PyTorch distributed environment for training scripts. + + Context defines the necessary methods to manage and retrieve information + about the distributed training environment, including worker and node ranks, + world size, and backend configurations. + + Definitions: + + Node: A physical instance or a container. + + Worker: A worker process in the context of distributed training. + + WorkerGroup: The set of workers that execute the same function (e.g., trainers). + + LocalWorkerGroup: A subset of the workers in the worker group running on the same node. + + RANK: The rank of the worker within a worker group. + + WORLD_SIZE: The total number of workers in a worker group. + + LOCAL_RANK: The rank of the worker within a local worker group. + + LOCAL_WORLD_SIZE: The size of the local worker group. + rdzv_id: An ID that uniquely identifies the worker group for a job. This ID is used by each node to join as + a member of a particular worker group. + + rdzv_backend: The backend of the rendezvous (e.g., c10d). This is typically a strongly consistent + key-value store. + + rdzv_endpoint: The rendezvous backend endpoint; usually in the form :. + + Methods: + + ``get_world_size(self) -> int`` + Return the number of workers (or processes) participating in the job. + + For example, if training is running on 2 nodes (servers) each with 4 GPUs, + then the world size is 8 (2 nodes * 4 GPUs per node). Usually, each GPU corresponds + to a training process. + + ``get_rank(self) -> int`` + Return the rank of the current process across all processes. + + Rank is the unique ID given to a process to identify it uniquely across the world. + It should be a number between 0 and world_size - 1. + + Some frameworks also call it world_rank, to distinguish it from local_rank. + For example, if training is running on 2 nodes (servers) each with 4 GPUs, + then the ranks will be [0, 1, 2, 3, 4, 5, 6, 7], i.e., from 0 to world_size - 1. + + ``get_local_rank(self) -> int`` + Return the local rank for the current worker. + + Local rank is a unique local ID for a worker (or process) running on the current node. + + For example, if training is running on 2 nodes (servers) each with 4 GPUs, then + local rank for workers(or processes) running on node 0 will be [0, 1, 2, 3] and + similarly four workers(or processes) running on node 1 will have local_rank [0, 1, 2, 3]. + + ``get_local_world_size(self) -> int`` + Return the number of workers running in the current node. + + For example, if training is running on 2 nodes (servers) each with 4 GPUs, + then local_world_size will be 4 for all processes on both nodes. + + ``get_node_rank(self)`` + Return the rank of the current node across all nodes. + + Node rank is a unique ID given to each node to identify it uniquely across all nodes + in the world. + + For example, if training is running on 2 nodes (servers) each with 4 GPUs, + then node ranks will be [0, 1] respectively. + + ``get_master_addr(self) -> str`` + Return IP address of the master node. + + This is typically the address of the node with node_rank 0. + + ``def get_master_port(self) -> int`` + Return port on master_addr that hosts the rendezvous server. + + ``get_default_backend(self) -> str`` + Return default backend selected by MCE. + + ``get_supported_backends(self) -> List[str]`` + Return list of supported backends by MCE. + + ``get_hyper_params(self) -> Optional[Dict[str, str]]`` + Return hyperparameter map provided to trainer.run(...) method. + + ``get_dataset_map(self) -> Optional[Dict[str, Type[DataConnector]]]`` + Return dataset map provided to trainer.run(...) method. + +Related functions +----------------- + +``snowflake.ml.modeling.distributors.pytorch.get_context`` + + Fetches the context object that contains the worker specific runtime information. + + Returns: + + Context: An instance of the Context interface that provides methods for + managing the distributed training environment. + + Raises: + + RuntimeError: If the PyTorch context is not available. diff --git a/docs/source/index.rst b/docs/source/index.rst index dccd7d44..fc3bf458 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -26,6 +26,7 @@ Table of Contents :maxdepth: 3 cortex + data dataset feature_store fileset diff --git a/docs/source/model.rst b/docs/source/model.rst index bb0b5e8a..698657dc 100644 --- a/docs/source/model.rst +++ b/docs/source/model.rst @@ -32,7 +32,7 @@ snowflake.ml.model.custom_model CustomModel snowflake.ml.model.model_signature ---------------------------------- +---------------------------------- .. currentmodule:: snowflake.ml.model.model_signature diff --git a/packages.bzl b/packages.bzl index fefa8d60..301aa6a2 100644 --- a/packages.bzl +++ b/packages.bzl @@ -13,6 +13,8 @@ PACKAGES = [ "//snowflake/ml/utils:utils_pkg", "//snowflake/ml/fileset:fileset_pkg", "//snowflake/ml/registry:model_registry_pkg", + "//snowflake/ml/data:data_pkg", + "//snowflake/ml/dataset:dataset_pkg", # Auotgen packages "//snowflake/ml/modeling/linear_model:sklearn_linear_model_pkg", "//snowflake/ml/modeling/ensemble:sklearn_ensemble_pkg", diff --git a/requirements.txt b/requirements.txt index e8309a4d..452f792a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ importlib_resources==6.1.1 inflection==0.5.1 joblib==1.4.2 jsonschema==3.2.0 -lightgbm==3.3.5 +lightgbm==4.1.0 mlflow==2.3.1 moto==4.0.11 mypy==1.10.0 @@ -63,4 +63,4 @@ types-requests==2.30.0.0 types-toml==0.10.8.6 typing-extensions==4.6.3 werkzeug==2.2.2 -xgboost==1.7.3 +xgboost==1.7.6 diff --git a/requirements.yml b/requirements.yml index 91797688..6b335ce6 100644 --- a/requirements.yml +++ b/requirements.yml @@ -63,8 +63,6 @@ # `tags`: Set tags to filter some of the requirements in specific cases. The current valid tags include: # - `model_packaging`: Used by model packaging and deployment to indicate the core requirements to save and load the # model. -# - `snowml_inference_alternative`: Used by model packaging and deployment to indicate a subset of requirements to run -# inference as alternative of installing all dependencies of snowflake-ml-python. # - `build_essential`: Used to indicate the packages composing the build environment. # - `build_test_env`: Used to indicate the package is required in build and test environment to run the tests. @@ -73,7 +71,6 @@ version_requirements: '>=0.15,<2' tags: - build_essential - - snowml_inference_alternative # For fsspec[http] in conda - name_conda: aiohttp dev_version_conda: 3.8.3 @@ -81,8 +78,6 @@ - name: anyio dev_version: 3.5.0 version_requirements: '>=3.5.0,<4' - tags: - - snowml_inference_alternative - name: build dev_version: 0.10.0 tags: @@ -105,6 +100,7 @@ - model_packaging - name: cryptography dev_version: 39.0.1 + version_requirements: '' # Skipping version requirements as it should come as part of connector. # Only used in connection_params.py, which is an util library anyways. - name: coverage @@ -134,8 +130,8 @@ - name: joblib dev_version: 1.4.2 - name: lightgbm - dev_version: 3.3.5 - version_requirements: '>=3.3.5,<5' + dev_version: 4.1.0 + version_requirements: '>=4.1.0, <5' requirements_extra_tags: - lightgbm tags: @@ -161,18 +157,14 @@ version_requirements: '>=1.23,<2' tags: - build_essential - - snowml_inference_alternative - name: packaging dev_version: '23.0' - version_requirements: '>=20.9,<24' + version_requirements: '>=20.9,<25' tags: - build_essential - - snowml_inference_alternative - name: pandas dev_version: 1.5.3 version_requirements: '>=1.0.0,<3' - tags: - - snowml_inference_alternative - name: protobuf dev_version: 3.20.3 - name: psutil @@ -196,15 +188,14 @@ - build_test_env - name_pypi: torch name_conda: pytorch - dev_version: 2.0.1 + dev_version_conda: 2.1.0 + dev_version_pypi: 2.0.1 version_requirements: '>=2.0.1,<2.3.0' requirements_extra_tags: - torch - name: pyyaml dev_version: '6.0' version_requirements: '>=6.0,<7' - tags: - - snowml_inference_alternative - name: retrying dev_version: 1.3.3 version_requirements: '>=1.3.3,<2' @@ -244,8 +235,6 @@ - name: snowflake-snowpark-python dev_version: 1.17.0 version_requirements: '>=1.17.0,<2' - tags: - - snowml_inference_alternative - name: sphinx dev_version: 5.0.2 tags: @@ -293,10 +282,8 @@ - name: typing-extensions dev_version: 4.6.3 version_requirements: '>=4.1.0,<5' - tags: - - snowml_inference_alternative - name: xgboost - dev_version: 1.7.3 + dev_version: 1.7.6 version_requirements: '>=1.7.3,<3' tags: - build_essential diff --git a/snowflake/cortex/BUILD.bazel b/snowflake/cortex/BUILD.bazel index dd6039f0..380e9f83 100644 --- a/snowflake/cortex/BUILD.bazel +++ b/snowflake/cortex/BUILD.bazel @@ -8,15 +8,15 @@ package_group( ], ) -package(default_visibility = [ - ":cortex", - "//bazel:snowml_public_common", -]) +package(default_visibility = ["//visibility:public"]) py_library( name = "util", srcs = ["_util.py"], - deps = [":sse_client"], + deps = [ + ":sse_client", + "//snowflake/ml/_internal/utils:formatting", + ], ) py_library( @@ -191,6 +191,27 @@ py_test( ], ) +py_library( + name = "finetune", + srcs = ["_finetune.py"], + deps = [ + "//snowflake/cortex:util", + "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/utils:snowpark_dataframe_utils", + ], +) + +py_test( + name = "finetune_test", + srcs = ["finetune_test.py"], + deps = [ + "//snowflake/cortex:cortex_pkg", + "//snowflake/cortex:test_util", + "//snowflake/ml/test_utils:mock_session", + "//snowflake/ml/utils:connection_params", + ], +) + py_library( name = "init", srcs = [ @@ -202,6 +223,7 @@ py_library( ":embed_text_1024", ":embed_text_768", ":extract_answer", + ":finetune", ":sentiment", ":summarize", ":translate", diff --git a/snowflake/cortex/__init__.py b/snowflake/cortex/__init__.py index 92ab345f..00957943 100644 --- a/snowflake/cortex/__init__.py +++ b/snowflake/cortex/__init__.py @@ -3,6 +3,7 @@ from snowflake.cortex._embed_text_768 import EmbedText768 from snowflake.cortex._embed_text_1024 import EmbedText1024 from snowflake.cortex._extract_answer import ExtractAnswer +from snowflake.cortex._finetune import Finetune, FinetuneJob, FinetuneStatus from snowflake.cortex._sentiment import Sentiment from snowflake.cortex._summarize import Summarize from snowflake.cortex._translate import Translate @@ -14,6 +15,9 @@ "EmbedText768", "EmbedText1024", "ExtractAnswer", + "Finetune", + "FinetuneJob", + "FinetuneStatus", "Sentiment", "Summarize", "Translate", diff --git a/snowflake/cortex/_complete.py b/snowflake/cortex/_complete.py index bc9f6d82..e8d7095f 100644 --- a/snowflake/cortex/_complete.py +++ b/snowflake/cortex/_complete.py @@ -1,7 +1,8 @@ import json import logging import time -from typing import Any, Callable, Iterator, List, Optional, TypedDict, Union, cast +from io import BytesIO +from typing import Any, Callable, Dict, Iterator, List, Optional, TypedDict, Union, cast from urllib.parse import urlunparse import requests @@ -16,8 +17,10 @@ ) from snowflake.ml._internal import telemetry from snowflake.snowpark import context, functions +from snowflake.snowpark._internal.utils import is_in_stored_procedure logger = logging.getLogger(__name__) +_REST_COMPLETE_URL = "/api/v2/cortex/inference:complete" class ConversationMessage(TypedDict): @@ -84,6 +87,76 @@ def inner(*args: Any, **kwargs: Any) -> requests.Response: return inner +def _make_common_request_headers() -> Dict[str, str]: + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + return headers + + +def _make_request_body( + model: str, + prompt: Union[str, List[ConversationMessage]], + options: Optional[CompleteOptions] = None, +) -> Dict[str, Any]: + data = { + "model": model, + "stream": True, + } + if isinstance(prompt, List): + data["messages"] = prompt + else: + data["messages"] = [{"content": prompt}] + + if options: + if "max_tokens" in options: + data["max_tokens"] = options["max_tokens"] + data["max_output_tokens"] = options["max_tokens"] + if "temperature" in options: + data["temperature"] = options["temperature"] + if "top_p" in options: + data["top_p"] = options["top_p"] + return data + + +# XP endpoint returns a dict response which needs to be converted to a format which can +# be consumed by the SSEClient. This method does that. +def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response: + response = requests.Response() + response.status_code = int(raw_resp["status"]) + response.headers = raw_resp["headers"] + + data = raw_resp["content"] + data = json.loads(data) + # Convert the dictionary to a string format that resembles the SSE event format + # For example, if the dict is {'event': 'message', 'data': 'your data'}, it should be formatted like this: + sse_format_data = "" + for event in data: + event_type = event.get("event", "message") + event_data = event.get("data", "") + event_data = json.dumps(event_data) + sse_format_data += f"event: {event_type}\ndata: {event_data}\n\n" # Add each event with new lines + + response.raw = BytesIO(sse_format_data.encode("utf-8")) + return response + + +@retry +def _call_complete_xp( + model: str, + prompt: Union[str, List[ConversationMessage]], + options: Optional[CompleteOptions] = None, + deadline: Optional[float] = None, +) -> requests.Response: + headers = _make_common_request_headers() + body = _make_request_body(model, prompt, options) + import _snowflake + + raw_resp = _snowflake.send_snow_api_request("POST", _REST_COMPLETE_URL, {}, headers, body, {}, deadline) + return _xp_dict_to_response(raw_resp) + + @retry def _call_complete_rest( model: str, @@ -110,36 +183,16 @@ def _call_complete_rest( scheme = "https" if hasattr(session.connection, "scheme"): scheme = session.connection.scheme - url = urlunparse((scheme, session.connection.host, "api/v2/cortex/inference:complete", "", "", "")) + url = urlunparse((scheme, session.connection.host, _REST_COMPLETE_URL, "", "", "")) - headers = { - "Content-Type": "application/json", - "Authorization": f'Snowflake Token="{session.connection.rest.token}"', - "Accept": "application/json, text/event-stream", - } - - data = { - "model": model, - "stream": True, - } - if isinstance(prompt, List): - data["messages"] = prompt - else: - data["messages"] = [{"content": prompt}] - - if options: - if "max_tokens" in options: - data["max_tokens"] = options["max_tokens"] - data["max_output_tokens"] = options["max_tokens"] - if "temperature" in options: - data["temperature"] = options["temperature"] - if "top_p" in options: - data["top_p"] = options["top_p"] + headers = _make_common_request_headers() + headers["Authorization"] = f'Snowflake Token="{session.connection.rest.token}"' + body = _make_request_body(model, prompt, options) logger.debug(f"making POST request to {url} (model={model})") return requests.post( url, - json=data, + json=body, headers=headers, stream=True, ) @@ -164,49 +217,24 @@ def _complete_call_sql_function_snowpark( return cast(snowpark.Column, functions.builtin(function)(*args)) -def _complete_call_sql_function_immediate( - function: str, +def _complete_non_streaming_immediate( model: str, prompt: Union[str, List[ConversationMessage]], options: Optional[CompleteOptions], - session: Optional[snowpark.Session], + session: Optional[snowpark.Session] = None, + deadline: Optional[float] = None, ) -> str: - session = session or context.get_active_session() - if session is None: - raise SnowflakeAuthenticationException( - """Session required. Provide the session through a session=... argument or ensure an active session is - available in your environment.""" - ) + response = _complete_rest(model=model, prompt=prompt, options=options, session=session, deadline=deadline) + return "".join(response) - # https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex - if options is not None or not isinstance(prompt, str): - if isinstance(prompt, List): - prompt_arg = prompt - else: - prompt_arg = [{"role": "user", "content": prompt}] - options = options or {} - lit_args = [ - functions.lit(model), - functions.lit(prompt_arg), - functions.lit(options), - ] - else: - lit_args = [ - functions.lit(model), - functions.lit(prompt), - ] - - empty_df = session.create_dataframe([snowpark.Row()]) - df = empty_df.select(functions.builtin(function)(*lit_args)) - return cast(str, df.collect()[0][0]) - -def _complete_sql_impl( +def _complete_non_streaming_impl( function: str, model: Union[str, snowpark.Column], prompt: Union[str, List[ConversationMessage], snowpark.Column], options: Optional[Union[CompleteOptions, snowpark.Column]], - session: Optional[snowpark.Session], + session: Optional[snowpark.Session] = None, + deadline: Optional[float] = None, ) -> Union[str, snowpark.Column]: if isinstance(prompt, snowpark.Column): if options is not None: @@ -217,7 +245,24 @@ def _complete_sql_impl( raise ValueError("'model' cannot be a snowpark.Column when 'prompt' is a string.") if isinstance(options, snowpark.Column): raise ValueError("'options' cannot be a snowpark.Column when 'prompt' is a string.") - return _complete_call_sql_function_immediate(function, model, prompt, options, session) + return _complete_non_streaming_immediate( + model=model, prompt=prompt, options=options, session=session, deadline=deadline + ) + + +def _complete_rest( + model: str, + prompt: Union[str, List[ConversationMessage]], + options: Optional[CompleteOptions] = None, + session: Optional[snowpark.Session] = None, + deadline: Optional[float] = None, +) -> Iterator[str]: + if is_in_stored_procedure(): # type: ignore[no-untyped-call] + response = _call_complete_xp(model=model, prompt=prompt, options=options, deadline=deadline) + else: + response = _call_complete_rest(model=model, prompt=prompt, options=options, session=session, deadline=deadline) + assert response.status_code >= 200 and response.status_code < 300 + return _return_stream_response(response, deadline) def _complete_impl( @@ -239,10 +284,8 @@ def _complete_impl( raise ValueError("in REST mode, 'model' must be a string") if not isinstance(prompt, str) and not isinstance(prompt, List): raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage") - response = _call_complete_rest(model, prompt, options, session=session, deadline=deadline) - assert response.status_code >= 200 and response.status_code < 300 - return _return_stream_response(response, deadline) - return _complete_sql_impl(function, model, prompt, options, session) + return _complete_rest(model=model, prompt=prompt, options=options, session=session, deadline=deadline) + return _complete_non_streaming_impl(function, model, prompt, options, session, deadline) @telemetry.send_api_usage_telemetry( diff --git a/snowflake/cortex/_finetune.py b/snowflake/cortex/_finetune.py new file mode 100644 index 00000000..d1ba21a8 --- /dev/null +++ b/snowflake/cortex/_finetune.py @@ -0,0 +1,273 @@ +import json +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union, cast + +from snowflake import snowpark +from snowflake.cortex._util import ( + CORTEX_FUNCTIONS_TELEMETRY_PROJECT, + call_sql_function_literals, +) +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.utils import snowpark_dataframe_utils + +_CORTEX_FINETUNE_SYSTEM_FUNCTION_NAME = "SNOWFLAKE.CORTEX.FINETUNE" +CORTEX_FINETUNE_TELEMETRY_SUBPROJECT = "FINETUNE" +CORTEX_FINETUNE_FIRST_VERSION = "1.7.0" +CORTEX_FINETUNE_DOCUMENTATION_URL = "https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-finetuning" + + +class FinetuneError(Exception): + def __init__(self, message: str, original_exception: Optional[Exception] = None) -> None: + """Finetuning Exception Class. + + Args: + message: Error message to be reported. + original_exception: Original exception. This is the exception raised to users by telemetry. + + Attributes: + original_exception: Original exception with an error code in its message. + """ + self.original_exception = original_exception + self._pretty_msg = message + repr(self.original_exception) if self.original_exception is not None else "" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._pretty_msg!r})" + + def __str__(self) -> str: + return self._pretty_msg + + +@dataclass +class FinetuneStatus: + """Fine-tuning job status.""" + + id: Optional[str] = None + """Workflow ID for the fine-tuning run.""" + + status: Optional[str] = None + """Status string, e.g. PENDING, RUNNING, SUCCESS, ERROR, CANCELLED.""" + + base_model: Optional[str] = None + """Name of the base model that is being fine-tuned.""" + + created_on: Optional[int] = None + """Creation timestamp of the Fine-tuning job in milliseconds.""" + + error: Optional[Dict[str, Any]] = None + """Error message propagated from the job.""" + + finished_on: Optional[int] = None + """Completion timestamp of the Fine-tuning job in milliseconds.""" + + progress: Optional[float] = None + """Progress made as a fraction of total [0.0,1.0].""" + + training_result: Optional[List[Dict[str, Any]]] = None + """Detailed metrics report for a completed training.""" + + trained_tokens: Optional[int] = None + """Number of tokens trained on. If multiple epochs are run, this can be larger than number of tokens in the + training data.""" + + training_data: Optional[str] = None + """Training data query.""" + + validation_data: Optional[str] = None + """Validation data query.""" + + model: Optional[str] = None + """Location of the fine-tuned model.""" + + +class FinetuneJob: + def __init__(self, session: Optional[snowpark.Session], status: FinetuneStatus) -> None: + """Fine-tuning Job. + + Args: + session: Snowpark session to use to communicate with Snowflake. + status: FinetuneStatus for this job. + """ + self._session = session + self.status = status + + def __repr__(self) -> str: + return self.status.__repr__() + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FinetuneJob): + raise NotImplementedError( + f"Equality comparison of FinetuneJob with objects of type {type(other)} is not implemented." + ) + return self.status == other.status + + @snowpark._internal.utils.experimental(version=CORTEX_FINETUNE_FIRST_VERSION) + @telemetry.send_api_usage_telemetry( + project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, + subproject=CORTEX_FINETUNE_TELEMETRY_SUBPROJECT, + ) + def cancel(self) -> bool: + """Cancel a fine-tuning run. + + No confirmation will be required. + + [Documentation](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-finetuning) + + Args: + + Returns: + True if the cancellation was successful, False otherwise. + """ + result = _finetune_impl(operation="CANCEL", session=self._session, function_args=[self.status.id]) + return result is not None and isinstance(result, str) and result.startswith("Canceled Cortex Fine-tuning job: ") + + @snowpark._internal.utils.experimental(version=CORTEX_FINETUNE_FIRST_VERSION) + @telemetry.send_api_usage_telemetry( + project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, + subproject=CORTEX_FINETUNE_TELEMETRY_SUBPROJECT, + ) + def describe(self) -> FinetuneStatus: + """Describe a fine-tuning run. + + Args: + + Returns: + FinetuneStatus containing of attributes of the fine-tuning run. + """ + result_string = _finetune_impl(operation="DESCRIBE", session=self._session, function_args=[self.status.id]) + + result = FinetuneStatus(**cast(Dict[str, Any], _try_load_json(result_string))) + return result + + +class Finetune: + @snowpark._internal.utils.experimental(version=CORTEX_FINETUNE_FIRST_VERSION) + @telemetry.send_api_usage_telemetry( + project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, + subproject=CORTEX_FINETUNE_TELEMETRY_SUBPROJECT, + ) + def __init__(self, session: Optional[snowpark.Session] = None) -> None: + """Cortex Fine-Tuning API. + + [Documentation](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-finetuning) + + Args: + session: Snowpark session to be used. If none is given, we will attempt to + use the currently active session. + """ + self._session = session + + @snowpark._internal.utils.experimental(version=CORTEX_FINETUNE_FIRST_VERSION) + @telemetry.send_api_usage_telemetry( + project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, + subproject=CORTEX_FINETUNE_TELEMETRY_SUBPROJECT, + ) + def create( + self, + name: str, + base_model: str, + training_data: Union[str, snowpark.DataFrame], + validation_data: Optional[Union[str, snowpark.DataFrame]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> FinetuneJob: + """Create a new fine-tuning runs. + + The expected format of training and validation data is two fields or columns: + "prompt": the input to the model and + "completion": the output that the model is expected to generate. + + Both data parameters "training_data" and "validation_data" expect to be one of + (1) stage path to JSONL-formatted data, + (2) select-query string resulting in a table, + (3) Snowpark DataFrame containing the data + + [Documentation](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-finetuning) + + Args: + name: Name of the resulting fine-tuned model. + base_model: The name of the base model to start fine-tuning from. + training_data: Data used for fine-tuning the model. + validation_data: Data used for validating the fine-tuned model (not used in training) + options: Dictionary of additional options to be passed to the training procedure. + Please refer to the official documentation for a list of available options. + + Returns: + The identifier of the fine-tuning run. + + Raises: + ValueError: If the Snowpark DataFrame used is incompatible with this API. + This can happen if the DataFrame contains multiple queries. + """ + + # Handle data provided as snowpark dataframes + if isinstance(training_data, snowpark.DataFrame): + if snowpark_dataframe_utils.is_single_query_snowpark_dataframe(training_data): + training_string = str(training_data.queries["queries"][0]) + else: + raise ValueError( + "Snowpark DataFrame given in 'training_data' contains " + + f'{training_data.queries["queries"]} queries and ' + + f'{training_data.queries["post_actions"]} post_actions. It needs ' + "to contain exactly one query and no post_actions." + ) + else: + training_string = training_data + + validation_string: Optional[str] = None + if isinstance(validation_data, snowpark.DataFrame): + if snowpark_dataframe_utils.is_single_query_snowpark_dataframe(validation_data): + validation_string = str(validation_data.queries["queries"][0]) + else: + raise ValueError( + "Snowpark DataFrame given in 'validation_data' contains " + + f'{validation_data.queries["queries"]} queries and ' + + f'{validation_data.queries["post_actions"]} post_actions. It needs ' + "to contain exactly one query and no post_actions." + ) + else: + validation_string = validation_data + + result = _finetune_impl( + operation="CREATE", + session=self._session, + function_args=[name, base_model, training_string, validation_string, options], + ) + finetune_status = FinetuneStatus(id=result) + finetune_run = FinetuneJob(self._session, finetune_status) + return finetune_run + + @snowpark._internal.utils.experimental(version=CORTEX_FINETUNE_FIRST_VERSION) + @telemetry.send_api_usage_telemetry( + project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, + subproject=CORTEX_FINETUNE_TELEMETRY_SUBPROJECT, + ) + def list_jobs(self) -> List["FinetuneJob"]: + """Show current and past fine-tuning runs. + + Returns: + List of dictionaries of attributes of the fine-tuning runs. Please refer to the official documentation for a + list of expected fields. + """ + result_string = _finetune_impl(operation="SHOW", session=self._session, function_args=[]) + result = _try_load_json(result_string) + + return [FinetuneJob(session=self._session, status=FinetuneStatus(**run_status)) for run_status in result] + + +def _try_load_json(json_string: str) -> Union[Dict[Any, Any], List[Any]]: + try: + result = json.loads(str(json_string)) + except json.JSONDecodeError as e: + message = f"""Unable to parse JSON from: "{json_string}". """ + raise FinetuneError(message=message, original_exception=e) + except Exception as e: + message = f"""Unable to parse JSON from: "{json_string}". """ + raise FinetuneError(message=message, original_exception=e) + else: + if not isinstance(result, dict) and not isinstance(result, list): + message = f"""Unable to parse JSON from: "{json_string}". Result was not a dictionary.""" + raise FinetuneError(message=message) + return result + + +def _finetune_impl(operation: str, session: Optional[snowpark.Session], function_args: List[Any]) -> str: + return call_sql_function_literals(_CORTEX_FINETUNE_SYSTEM_FUNCTION_NAME, session, operation, *function_args) diff --git a/snowflake/cortex/_sse_client.py b/snowflake/cortex/_sse_client.py index a6fcf202..cd697d32 100644 --- a/snowflake/cortex/_sse_client.py +++ b/snowflake/cortex/_sse_client.py @@ -1,73 +1,125 @@ -from typing import Iterator, cast +import json +from typing import Any, Iterator, Optional -import requests +_FIELD_SEPARATOR = ":" class Event: - def __init__(self, event: str = "message", data: str = "") -> None: + """Representation of an event from the event stream.""" + + def __init__( + self, + id: Optional[str] = None, + event: str = "message", + data: str = "", + comment: Optional[str] = None, + retry: Optional[int] = None, + ) -> None: + self.id = id self.event = event self.data = data + self.comment = comment + self.retry = retry def __str__(self) -> str: s = f"{self.event} event" + if self.id: + s += f" #{self.id}" if self.data: - s += f", {len(self.data)} bytes" + s += ", {} byte{}".format(len(self.data), "s" if len(self.data) else "") else: s += ", no data" + if self.comment: + s += f", comment: {self.comment}" + if self.retry: + s += f", retry in {self.retry}ms" return s +# This is copied from the snowpy library: +# https://github.com/snowflakedb/snowpy/blob/main/libs/snowflake.core/src/snowflake/core/rest.py#L39 +# TODO(SNOW-1750723) - Current there’s code duplication across snowflake-ml-python +# and snowpy library for Cortex REST API which was done to meet our GA timelines +# Once snowpy has a release with https://github.com/snowflakedb/snowpy/pull/679, we should +# remove the class here and directly refer from the snowflake.core package directly class SSEClient: - def __init__(self, response: requests.Response) -> None: + def __init__(self, event_source: Any, char_enc: str = "utf-8") -> None: + self._event_source = event_source + self._char_enc = char_enc - self.response = response - - def _read(self) -> Iterator[str]: - - lines = b"" - for chunk in self.response: + def _read(self) -> Iterator[bytes]: + data = b"" + for chunk in self._event_source: for line in chunk.splitlines(True): - lines += line - if lines.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): - yield cast(str, lines) - lines = b"" - if lines: - yield cast(str, lines) + data += line + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): + yield data + data = b"" + if data: + yield data def events(self) -> Iterator[Event]: - for raw_event in self._read(): + content_type = self._event_source.headers.get("Content-Type") + # The check for empty content-type is present because it's being populated after + # the change in https://github.com/snowflakedb/snowflake/pull/217654. + # This can be removed once the above change makes it to prod or we move to snowpy + # for SSEClient implementation. + if content_type == "text/event-stream" or not content_type: + return self._handle_sse() + elif content_type == "application/json": + return self._handle_json() + else: + raise ValueError(f"Unknown Content-Type: {content_type}") + + def _handle_sse(self) -> Iterator[Event]: + for chunk in self._read(): event = Event() - # splitlines() only uses \r and \n - for line in raw_event.splitlines(): + # Split before decoding so splitlines() only uses \r and \n + for line_bytes in chunk.splitlines(): + # Decode the line. + line = line_bytes.decode(self._char_enc) - line = cast(bytes, line).decode("utf-8") + # Lines starting with a separator are comments and are to be + # ignored. + if not line.strip() or line.startswith(_FIELD_SEPARATOR): + continue - data = line.split(":", 1) + data = line.split(_FIELD_SEPARATOR, 1) field = data[0] + # Ignore unknown fields. + if not hasattr(event, field): + continue + if len(data) > 1: + # From the spec: # "If value starts with a single U+0020 SPACE character, - # remove it from value. .strip() would remove all white spaces" + # remove it from value." if data[1].startswith(" "): value = data[1][1:] else: value = data[1] else: + # If no value is present after the separator, + # assume an empty value. value = "" # The data field may come over multiple lines and their values # are concatenated with each other. + current_value = getattr(event, field, "") if field == "data": - event.data += value + "\n" - elif field == "event": - event.event = value + new_value = current_value + value + "\n" + else: + new_value = value + setattr(event, field, new_value) + # Events with no data are not dispatched. if not event.data: continue # If the data field ends with a newline, remove it. if event.data.endswith("\n"): - event.data = event.data[0:-1] # Replace trailing newline - rstrip would remove multiple. + event.data = event.data[0:-1] # Empty event names default to 'message' event.event = event.event or "message" @@ -77,5 +129,16 @@ def events(self) -> Iterator[Event]: yield event + def _handle_json(self) -> Iterator[Event]: + data_list = json.loads(self._event_source.data.decode(self._char_enc)) + for data in data_list: + yield Event( + id=data.get("id"), + event=data.get("event"), + data=data.get("data"), + comment=data.get("comment"), + retry=data.get("retry"), + ) + def close(self) -> None: - self.response.close() + self._event_source.close() diff --git a/snowflake/cortex/_util.py b/snowflake/cortex/_util.py index bdafe23f..ceba07d0 100644 --- a/snowflake/cortex/_util.py +++ b/snowflake/cortex/_util.py @@ -1,6 +1,8 @@ -from typing import Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Union, cast from snowflake import snowpark +from snowflake.ml._internal.exceptions import error_codes, exceptions +from snowflake.ml._internal.utils import formatting from snowflake.snowpark import context, functions CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions" @@ -64,3 +66,30 @@ def _call_sql_function_immediate( empty_df = session.create_dataframe([snowpark.Row()]) df = empty_df.select(functions.builtin(function)(*lit_args)) return cast(str, df.collect()[0][0]) + + +def call_sql_function_literals(function: str, session: Optional[snowpark.Session], *args: Any) -> str: + r"""Call a SQL function with only literal arguments. + + This is useful for calling system functions. + + Args: + function: The name of the function to be called. + session: The Snowpark session to use. + *args: The list of arguments + + Returns: + String value that corresponds the the first cell in the dataframe. + + Raises: + SnowflakeMLException: If no session is given and no active session exists. + """ + if session is None: + session = context.get_active_session() + if session is None: + raise exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_SNOWPARK_SESSION, + ) + + function_arguments = ",".join(["NULL" if arg is None else formatting.format_value_for_select(arg) for arg in args]) + return cast(str, session.sql(f"SELECT {function}({function_arguments})").collect()[0][0]) diff --git a/snowflake/cortex/complete_test.py b/snowflake/cortex/complete_test.py index 84d40430..14543e14 100644 --- a/snowflake/cortex/complete_test.py +++ b/snowflake/cortex/complete_test.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from io import BytesIO from types import GeneratorType -from typing import Dict, Iterable, Iterator, List, cast +from typing import Dict, Iterable, Iterator, cast import _test_util from absl.testing import absltest @@ -15,10 +15,9 @@ from snowflake import snowpark from snowflake.cortex import _complete -from snowflake.cortex._complete import CompleteOptions, ConversationMessage from snowflake.snowpark import functions, types -_OPTIONS = CompleteOptions( # random params +_OPTIONS = _complete.CompleteOptions( # random params max_tokens=10, temperature=0.7, top_p=1, @@ -58,8 +57,10 @@ class FakeSession: class FakeResponse: # needed for testing, imitates some of requests.Response behaviors - def __init__(self, content: bytes) -> None: + def __init__(self, content: bytes, headers: Dict[str, str], data: bytes) -> None: self.content = BytesIO(content) + self.headers = headers + self.data = data def iter_content(self, chunk_size: int = 1) -> Iterator[bytes]: while True: @@ -74,6 +75,9 @@ def __iter__(self) -> Iterator[bytes]: class CompleteSQLBackendTest(absltest.TestCase): model = "|model|" + custom_model_stage = "@my.custom.model/stage" + custom_model_entity = "my.custom.model_entity" + all_models = [model, custom_model_stage, custom_model_entity] prompt = "|prompt|" @staticmethod @@ -98,74 +102,18 @@ def tearDown(self) -> None: self._session.sql("drop function complete(string,string)").collect() self._session.close() - def test_complete_sql_mode(self) -> None: - res = _complete._complete_impl(self.model, self.prompt, session=self._session, function="complete") - self.assertEqual(self.complete_for_test(self.model, self.prompt), res) - def test_complete_snowpark_mode(self) -> None: - df_in = self._session.create_dataframe([snowpark.Row(model=self.model, prompt=self.prompt)]) + """Test complete call with a single dataframe argument with columns for model + and prompt.""" + df_in = self._session.create_dataframe( + [snowpark.Row(model=model, prompt=self.prompt) for model in self.all_models] + ) df_out = df_in.select( _complete._complete_impl(functions.col("model"), functions.col("prompt"), function="complete") ) - res = df_out.collect()[0][0] - self.assertEqual(self.complete_for_test(self.model, self.prompt), res) - - -class CompleteOptionsSQLBackendTest(absltest.TestCase): - model = "|model|" - - @staticmethod - def format_as_complete(model: str, prompt: List[ConversationMessage], options: CompleteOptions) -> str: - prompt_str = "" - for d in prompt: - prompt_str += f"({d['role']}, {d['content']}), " - return f"model: {model}, prompt: {prompt_str}, options: {options}" - - @staticmethod - def complete_for_test(model: str, prompt: List[Dict[str, str]], options: Dict[str, float]) -> str: - prompt_str = "" - for d in prompt: - prompt_str += f"({d['role']}, {d['content']}), " - return f"model: {model}, prompt: {prompt_str}, options: {options}" - - def setUp(self) -> None: - self._session = _test_util.create_test_session() - functions.udf( - self.complete_for_test, - name="complete", - return_type=types.StringType(), - input_types=[types.StringType(), types.ArrayType(), types.MapType()], - session=self._session, - is_permanent=False, - ) - - def tearDown(self) -> None: - self._session.sql("drop function complete(string,array,object)").collect() - self._session.close() - - def test_conversation_history_immediate_mode(self) -> None: - conversation_history_prompt = [ - ConversationMessage({"role": "system", "content": "content for system"}), - ConversationMessage({"role": "user", "content": "content for user"}), - ] - res = _complete._complete_impl( - self.model, conversation_history_prompt, session=self._session, function="complete" - ) - self.assertEqual(self.format_as_complete(self.model, conversation_history_prompt, {}), res) - - def test_populated_options(self) -> None: - prompt = "|prompt|" - equivalent_prompt_for_sql = [ConversationMessage({"role": "user", "content": "|prompt|"})] - res = _complete._complete_impl(self.model, prompt, options=_OPTIONS, session=self._session, function="complete") - self.assertEqual(self.format_as_complete(self.model, equivalent_prompt_for_sql, _OPTIONS), res) - - def test_empty_options(self) -> None: - prompt = "|prompt|" - equivalent_prompt_for_sql = [ConversationMessage({"role": "user", "content": "|prompt|"})] - res = _complete._complete_impl( - self.model, prompt, options=_complete.CompleteOptions(), session=self._session, function="complete" - ) - self.assertEqual(self.format_as_complete(self.model, equivalent_prompt_for_sql, CompleteOptions()), res) + for row_index in range(len(self.all_models)): + res = df_out.collect()[row_index][0] + self.assertEqual(self.complete_for_test(self.all_models[row_index], self.prompt), res) class MockIpifyHTTPRequestHandler(http.server.BaseHTTPRequestHandler): @@ -328,6 +276,16 @@ def test_streaming_timeout(self) -> None: ), ) + def test_complete_non_streaming_mode(self) -> None: + result = _complete._complete_impl( + model="my_models", + prompt="test_prompt", + options=_complete.CompleteOptions(), + session=self.session, + ) + self.assertIsInstance(result, str) + self.assertEqual("This is a streaming response", result) + def test_deadline(self) -> None: self.assertRaises( TimeoutError, diff --git a/snowflake/cortex/finetune_test.py b/snowflake/cortex/finetune_test.py new file mode 100644 index 00000000..b297d547 --- /dev/null +++ b/snowflake/cortex/finetune_test.py @@ -0,0 +1,197 @@ +import json +from typing import Any, Dict, List +from unittest import mock + +from absl.testing import absltest + +from snowflake.cortex import Finetune, FinetuneJob, FinetuneStatus +from snowflake.ml.test_utils import mock_data_frame + + +class FinetuneTest(absltest.TestCase): + system_function_name = "SNOWFLAKE.CORTEX.FINETUNE" + + def setUp(self) -> None: + self.list_jobs_return_value: List[Dict[str, Any]] = [ + {"id": "1", "status": "SUCCESS"}, + {"id": "2", "status": "ERROR"}, + ] + self.list_jobs_expected_result = [ + FinetuneJob(session=None, status=FinetuneStatus(**status)) for status in self.list_jobs_return_value + ] + + @mock.patch("snowflake.cortex._finetune.call_sql_function_literals") + def test_finetune_create(self, mock_call_sql_function: mock.Mock) -> None: + """Test call of finetune operation CREATE.""" + mock_call_sql_function.return_value = "workflow_id" + cft = Finetune() + cft.create("test_model", "base_model", "SELECT * FROM TRAINING", "SELECT * FROM VALIDATION") + mock_call_sql_function.assert_called_with( + self.system_function_name, + None, + "CREATE", + "test_model", + "base_model", + "SELECT * FROM TRAINING", + "SELECT * FROM VALIDATION", + None, + ) + + @mock.patch("snowflake.cortex._finetune.call_sql_function_literals") + def test_finetune_create_with_snowpark_dataframe(self, mock_call_sql_function: mock.Mock) -> None: + """Test call of finetune operation CREATE.""" + mock_call_sql_function.return_value = "workflow_id" + training_df = mock_data_frame.MockDataFrame() + training_df.add_query("queries", "SELECT PROMPT, COMPLETION FROM TRAINING") + validation_df = mock_data_frame.MockDataFrame() + validation_df.add_query("queries", "SELECT PROMPT, COMPLETION FROM VALIDATION") + + cft = Finetune() + cft.create("test_model", "base_model", training_df, validation_df) + mock_call_sql_function.assert_called_with( + self.system_function_name, + None, + "CREATE", + "test_model", + "base_model", + "SELECT PROMPT, COMPLETION FROM TRAINING", + "SELECT PROMPT, COMPLETION FROM VALIDATION", + None, + ) + + @mock.patch("snowflake.cortex._finetune.call_sql_function_literals") + def test_finetune_create_with_snowpark_dataframe_two_training_queries( + self, mock_call_sql_function: mock.Mock + ) -> None: + """Test call of finetune operation CREATE with an incompatible training DataFrame.""" + training_df = mock_data_frame.MockDataFrame() + training_df.add_query("queries", "SELECT PROMPT, COMPLETION FROM TRAINING") + training_df.add_query("queries", "SELECT PROMPT, COMPLETION FROM VALIDATION") + validation_df = mock_data_frame.MockDataFrame() + validation_df.add_query("queries", "SELECT PROMPT, COMPLETION FROM VALIDATION") + + cft = Finetune() + with self.assertRaisesRegex(ValueError, r".*training_data.*queries.*"): + cft.create("test_model", "base_model", training_df, validation_df) + + @mock.patch("snowflake.cortex._finetune.call_sql_function_literals") + def test_finetune_create_with_snowpark_dataframe_two_validation_queries( + self, mock_call_sql_function: mock.Mock + ) -> None: + """Test call of finetune operation CREATE with an incompatible validation DataFrame.""" + training_df = mock_data_frame.MockDataFrame() + training_df.add_query("queries", "SELECT PROMPT, COMPLETION FROM TRAINING") + validation_df = mock_data_frame.MockDataFrame() + validation_df.add_query("queries", "SELECT PROMPT, COMPLETION FROM VALIDATION") + validation_df.add_query("queries", "SELECT PROMPT, COMPLETION FROM TRAINING") + + cft = Finetune() + with self.assertRaisesRegex(ValueError, r"validation_data.*queries"): + cft.create("test_model", "base_model", training_df, validation_df) + + @mock.patch("snowflake.cortex._finetune.call_sql_function_literals") + def test_finetune_create_with_options(self, mock_call_sql_function: mock.Mock) -> None: + """Test call of finetune operation CREATE with options.""" + mock_call_sql_function.return_value = "workflow_id" + cft = Finetune() + cft.create("test_model", "base_model", "SELECT * FROM TRAINING", "SELECT * FROM VALIDATION", {"awesome": True}) + mock_call_sql_function.assert_called_with( + self.system_function_name, + None, + "CREATE", + "test_model", + "base_model", + "SELECT * FROM TRAINING", + "SELECT * FROM VALIDATION", + {"awesome": True}, + ) + + @mock.patch("snowflake.cortex._finetune.Finetune.list_jobs") + @mock.patch("snowflake.cortex._finetune.call_sql_function_literals") + def test_finetune_cancel(self, mock_call_sql_function: mock.Mock, mock_finetune_list_jobs: mock.Mock) -> None: + """Test call of finetune operation CANCEL.""" + mock_call_sql_function.return_value = "job 2 cancelled" + mock_finetune_list_jobs.return_value = self.list_jobs_expected_result + cft = Finetune() + run = cft.list_jobs()[1] + run.cancel() + mock_call_sql_function.assert_called_with(self.system_function_name, None, "CANCEL", "2") + + @mock.patch("snowflake.cortex._finetune.Finetune.list_jobs") + @mock.patch("snowflake.cortex._finetune.call_sql_function_literals") + def test_finetune_describe(self, mock_call_sql_function: mock.Mock, mock_finetune_list_jobs: mock.Mock) -> None: + """Test call of finetune operation DESCRIBE.""" + sql_return_value: Dict[str, Any] = { + "base_model": "llama3-8b", + "created_on": 1728688216077, + "finished_on": 1728688392137, + "id": "CortexFineTuningWorkflow_4dbc8970-65d8-44b8-9054-da818c1593dd", + "model": "CFT_DB.TASTY_EMAIL.test_api_1", + "progress": 1.0, + "status": "SUCCESS", + "training_data": ( + "select MODIFIED_BODY as PROMPT, GOLDEN_JSON as COMPLETION " + "from EMAIL_MODIFIED_WITH_RESPONSE_GOLDEN_10K_JSON where id % 10 = 0" + ), + "trained_tokens": 377100, + "training_result": {"validation_loss": 0.8828646540641785, "training_loss": 0.8691850564418695}, + "validation_data": "", + } + expected_result = FinetuneStatus(**sql_return_value) + mock_call_sql_function.return_value = json.dumps(sql_return_value) + mock_finetune_list_jobs.return_value = self.list_jobs_expected_result + run = Finetune().list_jobs()[0] + self.assertEqual( + run.describe(), + expected_result, + ) + mock_call_sql_function.assert_called_with(self.system_function_name, None, "DESCRIBE", "1") + + @mock.patch("snowflake.cortex._finetune.Finetune.list_jobs") + @mock.patch("snowflake.cortex._finetune.call_sql_function_literals") + def test_finetune_describe_error( + self, mock_call_sql_function: mock.Mock, mock_finetune_list_jobs: mock.Mock + ) -> None: + """Test call of finetune operation DESCRIBE with error message.""" + mock_finetune_list_jobs.return_value = self.list_jobs_expected_result + sql_return_value: Dict[str, Any] = { + "base_model": "llama3-8b", + "created_on": 1728670992861, + "error": { + "code": "INVALID_PARAMETER_VALUE", + "message": ( + 'Failed to query input data for Fine-tuning. Failed to execute query: SELECT "PROMPT", ' + '"COMPLETION" FROM (select MODIFIED_BODY as PROMPT, GOLDEN_JSON as COMPLETION from ' + "EMAIL_MODIFIED_WITH_RESPONSE_GOLDEN_10K_JSON where id % 10 = 0;). 001003 (42000): " + "01b79faf-0003-dbfd-0022-b7876037dabe: SQL compilation error:\nsyntax error line 1 at position " + "161 unexpected ';'." + ), + }, + "finished_on": 1728670994015, + "id": "CortexFineTuningWorkflow_54cfb2bb-ff69-4d5a-8513-320ba3fdb258", + "progress": 0.0, + "status": "ERROR", + "training_data": ( + "select MODIFIED_BODY as PROMPT, GOLDEN_JSON as COMPLETION from " + "EMAIL_MODIFIED_WITH_RESPONSE_GOLDEN_10K_JSON where id % 10 = 0;" + ), + "validation_data": "", + } + + mock_call_sql_function.return_value = json.dumps(sql_return_value) + run = Finetune().list_jobs()[0] + self.assertEqual(run.describe(), FinetuneStatus(**sql_return_value)) + mock_call_sql_function.assert_called_with(self.system_function_name, None, "DESCRIBE", "1") + + @mock.patch("snowflake.cortex._finetune.call_sql_function_literals") + def test_finetune_list_jobs(self, mock_call_sql_function: mock.Mock) -> None: + """Test call of finetune operation list_jobs.""" + mock_call_sql_function.return_value = json.dumps(self.list_jobs_return_value) + run_list = Finetune().list_jobs() + self.assertTrue(isinstance(run_list, list)) + self.assertEqual(run_list, self.list_jobs_expected_result) + mock_call_sql_function.assert_called_with(self.system_function_name, None, "SHOW") + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/cortex/package_visibility_test.py b/snowflake/cortex/package_visibility_test.py index 98655da8..fd18ccfc 100644 --- a/snowflake/cortex/package_visibility_test.py +++ b/snowflake/cortex/package_visibility_test.py @@ -31,6 +31,11 @@ def test_summarize_visible(self) -> None: def test_translate_visible(self) -> None: self.assertTrue(callable(cortex.Translate)) + def test_finetune_visible(self) -> None: + self.assertTrue(callable(cortex.Finetune)) + self.assertTrue(callable(cortex.FinetuneJob)) + self.assertTrue(callable(cortex.FinetuneStatus)) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/cortex/sse_test.py b/snowflake/cortex/sse_test.py index b0ecebf4..d752a53a 100644 --- a/snowflake/cortex/sse_test.py +++ b/snowflake/cortex/sse_test.py @@ -1,14 +1,14 @@ -from typing import List, cast +import json +from typing import List -import requests from absl.testing import absltest from snowflake.cortex._sse_client import SSEClient from snowflake.cortex.complete_test import FakeResponse -def _streaming_messages(response_data: bytes) -> List[str]: - client = SSEClient(cast(requests.Response, FakeResponse(response_data))) +def _streaming_messages(response_data: bytes, content_type: str = "text/event-stream", data: bytes = b"") -> List[str]: + client = SSEClient(FakeResponse(response_data, {"Content-Type": content_type}, data=data)) out = [] for event in client.events(): out.append(event.data) @@ -114,7 +114,6 @@ def test_ignore_other_event_types(self) -> None: # fmt: on result_parsed = _streaming_messages(response_sth_else) - assert result_parsed == [] # ignore anything that is not message def test_empty_data_json(self) -> None: @@ -124,6 +123,12 @@ def test_empty_data_json(self) -> None: assert result_parsed == ["{}"] + def test_json_response(self) -> None: + d = {} + d["data"] = "random" + result_parsed = _streaming_messages(b"", "application/json", data=json.dumps([d]).encode("utf-8")) + assert result_parsed == ["random"] + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/_internal/type_utils.py b/snowflake/ml/_internal/type_utils.py index cbab70b4..b6fe2268 100644 --- a/snowflake/ml/_internal/type_utils.py +++ b/snowflake/ml/_internal/type_utils.py @@ -1,4 +1,4 @@ -import sys +import importlib from typing import Any, Generic, Type, TypeVar, Union, cast import numpy as np @@ -51,8 +51,8 @@ def __repr__(self) -> str: def get_class(self) -> Type[T]: if self._runtime_class is None: try: - m = sys.modules[self.module] - except KeyError: + m = importlib.import_module(self.module) + except ModuleNotFoundError: raise ValueError(f"Module {self.module} not imported.") self._runtime_class = cast("Type[T]", getattr(m, self.qualname)) diff --git a/snowflake/ml/model/_client/model/model_version_impl.py b/snowflake/ml/model/_client/model/model_version_impl.py index 9635cde7..7ea9c0ab 100644 --- a/snowflake/ml/model/_client/model/model_version_impl.py +++ b/snowflake/ml/model/_client/model/model_version_impl.py @@ -851,17 +851,13 @@ def list_services( ) return pd.DataFrame( - self._model_ops.list_inference_services( + self._model_ops.show_services( database_name=None, schema_name=None, model_name=self._model_name, version_name=self._version_name, statement_params=statement_params, - ), - columns=[ - self._model_ops.INFERENCE_SERVICE_NAME_COL_NAME, - self._model_ops.INFERENCE_SERVICE_ENDPOINT_COL_NAME, - ], + ) ) @telemetry.send_api_usage_telemetry( @@ -889,12 +885,16 @@ def delete_service( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) + + database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name) self._model_ops.delete_service( database_name=None, schema_name=None, model_name=self._model_name, version_name=self._version_name, - service_name=service_name, + service_database_name=database_name_id, + service_schema_name=schema_name_id, + service_name=service_name_id, statement_params=statement_params, ) diff --git a/snowflake/ml/model/_client/model/model_version_impl_test.py b/snowflake/ml/model/_client/model/model_version_impl_test.py index b8129297..30eef9a5 100644 --- a/snowflake/ml/model/_client/model/model_version_impl_test.py +++ b/snowflake/ml/model/_client/model/model_version_impl_test.py @@ -845,19 +845,15 @@ def test_create_service_no_eai(self) -> None: ) def test_list_services(self) -> None: - m_df = pd.DataFrame( - { - "service_name": ["a.b.c", "a.b.c", "d.e.f"], - "endpoints": ["fooendpoint", "barendpoint", "bazendpoint"], - } - ) + data = [ + {"name": "a.b.c", "inference_endpoint": "fooendpoint"}, + {"name": "d.e.f", "inference_endpoint": "bazendpoint"}, + ] + m_df = pd.DataFrame(data) with mock.patch.object( self.m_mv._model_ops, - attribute="list_inference_services", - return_value={ - "service_name": ["a.b.c", "a.b.c", "d.e.f"], - "endpoints": ["fooendpoint", "barendpoint", "bazendpoint"], - }, + attribute="show_services", + return_value=data, ) as mock_get_functions: pd.testing.assert_frame_equal(m_df, self.m_mv.list_services()) mock_get_functions.assert_called_once_with( @@ -881,7 +877,23 @@ def test_delete_service(self) -> None: schema_name=None, model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), - service_name="c", + service_database_name=None, + service_schema_name=None, + service_name=sql_identifier.SqlIdentifier("c"), + statement_params=mock.ANY, + ) + + with mock.patch.object(self.m_mv._model_ops, attribute="delete_service") as mock_delete_service: + self.m_mv.delete_service("a.b.c") + + mock_delete_service.assert_called_with( + database_name=None, + schema_name=None, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + service_database_name=sql_identifier.SqlIdentifier("a"), + service_schema_name=sql_identifier.SqlIdentifier("b"), + service_name=sql_identifier.SqlIdentifier("c"), statement_params=mock.ANY, ) diff --git a/snowflake/ml/model/_client/ops/model_ops.py b/snowflake/ml/model/_client/ops/model_ops.py index 7b7a24b8..815e6da0 100644 --- a/snowflake/ml/model/_client/ops/model_ops.py +++ b/snowflake/ml/model/_client/ops/model_ops.py @@ -3,7 +3,7 @@ import pathlib import tempfile import warnings -from typing import Any, Dict, List, Literal, Optional, Union, cast, overload +from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast, overload import yaml @@ -31,9 +31,14 @@ from snowflake.snowpark._internal import utils as snowpark_utils +class ServiceInfo(TypedDict): + name: str + inference_endpoint: Optional[str] + + class ModelOperator: - INFERENCE_SERVICE_NAME_COL_NAME = "service_name" - INFERENCE_SERVICE_ENDPOINT_COL_NAME = "endpoints" + INFERENCE_SERVICE_ENDPOINT_NAME = "inference" + INGRESS_ENDPOINT_URL_SUFFIX = "snowflakecomputing.app" def __init__( self, @@ -517,7 +522,7 @@ def unset_tag( statement_params=statement_params, ) - def list_inference_services( + def show_services( self, *, database_name: Optional[sql_identifier.SqlIdentifier], @@ -525,7 +530,7 @@ def list_inference_services( model_name: sql_identifier.SqlIdentifier, version_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, - ) -> Dict[str, List[str]]: + ) -> List[ServiceInfo]: res = self._model_client.show_versions( database_name=database_name, schema_name=schema_name, @@ -546,21 +551,28 @@ def list_inference_services( json_array = json.loads(res[0][service_col_name]) # TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side. - services = [str(service) for service in json_array if "MODEL_BUILD_" not in service] - endpoint_col_name = self._model_client.MODEL_INFERENCE_SERVICE_ENDPOINT_COL_NAME - - services_col, endpoints_col = [], [] - for service in services: - res = self._model_client.show_endpoints(service_name=service) - endpoints = [endpoint[endpoint_col_name] for endpoint in res] - for endpoint in endpoints: - services_col.append(service) - endpoints_col.append(endpoint) - - return { - self.INFERENCE_SERVICE_NAME_COL_NAME: services_col, - self.INFERENCE_SERVICE_ENDPOINT_COL_NAME: endpoints_col, - } + fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service] + + result = [] + ingress_url: Optional[str] = None + for fully_qualified_service_name in fully_qualified_service_names: + db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name) + for res_row in self._service_client.show_endpoints( + database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params + ): + if ( + res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME] + == self.INFERENCE_SERVICE_ENDPOINT_NAME + and res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME] is not None + ): + ingress_url = str( + res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME] + ) + if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX): + ingress_url = None + result.append(ServiceInfo(name=fully_qualified_service_name, inference_endpoint=ingress_url)) + + return result def delete_service( self, @@ -569,33 +581,42 @@ def delete_service( schema_name: Optional[sql_identifier.SqlIdentifier], model_name: sql_identifier.SqlIdentifier, version_name: sql_identifier.SqlIdentifier, - service_name: str, + service_database_name: Optional[sql_identifier.SqlIdentifier], + service_schema_name: Optional[sql_identifier.SqlIdentifier], + service_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - services = self.list_inference_services( + services = self.show_services( database_name=database_name, schema_name=schema_name, model_name=model_name, version_name=version_name, statement_params=statement_params, ) - db, schema, service_name = sql_identifier.parse_fully_qualified_name(service_name) + + # Fall back to the model's database and schema. + # database_name or schema_name are set if the model is created or get using fully qualified name + # Otherwise, the model's database and schema are same as registry's database and schema, which are set in the + # self._model_client. + + service_database_name = service_database_name or database_name or self._model_client._database_name + service_schema_name = service_schema_name or schema_name or self._model_client._schema_name fully_qualified_service_name = sql_identifier.get_fully_qualified_name( - db, schema, service_name, self._session.get_current_database(), self._session.get_current_schema() + service_database_name, service_schema_name, service_name ) - service_col_name = self.INFERENCE_SERVICE_NAME_COL_NAME - for service in services[service_col_name]: - if service == fully_qualified_service_name: + for service_info in services: + if service_info["name"] == fully_qualified_service_name: self._service_client.drop_service( - database_name=db, - schema_name=schema, + database_name=service_database_name, + schema_name=service_schema_name, service_name=service_name, statement_params=statement_params, ) return raise ValueError( - f"Service '{service_name}' does not exist or unauthorized or not associated with this model version." + f"Service '{fully_qualified_service_name}' does not exist " + "or unauthorized or not associated with this model version." ) def get_model_version_manifest( diff --git a/snowflake/ml/model/_client/ops/model_ops_test.py b/snowflake/ml/model/_client/ops/model_ops_test.py index e44a17e1..8a2180cf 100644 --- a/snowflake/ml/model/_client/ops/model_ops_test.py +++ b/snowflake/ml/model/_client/ops/model_ops_test.py @@ -460,29 +460,92 @@ def test_unset_tag(self) -> None: statement_params=self.m_statement_params, ) - def test_list_inference_services(self) -> None: + def test_show_services_1(self) -> None: m_services_list_res = [Row(inference_services='["a.b.c", "d.e.f"]')] - m_endpoints_list_res_0 = [Row(name="fooendpoint"), Row(name="barendpoint")] - m_endpoints_list_res_1 = [Row(name="bazendpoint")] + m_endpoints_list_res_0 = [Row(name="inference", ingress_url="Waiting")] + m_endpoints_list_res_1 = [Row(name="inference", ingress_url="foo.snowflakecomputing.app")] with mock.patch.object( self.m_ops._model_client, "show_versions", return_value=m_services_list_res ) as mock_show_versions, mock.patch.object( - self.m_ops._model_client, "show_endpoints", side_effect=[m_endpoints_list_res_0, m_endpoints_list_res_1] + self.m_ops._service_client, "show_endpoints", side_effect=[m_endpoints_list_res_0, m_endpoints_list_res_1] ): - res = self.m_ops.list_inference_services( + res = self.m_ops.show_services( database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), statement_params=self.m_statement_params, ) - self.assertEqual( + self.assertListEqual( + res, + [ + {"name": "a.b.c", "inference_endpoint": None}, + {"name": "d.e.f", "inference_endpoint": "foo.snowflakecomputing.app"}, + ], + ) + mock_show_versions.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + + def test_show_services_2(self) -> None: + m_services_list_res = [Row(inference_services='["a.b.c"]')] + m_endpoints_list_res = [Row(name="inference", ingress_url=None)] + + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_services_list_res + ) as mock_show_versions, mock.patch.object( + self.m_ops._service_client, "show_endpoints", return_value=m_endpoints_list_res + ): + res = self.m_ops.show_services( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + self.assertListEqual( + res, + [ + {"name": "a.b.c", "inference_endpoint": None}, + ], + ) + mock_show_versions.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + + def test_show_services_3(self) -> None: + m_services_list_res = [Row(inference_services='["a.b.c"]')] + m_endpoints_list_res = [ + Row(name="inference", ingress_url="foo.snowflakecomputing.app"), + Row(name="another", ingress_url="bar.snowflakecomputing.app"), + ] + + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_services_list_res + ) as mock_show_versions, mock.patch.object( + self.m_ops._service_client, "show_endpoints", return_value=m_endpoints_list_res + ): + res = self.m_ops.show_services( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + self.assertListEqual( res, - { - "service_name": ["a.b.c", "a.b.c", "d.e.f"], - "endpoints": ["fooendpoint", "barendpoint", "bazendpoint"], - }, + [ + {"name": "a.b.c", "inference_endpoint": "foo.snowflakecomputing.app"}, + ], ) mock_show_versions.assert_called_once_with( database_name=sql_identifier.SqlIdentifier("TEMP"), @@ -492,13 +555,43 @@ def test_list_inference_services(self) -> None: statement_params=self.m_statement_params, ) - def test_list_inference_services_pre_bcr(self) -> None: + def test_show_services_4(self) -> None: + m_services_list_res = [Row(inference_services='["a.b.c"]')] + m_endpoints_list_res = [Row(name="custom", ingress_url="foo.snowflakecomputing.app")] + + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_services_list_res + ) as mock_show_versions, mock.patch.object( + self.m_ops._service_client, "show_endpoints", return_value=m_endpoints_list_res + ): + res = self.m_ops.show_services( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + self.assertListEqual( + res, + [ + {"name": "a.b.c", "inference_endpoint": None}, + ], + ) + mock_show_versions.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + + def test_show_services_pre_bcr(self) -> None: m_list_res = [Row(comment="mycomment")] with mock.patch.object( self.m_ops._model_client, "show_versions", return_value=m_list_res ) as mock_show_versions: with self.assertRaises(exceptions.SnowflakeMLException) as context: - self.m_ops.list_inference_services( + self.m_ops.show_services( database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), model_name=sql_identifier.SqlIdentifier("MODEL"), @@ -518,27 +611,26 @@ def test_list_inference_services_pre_bcr(self) -> None: statement_params=self.m_statement_params, ) - def test_list_inference_services_skip_build(self) -> None: + def test_show_services_skip_build(self) -> None: m_list_res = [Row(inference_services='["A.B.MODEL_BUILD_34d35ew", "A.B.SERVICE"]')] m_endpoints_list_res = [Row(name="fooendpoint"), Row(name="barendpoint")] with mock.patch.object( self.m_ops._model_client, "show_versions", return_value=m_list_res ) as mock_show_versions, mock.patch.object( - self.m_ops._model_client, "show_endpoints", side_effect=[m_endpoints_list_res] + self.m_ops._service_client, "show_endpoints", side_effect=[m_endpoints_list_res] ): - res = self.m_ops.list_inference_services( + res = self.m_ops.show_services( database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), statement_params=self.m_statement_params, ) - self.assertEqual( + self.assertListEqual( res, - { - "service_name": ["A.B.SERVICE", "A.B.SERVICE"], - "endpoints": ["fooendpoint", "barendpoint"], - }, + [ + {"name": "A.B.SERVICE", "inference_endpoint": None}, + ], ) mock_show_versions.assert_called_once_with( database_name=sql_identifier.SqlIdentifier("TEMP"), @@ -554,52 +646,79 @@ def test_delete_service_non_existent(self) -> None: with mock.patch.object( self.m_ops._model_client, "show_versions", return_value=m_list_res ) as mock_show_versions, mock.patch.object( - self.m_session, attribute="get_current_database", return_value="a" - ) as mock_get_database, mock.patch.object( - self.m_session, attribute="get_current_schema", return_value="b" - ) as mock_get_schema, mock_show_versions, mock.patch.object( - self.m_ops._model_client, "show_endpoints", return_value=m_endpoints_list_res + self.m_ops._service_client, "show_endpoints", return_value=m_endpoints_list_res ): with self.assertRaisesRegex( - ValueError, "Service 'A' does not exist or unauthorized or not associated with this model version." + ValueError, "Service 'A.B.A' does not exist or unauthorized or not associated with this model version." ): self.m_ops.delete_service( database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), - service_name="a", + service_database_name=sql_identifier.SqlIdentifier("A"), + service_schema_name=sql_identifier.SqlIdentifier("B"), + service_name=sql_identifier.SqlIdentifier("A"), ) + mock_show_versions.assert_called_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions, mock.patch.object( + self.m_ops._service_client, "show_endpoints", return_value=m_endpoints_list_res + ): with self.assertRaisesRegex( - ValueError, "Service 'B' does not exist or unauthorized or not associated with this model version." + ValueError, + "Service 'FOO.\"bar\".B' does not exist or unauthorized or not associated with this model version.", ): self.m_ops.delete_service( - database_name=sql_identifier.SqlIdentifier("TEMP"), - schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + database_name=sql_identifier.SqlIdentifier("foo"), + schema_name=sql_identifier.SqlIdentifier("bar", case_sensitive=True), model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), - service_name="a.b", + service_database_name=None, + service_schema_name=None, + service_name=sql_identifier.SqlIdentifier("B"), ) + mock_show_versions.assert_called_with( + database_name=sql_identifier.SqlIdentifier("foo"), + schema_name=sql_identifier.SqlIdentifier("bar", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions, mock.patch.object( + self.m_ops._service_client, "show_endpoints", return_value=m_endpoints_list_res + ): with self.assertRaisesRegex( - ValueError, "Service 'D' does not exist or unauthorized or not associated with this model version." + ValueError, + "Service 'TEMP.\"test\".D' does not exist or unauthorized or not associated with this model version.", ): self.m_ops.delete_service( - database_name=sql_identifier.SqlIdentifier("TEMP"), - schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + database_name=None, + schema_name=None, model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), - service_name="b.c.d", + service_database_name=None, + service_schema_name=None, + service_name=sql_identifier.SqlIdentifier("D"), ) - mock_show_versions.assert_called_with( - database_name=sql_identifier.SqlIdentifier("TEMP"), - schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + database_name=None, + schema_name=None, model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), statement_params=mock.ANY, ) - mock_get_database.assert_called() - mock_get_schema.assert_called() def test_delete_service_exists(self) -> None: m_list_res = [Row(inference_services='["A.B.C", "D.E.F"]')] @@ -609,47 +728,89 @@ def test_delete_service_exists(self) -> None: ) as mock_show_versions, mock.patch.object( self.m_ops._service_client, "drop_service" ) as mock_drop_service, mock.patch.object( - self.m_session, attribute="get_current_database", return_value="a" - ) as mock_get_database, mock.patch.object( - self.m_session, attribute="get_current_schema", return_value="b" - ) as mock_get_schema, mock_show_versions, mock.patch.object( - self.m_ops._model_client, "show_endpoints", return_value=m_endpoints_list_res + self.m_ops._service_client, "show_endpoints", return_value=m_endpoints_list_res ): self.m_ops.delete_service( database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), - service_name="c", + service_database_name=sql_identifier.SqlIdentifier("A"), + service_schema_name=sql_identifier.SqlIdentifier("B"), + service_name=sql_identifier.SqlIdentifier("C"), ) - self.m_ops.delete_service( + mock_show_versions.assert_called_with( database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), - service_name="b.c", + statement_params=mock.ANY, ) + mock_drop_service.assert_called_with( + database_name=sql_identifier.SqlIdentifier("A"), + schema_name=sql_identifier.SqlIdentifier("B"), + service_name=sql_identifier.SqlIdentifier("C"), + statement_params=mock.ANY, + ) + + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions, mock.patch.object( + self.m_ops._service_client, "drop_service" + ) as mock_drop_service, mock.patch.object( + self.m_ops._service_client, "show_endpoints", return_value=m_endpoints_list_res + ): self.m_ops.delete_service( - database_name=sql_identifier.SqlIdentifier("TEMP"), - schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + database_name=sql_identifier.SqlIdentifier("A"), + schema_name=sql_identifier.SqlIdentifier("B"), model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), - service_name="a.b.c", + service_database_name=None, + service_schema_name=None, + service_name=sql_identifier.SqlIdentifier("C"), ) - mock_show_versions.assert_called_with( - database_name=sql_identifier.SqlIdentifier("TEMP"), - schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + database_name=sql_identifier.SqlIdentifier("A"), + schema_name=sql_identifier.SqlIdentifier("B"), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + mock_drop_service.assert_called_with( + database_name=sql_identifier.SqlIdentifier("A"), + schema_name=sql_identifier.SqlIdentifier("B"), + service_name=sql_identifier.SqlIdentifier("C"), + statement_params=mock.ANY, + ) + with mock.patch.object( + self.m_ops._model_client, + "show_versions", + return_value=[Row(inference_services='["TEMP.\\"test\\".C", "D.E.F"]')], + ) as mock_show_versions, mock.patch.object( + self.m_ops._service_client, "drop_service" + ) as mock_drop_service, mock.patch.object( + self.m_ops._service_client, "show_endpoints", return_value=m_endpoints_list_res + ): + self.m_ops.delete_service( + database_name=None, + schema_name=None, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + service_database_name=None, + service_schema_name=None, + service_name=sql_identifier.SqlIdentifier("C"), + ) + mock_show_versions.assert_called_with( + database_name=None, + schema_name=None, model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), statement_params=mock.ANY, ) - mock_get_database.assert_called() - mock_get_schema.assert_called() mock_drop_service.assert_called_with( - database_name="A", - schema_name="B", - service_name="C", + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + service_name=sql_identifier.SqlIdentifier("C"), statement_params=mock.ANY, ) diff --git a/snowflake/ml/model/_client/ops/service_ops.py b/snowflake/ml/model/_client/ops/service_ops.py index 4a6092f4..06771725 100644 --- a/snowflake/ml/model/_client/ops/service_ops.py +++ b/snowflake/ml/model/_client/ops/service_ops.py @@ -109,6 +109,17 @@ def create_service( build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]], statement_params: Optional[Dict[str, Any]] = None, ) -> str: + + # Fall back to the registry's database and schema if not provided + database_name = database_name or self._database_name + schema_name = schema_name or self._schema_name + + # Fall back to the model's database and schema if not provided then to the registry's database and schema + service_database_name = service_database_name or database_name or self._database_name + service_schema_name = service_schema_name or schema_name or self._schema_name + + image_repo_database_name = image_repo_database_name or database_name or self._database_name + image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name # create a temp stage stage_name = sql_identifier.SqlIdentifier( snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE) @@ -130,8 +141,8 @@ def create_service( raise ValueError("External access integrations are required in Snowflake < 8.40.0.") self._model_deployment_spec.save( - database_name=database_name or self._database_name, - schema_name=schema_name or self._schema_name, + database_name=database_name, + schema_name=schema_name, model_name=model_name, version_name=version_name, service_database_name=service_database_name, diff --git a/snowflake/ml/model/_client/ops/service_ops_test.py b/snowflake/ml/model/_client/ops/service_ops_test.py index 489501d3..a6043456 100644 --- a/snowflake/ml/model/_client/ops/service_ops_test.py +++ b/snowflake/ml/model/_client/ops/service_ops_test.py @@ -116,9 +116,193 @@ def test_create_service(self) -> None: statement_params=self.m_statement_params, ) mock_get_service_status.assert_called_once_with( - database_name="SERVICE_DB", - schema_name="SERVICE_SCHEMA", - service_name="MYSERVICE", + database_name=sql_identifier.SqlIdentifier("SERVICE_DB"), + schema_name=sql_identifier.SqlIdentifier("SERVICE_SCHEMA"), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + include_message=False, + statement_params=self.m_statement_params, + ) + + def test_create_service_model_db_and_schema(self) -> None: + with mock.patch.object(self.m_ops._stage_client, "create_tmp_stage",) as mock_create_stage, mock.patch.object( + snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_STAGE_ABCDEF0123" + ), mock.patch.object(self.m_ops._model_deployment_spec, "save",) as mock_save, mock.patch.object( + file_utils, "upload_directory_to_stage", return_value=None + ) as mock_upload_directory_to_stage, mock.patch.object( + self.m_ops._service_client, + "deploy_model", + return_value=(str(uuid.uuid4()), mock.MagicMock(spec=snowpark.AsyncJob)), + ) as mock_deploy_model, mock.patch.object( + self.m_ops._service_client, + "get_service_status", + return_value=(service_sql.ServiceStatus.PENDING, None), + ) as mock_get_service_status: + self.m_ops.create_service( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("VERSION"), + service_database_name=None, + service_schema_name=None, + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), + service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + image_repo_database_name=None, + image_repo_schema_name=None, + image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + ingress_enabled=True, + max_instances=1, + cpu_requests="1", + memory_requests="6GiB", + gpu_requests="1", + num_workers=1, + max_batch_rows=1024, + force_rebuild=True, + build_external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], + statement_params=self.m_statement_params, + ) + mock_create_stage.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + stage_name=sql_identifier.SqlIdentifier("SNOWPARK_TEMP_STAGE_ABCDEF0123"), + statement_params=self.m_statement_params, + ) + mock_save.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("VERSION"), + service_database_name=sql_identifier.SqlIdentifier("DB"), + service_schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), + service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + image_repo_database_name=sql_identifier.SqlIdentifier("DB"), + image_repo_schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + ingress_enabled=True, + max_instances=1, + cpu="1", + memory="6GiB", + gpu="1", + num_workers=1, + max_batch_rows=1024, + force_rebuild=True, + external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], + ) + mock_upload_directory_to_stage.assert_called_once_with( + self.c_session, + local_path=self.m_ops._model_deployment_spec.workspace_path, + stage_path=pathlib.PurePosixPath( + self.m_ops._stage_client.fully_qualified_object_name( + sql_identifier.SqlIdentifier("DB"), + sql_identifier.SqlIdentifier("SCHEMA"), + sql_identifier.SqlIdentifier("SNOWPARK_TEMP_STAGE_ABCDEF0123"), + ) + ), + statement_params=self.m_statement_params, + ) + mock_deploy_model.assert_called_once_with( + stage_path="DB.SCHEMA.SNOWPARK_TEMP_STAGE_ABCDEF0123", + model_deployment_spec_file_rel_path=self.m_ops._model_deployment_spec.DEPLOY_SPEC_FILE_REL_PATH, + statement_params=self.m_statement_params, + ) + mock_get_service_status.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + include_message=False, + statement_params=self.m_statement_params, + ) + + def test_create_service_default_db_and_schema(self) -> None: + with mock.patch.object(self.m_ops._stage_client, "create_tmp_stage",) as mock_create_stage, mock.patch.object( + snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_STAGE_ABCDEF0123" + ), mock.patch.object(self.m_ops._model_deployment_spec, "save",) as mock_save, mock.patch.object( + file_utils, "upload_directory_to_stage", return_value=None + ) as mock_upload_directory_to_stage, mock.patch.object( + self.m_ops._service_client, + "deploy_model", + return_value=(str(uuid.uuid4()), mock.MagicMock(spec=snowpark.AsyncJob)), + ) as mock_deploy_model, mock.patch.object( + self.m_ops._service_client, + "get_service_status", + return_value=(service_sql.ServiceStatus.PENDING, None), + ) as mock_get_service_status: + self.m_ops.create_service( + database_name=None, + schema_name=None, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("VERSION"), + service_database_name=None, + service_schema_name=None, + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), + service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + image_repo_database_name=None, + image_repo_schema_name=None, + image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + ingress_enabled=True, + max_instances=1, + cpu_requests="1", + memory_requests="6GiB", + gpu_requests="1", + num_workers=1, + max_batch_rows=1024, + force_rebuild=True, + build_external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], + statement_params=self.m_statement_params, + ) + mock_create_stage.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + stage_name=sql_identifier.SqlIdentifier("SNOWPARK_TEMP_STAGE_ABCDEF0123"), + statement_params=self.m_statement_params, + ) + mock_save.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("VERSION"), + service_database_name=sql_identifier.SqlIdentifier("TEMP"), + service_schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), + service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + image_repo_database_name=sql_identifier.SqlIdentifier("TEMP"), + image_repo_schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + ingress_enabled=True, + max_instances=1, + cpu="1", + memory="6GiB", + gpu="1", + num_workers=1, + max_batch_rows=1024, + force_rebuild=True, + external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], + ) + mock_upload_directory_to_stage.assert_called_once_with( + self.c_session, + local_path=self.m_ops._model_deployment_spec.workspace_path, + stage_path=pathlib.PurePosixPath( + self.m_ops._stage_client.fully_qualified_object_name( + sql_identifier.SqlIdentifier("TEMP"), + sql_identifier.SqlIdentifier("test", case_sensitive=True), + sql_identifier.SqlIdentifier("SNOWPARK_TEMP_STAGE_ABCDEF0123"), + ) + ), + statement_params=self.m_statement_params, + ) + mock_deploy_model.assert_called_once_with( + stage_path='TEMP."test".SNOWPARK_TEMP_STAGE_ABCDEF0123', + model_deployment_spec_file_rel_path=self.m_ops._model_deployment_spec.DEPLOY_SPEC_FILE_REL_PATH, + statement_params=self.m_statement_params, + ) + mock_get_service_status.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), include_message=False, statement_params=self.m_statement_params, ) diff --git a/snowflake/ml/model/_client/sql/model.py b/snowflake/ml/model/_client/sql/model.py index d630fe44..5646adac 100644 --- a/snowflake/ml/model/_client/sql/model.py +++ b/snowflake/ml/model/_client/sql/model.py @@ -17,8 +17,6 @@ class ModelSQLClient(_base._BaseSQLClient): MODEL_VERSION_ALIASES_COL_NAME = "aliases" MODEL_VERSION_INFERENCE_SERVICES_COL_NAME = "inference_services" - MODEL_INFERENCE_SERVICE_ENDPOINT_COL_NAME = "name" - def show_models( self, *, @@ -85,18 +83,6 @@ def show_versions( return res.validate() - def show_endpoints( - self, - *, - service_name: str, - ) -> List[row.Row]: - res = query_result_checker.SqlResultValidator( - self._session, - (f"SHOW ENDPOINTS IN SERVICE {service_name}"), - ).has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True) - - return res.validate() - def set_comment( self, *, diff --git a/snowflake/ml/model/_client/sql/service.py b/snowflake/ml/model/_client/sql/service.py index 6033bffc..177af2b0 100644 --- a/snowflake/ml/model/_client/sql/service.py +++ b/snowflake/ml/model/_client/sql/service.py @@ -10,7 +10,7 @@ sql_identifier, ) from snowflake.ml.model._client.sql import _base -from snowflake.snowpark import dataframe, functions as F, types as spt +from snowflake.snowpark import dataframe, functions as F, row, types as spt from snowflake.snowpark._internal import utils as snowpark_utils @@ -26,6 +26,9 @@ class ServiceStatus(enum.Enum): class ServiceSQLClient(_base._BaseSQLClient): + MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name" + MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url" + def build_model_container( self, *, @@ -216,3 +219,24 @@ def drop_service( f"DROP SERVICE {self.fully_qualified_object_name(database_name, schema_name, service_name)}", statement_params=statement_params, ).has_dimensions(expected_rows=1, expected_cols=1).validate() + + def show_endpoints( + self, + *, + database_name: Optional[sql_identifier.SqlIdentifier], + schema_name: Optional[sql_identifier.SqlIdentifier], + service_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> List[row.Row]: + fully_qualified_service_name = self.fully_qualified_object_name(database_name, schema_name, service_name) + res = ( + query_result_checker.SqlResultValidator( + self._session, + (f"SHOW ENDPOINTS IN SERVICE {fully_qualified_service_name}"), + statement_params=statement_params, + ) + .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True) + .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True) + ) + + return res.validate() diff --git a/snowflake/ml/model/_client/sql/service_test.py b/snowflake/ml/model/_client/sql/service_test.py index 385cf913..1cddf4c0 100644 --- a/snowflake/ml/model/_client/sql/service_test.py +++ b/snowflake/ml/model/_client/sql/service_test.py @@ -355,6 +355,46 @@ def test_drop_service(self) -> None: statement_params=m_statement_params, ) + def test_show_endpoints(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row(name="inference", ingress_url="foo.snowflakecomputing.app")], + collect_statement_params=m_statement_params, + ) + self.m_session.add_mock_sql( + """SHOW ENDPOINTS IN SERVICE TEMP."test".MYSERVICE""", + copy.deepcopy(m_df), + ) + c_session = cast(Session, self.m_session) + + service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).show_endpoints( + database_name=None, + schema_name=None, + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + statement_params=m_statement_params, + ) + + self.m_session.add_mock_sql( + """SHOW ENDPOINTS IN SERVICE TEMP."test".MYSERVICE""", + copy.deepcopy(m_df), + ) + c_session = cast(Session, self.m_session) + + service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("foo"), + schema_name=sql_identifier.SqlIdentifier("bar", case_sensitive=True), + ).show_endpoints( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + statement_params=m_statement_params, + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py index dec1fff7..198b9dc2 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py @@ -39,32 +39,44 @@ name="model1", model_type="custom", path="mock_path", handler_version="version_0" ) -_PACKAGING_REQUIREMENTS_TARGET_WITHOUT_SNOWML = ( - list( - sorted( - map( - lambda x: str(env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x))), - model_meta._PACKAGING_REQUIREMENTS, - ) +_PACKAGING_REQUIREMENTS_TARGET_WITHOUT_SNOWML = list( + sorted( + map( + lambda x: str(env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x))), + model_meta._PACKAGING_REQUIREMENTS, ) ) - + model_runtime._SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES +) + list( + sorted( + filter( + lambda x: not any( + dep in x for dep in model_runtime.PACKAGES_NOT_ALLOWED_IN_WAREHOUSE + model_meta._PACKAGING_REQUIREMENTS + ), + model_runtime._SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES, + ), + ) ) -_PACKAGING_REQUIREMENTS_TARGET_WITHOUT_SNOWML_RELAXED = ( - list( - sorted( - map( - lambda x: str( - env_utils.relax_requirement_version( - env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x)) - ) - ), - model_meta._PACKAGING_REQUIREMENTS, - ) +_PACKAGING_REQUIREMENTS_TARGET_WITHOUT_SNOWML_RELAXED = list( + sorted( + map( + lambda x: str( + env_utils.relax_requirement_version( + env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x)) + ) + ), + model_meta._PACKAGING_REQUIREMENTS, ) ) - + model_runtime._SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES +) + list( + sorted( + filter( + lambda x: not any( + dep in x for dep in model_runtime.PACKAGES_NOT_ALLOWED_IN_WAREHOUSE + model_meta._PACKAGING_REQUIREMENTS + ), + model_runtime._SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES, + ), + ) ) _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML = list( diff --git a/snowflake/ml/model/_model_composer/model_method/fixtures/function_1.py b/snowflake/ml/model/_model_composer/model_method/fixtures/function_1.py index 7fdafd2e..a904c202 100644 --- a/snowflake/ml/model/_model_composer/model_method/fixtures/function_1.py +++ b/snowflake/ml/model/_model_composer/model_method/fixtures/function_1.py @@ -5,6 +5,7 @@ import anyio import pandas as pd +import numpy as np from _snowflake import vectorized from snowflake.ml.model._packager import model_packager @@ -47,4 +48,4 @@ def infer(df: pd.DataFrame) -> dict: df.columns = input_cols input_df = df.astype(dtype=dtype_map) predictions_df = runner(input_df[input_cols]) - return predictions_df.to_dict("records") + return predictions_df.replace({pd.NA: None, np.nan: None}).to_dict("records") diff --git a/snowflake/ml/model/_model_composer/model_method/fixtures/function_2.py b/snowflake/ml/model/_model_composer/model_method/fixtures/function_2.py index 8977cc56..cf2f711c 100644 --- a/snowflake/ml/model/_model_composer/model_method/fixtures/function_2.py +++ b/snowflake/ml/model/_model_composer/model_method/fixtures/function_2.py @@ -5,6 +5,7 @@ import anyio import pandas as pd +import numpy as np from _snowflake import vectorized from snowflake.ml.model._packager import model_packager @@ -47,4 +48,4 @@ def infer(df: pd.DataFrame) -> dict: df.columns = input_cols input_df = df.astype(dtype=dtype_map) predictions_df = runner(input_df[input_cols]) - return predictions_df.to_dict("records") + return predictions_df.replace({pd.NA: None, np.nan: None}).to_dict("records") diff --git a/snowflake/ml/model/_model_composer/model_method/infer_function.py_template b/snowflake/ml/model/_model_composer/model_method/infer_function.py_template index 6196b504..925cc2cb 100644 --- a/snowflake/ml/model/_model_composer/model_method/infer_function.py_template +++ b/snowflake/ml/model/_model_composer/model_method/infer_function.py_template @@ -5,6 +5,7 @@ import sys import anyio import pandas as pd +import numpy as np from _snowflake import vectorized from snowflake.ml.model._packager import model_packager @@ -47,4 +48,4 @@ def {function_name}(df: pd.DataFrame) -> dict: df.columns = input_cols input_df = df.astype(dtype=dtype_map) predictions_df = runner(input_df[input_cols]) - return predictions_df.to_dict("records") + return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records") diff --git a/snowflake/ml/model/_packager/model_env/model_env.py b/snowflake/ml/model/_packager/model_env/model_env.py index 83fe3479..1faae459 100644 --- a/snowflake/ml/model/_packager/model_env/model_env.py +++ b/snowflake/ml/model/_packager/model_env/model_env.py @@ -174,6 +174,18 @@ def include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = Fal except env_utils.DuplicateDependencyError: pass + def remove_if_present_conda(self, conda_pkgs: List[str]) -> None: + """Remove conda requirements from model env if present. + + Args: + conda_pkgs: A list of package name to be removed from conda requirements. + """ + for pkg_name in conda_pkgs: + spec_conda = env_utils._find_conda_dep_spec(self._conda_dependencies, pkg_name) + if spec_conda: + channel, spec = spec_conda + self._conda_dependencies[channel].remove(spec) + def generate_env_for_cuda(self) -> None: if self.cuda_version is None: return diff --git a/snowflake/ml/model/_packager/model_env/model_env_test.py b/snowflake/ml/model/_packager/model_env/model_env_test.py index 527cdf75..965faff8 100644 --- a/snowflake/ml/model/_packager/model_env/model_env_test.py +++ b/snowflake/ml/model/_packager/model_env/model_env_test.py @@ -469,6 +469,56 @@ def test_include_if_absent_pip_check_local(self) -> None: self.assertListEqual(env.conda_dependencies, []) self.assertListEqual(env.pip_requirements, ["numpy==1.0.1"]) + def test_remove_if_present_conda(self) -> None: + env = model_env.ModelEnv() + env.conda_dependencies = ["some-package==1.0.1"] + + env.remove_if_present_conda(["some-package"]) + self.assertListEqual(env.conda_dependencies, []) + self.assertListEqual(env.pip_requirements, []) + + env = model_env.ModelEnv() + env.conda_dependencies = ["some-package==1.0.1"] + + env.remove_if_present_conda(["some-package"]) + self.assertListEqual(env.conda_dependencies, []) + self.assertListEqual(env.pip_requirements, []) + + env = model_env.ModelEnv() + env.conda_dependencies = ["some-package==1.0.1"] + + env.remove_if_present_conda(["another-package"]) + self.assertListEqual(env.conda_dependencies, ["some-package==1.0.1"]) + self.assertListEqual(env.pip_requirements, []) + + env = model_env.ModelEnv() + env.conda_dependencies = ["another-package<2,>=1.0", "some-package==1.0.1"] + + env.remove_if_present_conda(["some-package", "another-package"]) + self.assertListEqual(env.conda_dependencies, []) + self.assertListEqual(env.pip_requirements, []) + + env = model_env.ModelEnv() + env.conda_dependencies = ["another-package<2,>=1.0", "some-package==1.0.1"] + + env.remove_if_present_conda(["another-package"]) + self.assertListEqual(env.conda_dependencies, ["some-package==1.0.1"]) + self.assertListEqual(env.pip_requirements, []) + + env = model_env.ModelEnv() + env.conda_dependencies = ["channel::some-package==1.0.1"] + + env.remove_if_present_conda(["some-package"]) + self.assertListEqual(env.conda_dependencies, []) + self.assertListEqual(env.pip_requirements, []) + + env = model_env.ModelEnv() + env.pip_requirements = ["some-package==1.0.1"] + + env.remove_if_present_conda(["some-package"]) + self.assertListEqual(env.conda_dependencies, []) + self.assertListEqual(env.pip_requirements, ["some-package==1.0.1"]) + def test_generate_conda_env_for_cuda(self) -> None: env = model_env.ModelEnv() env.conda_dependencies = ["somepackage==1.0.0", "another_channel::another_package==1.0.0"] diff --git a/snowflake/ml/model/_packager/model_handlers/_utils.py b/snowflake/ml/model/_packager/model_handlers/_utils.py index fad26efc..fd06bd01 100644 --- a/snowflake/ml/model/_packager/model_handlers/_utils.py +++ b/snowflake/ml/model/_packager/model_handlers/_utils.py @@ -179,7 +179,7 @@ def convert_explanations_to_2D_df( return pd.DataFrame(explanations) if hasattr(model, "classes_"): - classes_list = [str(cl) for cl in model.classes_] # type:ignore[union-attr] + classes_list = [str(cl) for cl in model.classes_] len_classes = len(classes_list) if explanations.shape[2] != len_classes: raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}") diff --git a/snowflake/ml/model/_packager/model_handlers/catboost.py b/snowflake/ml/model/_packager/model_handlers/catboost.py index 78e73e4c..80f4ce73 100644 --- a/snowflake/ml/model/_packager/model_handlers/catboost.py +++ b/snowflake/ml/model/_packager/model_handlers/catboost.py @@ -95,7 +95,7 @@ def get_prediction( get_prediction_fn=get_prediction, ) model_task_and_output = model_task_utils.get_model_task_and_output_type(model) - model_meta.task = model_task_and_output.task + model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task) if enable_explainability: explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS) model_meta = handlers_utils.add_explain_method_signature( diff --git a/snowflake/ml/model/_packager/model_handlers/custom.py b/snowflake/ml/model/_packager/model_handlers/custom.py index 8ea56f13..747a47cd 100644 --- a/snowflake/ml/model/_packager/model_handlers/custom.py +++ b/snowflake/ml/model/_packager/model_handlers/custom.py @@ -2,7 +2,7 @@ import os import pathlib import sys -from typing import Dict, Optional, Type, final +from typing import Dict, Optional, Type, cast, final import anyio import cloudpickle @@ -108,6 +108,7 @@ def get_prediction( model_meta=model_meta, model_blobs_dir_path=model_blobs_dir_path, is_sub_model=True, + **cast(model_types.BaseModelSaveOption, kwargs), ) # Make sure that the module where the model is defined get pickled by value as well. @@ -175,6 +176,7 @@ def load_model( name=sub_model_name, model_meta=model_meta, model_blobs_dir_path=model_blobs_dir_path, + **cast(model_types.BaseModelLoadOption, kwargs), ) models[sub_model_name] = sub_model reconstructed_context = custom_model.ModelContext(artifacts=artifacts, models=models) diff --git a/snowflake/ml/model/_packager/model_handlers/lightgbm.py b/snowflake/ml/model/_packager/model_handlers/lightgbm.py index faff74dc..5c709703 100644 --- a/snowflake/ml/model/_packager/model_handlers/lightgbm.py +++ b/snowflake/ml/model/_packager/model_handlers/lightgbm.py @@ -196,13 +196,14 @@ def load_model( with open(model_blob_file_path, "rb") as f: model = cloudpickle.load(f) assert isinstance(model, getattr(lightgbm, lightgbm_estimator_type)) + assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel) return model @classmethod def convert_as_custom_model( cls, - raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"], + raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"], model_meta: model_meta_api.ModelMetadata, background_data: Optional[pd.DataFrame] = None, **kwargs: Unpack[model_types.LGBMModelLoadOptions], diff --git a/snowflake/ml/model/_packager/model_handlers/sklearn.py b/snowflake/ml/model/_packager/model_handlers/sklearn.py index 66d3b8a2..9e814b7b 100644 --- a/snowflake/ml/model/_packager/model_handlers/sklearn.py +++ b/snowflake/ml/model/_packager/model_handlers/sklearn.py @@ -19,12 +19,26 @@ ) from snowflake.ml.model._packager.model_task import model_task_utils from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils +from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR if TYPE_CHECKING: import sklearn.base import sklearn.pipeline +def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "sklearn.pipeline.Pipeline": + new_steps = [] + for step_name, step in model.steps: + new_reg = step + if hasattr(step, "_sklearn_estimator") and step._sklearn_estimator is not None: + # Unpack estimator to open source. + new_reg = step._sklearn_estimator + new_steps.append((step_name, new_reg)) + + model.steps = new_steps + return model + + @final class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]]): """Handler for scikit-learn based model. @@ -101,6 +115,10 @@ def save_model( if sample_input_data is None: raise ValueError("Sample input data is required to enable explainability.") + # If this is a pipeline and we are in the container runtime, check for distributed estimator. + if os.getenv(IN_ML_RUNTIME_ENV_VAR) and isinstance(model, sklearn.pipeline.Pipeline): + model = _unpack_container_runtime_pipeline(model) + if not is_sub_model: target_methods = handlers_utils.get_target_methods( model=model, @@ -135,7 +153,7 @@ def get_prediction( ) model_task_and_output_type = model_task_utils.get_model_task_and_output_type(model) - model_meta.task = model_task_and_output_type.task + model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task) # if users did not ask then we enable if we have background data if enable_explainability is None: @@ -177,6 +195,35 @@ def get_prediction( model_meta.models[name] = base_meta model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION + # if model instance is a pipeline, check the pipeline steps + if isinstance(model, sklearn.pipeline.Pipeline): + for _, pipeline_step in model.steps: + if type_utils.LazyType("lightgbm.LGBMModel").isinstance(pipeline_step) or type_utils.LazyType( + "lightgbm.Booster" + ).isinstance(pipeline_step): + model_meta.env.include_if_absent( + [ + model_env.ModelDependency(requirement="lightgbm", pip_name="lightgbm"), + ], + check_local_version=True, + ) + elif type_utils.LazyType("xgboost.XGBModel").isinstance(pipeline_step) or type_utils.LazyType( + "xgboost.Booster" + ).isinstance(pipeline_step): + model_meta.env.include_if_absent( + [ + model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"), + ], + check_local_version=True, + ) + elif type_utils.LazyType("catboost.CatBoost").isinstance(pipeline_step): + model_meta.env.include_if_absent( + [ + model_env.ModelDependency(requirement="catboost", pip_name="catboost"), + ], + check_local_version=True, + ) + if enable_explainability: model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")]) model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP diff --git a/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py b/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py index 02fbe014..d49ecbe1 100644 --- a/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +++ b/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py @@ -138,7 +138,7 @@ def save_model( enable_explainability = False else: model_task_and_output_type = model_task_utils.get_model_task_and_output_type(python_base_obj) - model_meta.task = model_task_and_output_type.task + model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task) explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS) model_meta = handlers_utils.add_explain_method_signature( model_meta=model_meta, diff --git a/snowflake/ml/model/_packager/model_handlers/tensorflow.py b/snowflake/ml/model/_packager/model_handlers/tensorflow.py index 9360da17..459ebd3e 100644 --- a/snowflake/ml/model/_packager/model_handlers/tensorflow.py +++ b/snowflake/ml/model/_packager/model_handlers/tensorflow.py @@ -13,6 +13,7 @@ from snowflake.ml.model._packager.model_meta import ( model_blob_meta, model_meta as model_meta_api, + model_meta_schema, ) from snowflake.ml.model._signatures import ( numpy_handler, @@ -76,7 +77,11 @@ def save_model( assert isinstance(model, tensorflow.Module) - if isinstance(model, tensorflow.keras.Model): + is_keras_model = type_utils.LazyType("tensorflow.keras.Model").isinstance(model) or type_utils.LazyType( + "tf_keras.Model" + ).isinstance(model) + + if is_keras_model: default_target_methods = ["predict"] else: default_target_methods = cls.DEFAULT_TARGET_METHODS @@ -117,8 +122,14 @@ def get_prediction( model_blob_path = os.path.join(model_blobs_dir_path, name) os.makedirs(model_blob_path, exist_ok=True) - if isinstance(model, tensorflow.keras.Model): + if is_keras_model: tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)) + model_meta.env.include_if_absent( + [ + model_env.ModelDependency(requirement="keras<3", pip_name="keras"), + ], + check_local_version=False, + ) else: tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)) @@ -127,12 +138,16 @@ def get_prediction( model_type=cls.HANDLER_TYPE, handler_version=cls.HANDLER_VERSION, path=cls.MODEL_BLOB_FILE_OR_DIR, + options=model_meta_schema.TensorflowModelBlobOptions(is_keras_model=is_keras_model), ) model_meta.models[name] = base_meta model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION model_meta.env.include_if_absent( - [model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow")], check_local_version=True + [ + model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"), + ], + check_local_version=True, ) model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION) @@ -150,9 +165,11 @@ def load_model( model_blobs_metadata = model_meta.models model_blob_metadata = model_blobs_metadata[name] model_blob_filename = model_blob_metadata.path - m = tensorflow.keras.models.load_model(os.path.join(model_blob_path, model_blob_filename), compile=False) - if isinstance(m, tensorflow.keras.Model): - return m + model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options) + if model_blob_options.get("is_keras_model", False): + m = tensorflow.keras.models.load_model(os.path.join(model_blob_path, model_blob_filename), compile=False) + else: + m = tensorflow.saved_model.load(os.path.join(model_blob_path, model_blob_filename)) return cast(tensorflow.Module, m) @classmethod diff --git a/snowflake/ml/model/_packager/model_handlers/torchscript.py b/snowflake/ml/model/_packager/model_handlers/torchscript.py index 318abe80..0e526e03 100644 --- a/snowflake/ml/model/_packager/model_handlers/torchscript.py +++ b/snowflake/ml/model/_packager/model_handlers/torchscript.py @@ -23,7 +23,7 @@ @final -class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # type:ignore[name-defined] +class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): """Handler for PyTorch JIT based model. Currently torch.jit.ScriptModule based classes are supported. @@ -41,25 +41,25 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t def can_handle( cls, model: model_types.SupportedModelType, - ) -> TypeGuard["torch.jit.ScriptModule"]: # type:ignore[name-defined] + ) -> TypeGuard["torch.jit.ScriptModule"]: return type_utils.LazyType("torch.jit.ScriptModule").isinstance(model) @classmethod def cast_model( cls, model: model_types.SupportedModelType, - ) -> "torch.jit.ScriptModule": # type:ignore[name-defined] + ) -> "torch.jit.ScriptModule": import torch - assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined] + assert isinstance(model, torch.jit.ScriptModule) - return cast(torch.jit.ScriptModule, model) # type:ignore[name-defined] + return cast(torch.jit.ScriptModule, model) @classmethod def save_model( cls, name: str, - model: "torch.jit.ScriptModule", # type:ignore[name-defined] + model: "torch.jit.ScriptModule", model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str, sample_input_data: Optional[model_types.SupportedDataType] = None, @@ -72,7 +72,7 @@ def save_model( import torch - assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined] + assert isinstance(model, torch.jit.ScriptModule) if not is_sub_model: target_methods = handlers_utils.get_target_methods( @@ -111,7 +111,7 @@ def get_prediction( model_blob_path = os.path.join(model_blobs_dir_path, name) os.makedirs(model_blob_path, exist_ok=True) with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f: - torch.jit.save(model, f) # type:ignore[no-untyped-call, attr-defined] + torch.jit.save(model, f) # type:ignore[no-untyped-call] base_meta = model_blob_meta.ModelBlobMeta( name=name, model_type=cls.HANDLER_TYPE, @@ -133,7 +133,7 @@ def load_model( model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str, **kwargs: Unpack[model_types.TorchScriptLoadOptions], - ) -> "torch.jit.ScriptModule": # type:ignore[name-defined] + ) -> "torch.jit.ScriptModule": import torch model_blob_path = os.path.join(model_blobs_dir_path, name) @@ -141,10 +141,10 @@ def load_model( model_blob_metadata = model_blobs_metadata[name] model_blob_filename = model_blob_metadata.path with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f: - m = torch.jit.load( # type:ignore[no-untyped-call, attr-defined] + m = torch.jit.load( # type:ignore[no-untyped-call] f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu" ) - assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined] + assert isinstance(m, torch.jit.ScriptModule) if kwargs.get("use_gpu", False): m = m.cuda() @@ -154,7 +154,7 @@ def load_model( @classmethod def convert_as_custom_model( cls, - raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined] + raw_model: "torch.jit.ScriptModule", model_meta: model_meta_api.ModelMetadata, background_data: Optional[pd.DataFrame] = None, **kwargs: Unpack[model_types.TorchScriptLoadOptions], @@ -162,11 +162,11 @@ def convert_as_custom_model( from snowflake.ml.model import custom_model def _create_custom_model( - raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined] + raw_model: "torch.jit.ScriptModule", model_meta: model_meta_api.ModelMetadata, ) -> Type[custom_model.CustomModel]: def fn_factory( - raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined] + raw_model: "torch.jit.ScriptModule", signature: model_signature.ModelSignature, target_method: str, ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]: diff --git a/snowflake/ml/model/_packager/model_handlers_test/BUILD.bazel b/snowflake/ml/model/_packager/model_handlers_test/BUILD.bazel index da25f4db..73fb8b81 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/BUILD.bazel +++ b/snowflake/ml/model/_packager/model_handlers_test/BUILD.bazel @@ -70,6 +70,7 @@ py_test( "//snowflake/ml/model:model_signature", "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_packager:model_packager", + "//snowflake/ml/model/_packager/model_handlers:sklearn", ], ) diff --git a/snowflake/ml/model/_packager/model_handlers_test/catboost_test.py b/snowflake/ml/model/_packager/model_handlers_test/catboost_test.py index 94fac2bc..6e2acc1a 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/catboost_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/catboost_test.py @@ -91,6 +91,7 @@ def test_catboost_classifier_explain_disabled(self) -> None: predict_method = getattr(pk.model, "predict_proba", None) assert callable(predict_method) np.testing.assert_allclose(predict_method(cal_X_test), y_pred_proba) + self.assertEqual(pk.meta.task, model_types.Task.TABULAR_BINARY_CLASSIFICATION) def test_catboost_explainablity_enabled(self) -> None: cal_data = datasets.load_breast_cancer() diff --git a/snowflake/ml/model/_packager/model_handlers_test/lightgbm_test.py b/snowflake/ml/model/_packager/model_handlers_test/lightgbm_test.py index 808100d1..177e552a 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/lightgbm_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/lightgbm_test.py @@ -85,6 +85,7 @@ def test_lightgbm_booster_explainability_disabled(self) -> None: predict_method = getattr(pk.model, "predict", None) assert callable(predict_method) np.testing.assert_allclose(predict_method(cal_X_test), np.expand_dims(y_pred, axis=1)) + self.assertEqual(pk.meta.task, model_types.Task.TABULAR_BINARY_CLASSIFICATION) def test_lightgbm_booster_explainablity_enabled(self) -> None: cal_data = datasets.load_breast_cancer() diff --git a/snowflake/ml/model/_packager/model_handlers_test/pytorch_test.py b/snowflake/ml/model/_packager/model_handlers_test/pytorch_test.py index 6e4d9a1e..67f409d6 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/pytorch_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/pytorch_test.py @@ -143,7 +143,7 @@ def test_pytorch(self) -> None: def test_torch_df_sample_input(self) -> None: model, data_x, data_y = _prepare_torch_model(torch.float64) - model_script = torch.jit.script(model) # type:ignore[attr-defined] + model_script = torch.jit.script(model) s = {"forward": model_signature.infer_signature([data_x], [data_y])} with tempfile.TemporaryDirectory() as tmpdir: @@ -192,7 +192,7 @@ def test_torch_df_sample_input(self) -> None: pk.load() assert pk.model assert pk.meta - assert isinstance(pk.model, torch.jit.ScriptModule) # type:ignore[attr-defined] + assert isinstance(pk.model, torch.jit.ScriptModule) torch.testing.assert_close(pk.model.forward(data_x), y_pred) pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig_2")) diff --git a/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py b/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py index a59f381d..f2f7987e 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py @@ -8,9 +8,13 @@ import shap from absl.testing import absltest from sklearn import datasets, ensemble, linear_model, multioutput +from sklearn.pipeline import Pipeline from snowflake.ml.model import model_signature, type_hints as model_types from snowflake.ml.model._packager import model_packager +from snowflake.ml.model._packager.model_handlers.sklearn import ( + _unpack_container_runtime_pipeline, +) class SKLearnHandlerTest(absltest.TestCase): @@ -344,6 +348,20 @@ def test_skl_no_default_explain_without_background_data(self) -> None: assert callable(predict_method) self.assertEqual(explain_method, None) + def test_skl_with_cr_estimator(self) -> None: + class SecondMockEstimator: + ... + + class MockEstimator: + @property + def _sklearn_estimator(self) -> SecondMockEstimator: + return SecondMockEstimator() + + skl_pipeline = Pipeline(steps=[("mock", MockEstimator())]) + oss_pipeline = _unpack_container_runtime_pipeline(skl_pipeline) + + assert isinstance(oss_pipeline.steps[0][1], SecondMockEstimator) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py b/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py index c0434346..07ae8166 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py @@ -8,7 +8,7 @@ from absl.testing import absltest from sklearn import datasets -from snowflake.ml.model import model_signature +from snowflake.ml.model import model_signature, type_hints as model_types from snowflake.ml.model._packager import model_packager from snowflake.ml.modeling.linear_model import ( # type:ignore[attr-defined] LinearRegression, @@ -49,6 +49,7 @@ def test_snowml_all_input_no_explain(self) -> None: model=regr, metadata={"author": "halu", "version": "1"}, options={"enable_explainability": False}, + task=model_types.Task.TABULAR_REGRESSION, ) with warnings.catch_warnings(): warnings.simplefilter("error") @@ -67,6 +68,8 @@ def test_snowml_all_input_no_explain(self) -> None: predict_method = getattr(pk.model, "predict", None) assert callable(predict_method) np.testing.assert_allclose(predictions, predict_method(df[:1])[[OUTPUT_COLUMNS]]) + # correctly set when specified + self.assertEqual(pk.meta.task, model_types.Task.TABULAR_REGRESSION) def test_snowml_signature_partial_input(self) -> None: iris = datasets.load_iris() @@ -171,6 +174,7 @@ def test_snowml_xgboost_explain_default(self) -> None: name="model1", model=regr, metadata={"author": "halu", "version": "1"}, + task=model_types.Task.TABULAR_BINARY_CLASSIFICATION, # incorrect type but should be inferred properly ) with warnings.catch_warnings(): warnings.simplefilter("error") @@ -185,6 +189,8 @@ def test_snowml_xgboost_explain_default(self) -> None: assert callable(explain_method) np.testing.assert_allclose(predictions, predict_method(df[:1])[[OUTPUT_COLUMNS]]) np.testing.assert_allclose(explanations, explain_method(df[INPUT_COLUMNS]).values) + # correctly set even when incorrect + self.assertEqual(pk.meta.task, model_types.Task.TABULAR_REGRESSION) def test_snowml_all_input_with_explain(self) -> None: iris = datasets.load_iris() diff --git a/snowflake/ml/model/_packager/model_handlers_test/torchscript_test.py b/snowflake/ml/model/_packager/model_handlers_test/torchscript_test.py index db0702fb..6d3d1f3e 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/torchscript_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/torchscript_test.py @@ -52,7 +52,7 @@ def _prepare_torch_model( class TorchScriptHandlerTest(absltest.TestCase): def test_torchscript(self) -> None: model, data_x, data_y = _prepare_torch_model() - model_script = torch.jit.script(model) # type:ignore[attr-defined] + model_script = torch.jit.script(model) with tempfile.TemporaryDirectory() as tmpdir: s = {"forward": model_signature.infer_signature([data_x], [data_y])} @@ -94,7 +94,7 @@ def test_torchscript(self) -> None: pk.load() assert pk.model assert pk.meta - assert isinstance(pk.model, torch.jit.ScriptModule) # type:ignore[attr-defined] + assert isinstance(pk.model, torch.jit.ScriptModule) torch.testing.assert_close(pk.model.forward(data_x), y_pred) with self.assertRaisesRegex(RuntimeError, "Attempting to deserialize object on a CUDA device"): @@ -124,7 +124,7 @@ def test_torchscript(self) -> None: pk.load() assert pk.model assert pk.meta - assert isinstance(pk.model, torch.jit.ScriptModule) # type:ignore[attr-defined] + assert isinstance(pk.model, torch.jit.ScriptModule) torch.testing.assert_close(pk.model.forward(data_x), y_pred) self.assertEqual(s["forward"], pk.meta.signatures["forward"]) diff --git a/snowflake/ml/model/_packager/model_handlers_test/xgboost_test.py b/snowflake/ml/model/_packager/model_handlers_test/xgboost_test.py index fa8560a8..817d4084 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/xgboost_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/xgboost_test.py @@ -22,14 +22,14 @@ def test_xgb_booster_explainability_disabled(self) -> None: cal_y = pd.Series(cal_data.target) cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) params = dict(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3, objective="binary:logistic") - regressor = xgboost.train(params, xgboost.DMatrix(data=cal_X_train, label=cal_y_train)) - y_pred = regressor.predict(xgboost.DMatrix(data=cal_X_test)) + classifier = xgboost.train(params, xgboost.DMatrix(data=cal_X_train, label=cal_y_train)) + y_pred = classifier.predict(xgboost.DMatrix(data=cal_X_test)) with tempfile.TemporaryDirectory() as tmpdir: s = {"predict": model_signature.infer_signature(cal_X_test, y_pred)} with self.assertRaises(ValueError): model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( name="model1", - model=regressor, + model=classifier, signatures={**s, "another_predict": s["predict"]}, metadata={"author": "halu", "version": "1"}, options=model_types.XGBModelSaveOptions(enable_explainability=False), @@ -37,7 +37,7 @@ def test_xgb_booster_explainability_disabled(self) -> None: model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( name="model1", - model=regressor, + model=classifier, signatures=s, metadata={"author": "halu", "version": "1"}, ) @@ -58,10 +58,12 @@ def test_xgb_booster_explainability_disabled(self) -> None: predict_method = getattr(pk.model, "predict", None) assert callable(predict_method) np.testing.assert_allclose(predict_method(cal_X_test), np.expand_dims(y_pred, axis=1)) + # test task is set even without explain + self.assertEqual(pk.meta.task, model_types.Task.TABULAR_BINARY_CLASSIFICATION) model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( name="model1_no_sig", - model=regressor, + model=classifier, sample_input_data=cal_X_test, metadata={"author": "halu", "version": "1"}, options=model_types.XGBModelSaveOptions(enable_explainability=False), @@ -155,6 +157,7 @@ def test_xgb_explainability_disabled(self) -> None: predict_method = getattr(pk.model, "predict_proba", None) assert callable(predict_method) np.testing.assert_allclose(predict_method(cal_X_test), y_pred_proba) + self.assertEqual(pk.meta.task, model_types.Task.TABULAR_BINARY_CLASSIFICATION) def test_xgb_explainablity_enabled(self) -> None: cal_data = datasets.load_breast_cancer() @@ -177,6 +180,7 @@ def test_xgb_explainablity_enabled(self) -> None: model=classifier, signatures={"predict": model_signature.infer_signature(cal_X_test, y_pred)}, metadata={"author": "halu", "version": "1"}, + task=model_types.Task.UNKNOWN, ) with warnings.catch_warnings(): @@ -213,6 +217,9 @@ def test_xgb_explainablity_enabled(self) -> None: explain_method = getattr(pk.model, "explain", None) assert callable(explain_method) np.testing.assert_allclose(explain_method(cal_X_test), explanations) + assert pk.meta + # correctly inferred even when unknown + self.assertEqual(pk.meta.task, model_types.Task.TABULAR_BINARY_CLASSIFICATION) def test_xgb_explainablity_multiclass(self) -> None: cal_data = datasets.load_iris() diff --git a/snowflake/ml/model/_packager/model_meta/BUILD.bazel b/snowflake/ml/model/_packager/model_meta/BUILD.bazel index 823c07c4..608583a1 100644 --- a/snowflake/ml/model/_packager/model_meta/BUILD.bazel +++ b/snowflake/ml/model/_packager/model_meta/BUILD.bazel @@ -2,7 +2,7 @@ load("//bazel:py_rules.bzl", "py_genrule", "py_library", "py_test") package(default_visibility = ["//visibility:public"]) -GEN_PACKAGING_REQ_CMD = "$(location //bazel/requirements:parse_and_generate_requirements) $(location //:requirements.yml) --schema $(location //bazel/requirements:requirements.schema.json) --mode version_requirements --format python --filter_by_tag model_packaging > $@" +GEN_PACKAGING_REQ_CMD = "$(location //bazel/requirements:parse_and_generate_requirements) $(location //:requirements.yml) --schema $(location //bazel/requirements:requirements.schema.json) --mode version_requirements --format python --snowflake_channel_only --filter_by_tag model_packaging > $@" py_genrule( name = "gen_packaging_requirements", diff --git a/snowflake/ml/model/_packager/model_meta/model_meta_schema.py b/snowflake/ml/model/_packager/model_meta/model_meta_schema.py index fc534c6e..762cb027 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta_schema.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta_schema.py @@ -58,11 +58,16 @@ class XgboostModelBlobOptions(BaseModelBlobOptions): xgb_estimator_type: Required[str] +class TensorflowModelBlobOptions(BaseModelBlobOptions): + is_keras_model: Required[bool] + + ModelBlobOptions = Union[ BaseModelBlobOptions, HuggingFacePipelineModelBlobOptions, MLFlowModelBlobOptions, XgboostModelBlobOptions, + TensorflowModelBlobOptions, ] diff --git a/snowflake/ml/model/_packager/model_runtime/BUILD.bazel b/snowflake/ml/model/_packager/model_runtime/BUILD.bazel index a6192a9b..36a5080f 100644 --- a/snowflake/ml/model/_packager/model_runtime/BUILD.bazel +++ b/snowflake/ml/model/_packager/model_runtime/BUILD.bazel @@ -2,7 +2,7 @@ load("//bazel:py_rules.bzl", "py_genrule", "py_library", "py_test") package(default_visibility = ["//visibility:public"]) -GEN_RUNTIME_REQ_CMD = "$(location //bazel/requirements:parse_and_generate_requirements) $(location //:requirements.yml) --schema $(location //bazel/requirements:requirements.schema.json) --mode version_requirements --format python --filter_by_tag snowml_inference_alternative > $@" +GEN_RUNTIME_REQ_CMD = "$(location //bazel/requirements:parse_and_generate_requirements) $(location //:requirements.yml) --schema $(location //bazel/requirements:requirements.schema.json) --mode version_requirements --format python --snowflake_channel_only > $@" py_genrule( name = "gen_snowml_inference_alternative_requirements", diff --git a/snowflake/ml/model/_packager/model_runtime/model_runtime.py b/snowflake/ml/model/_packager/model_runtime/model_runtime.py index 98c7cc7e..e9c27446 100644 --- a/snowflake/ml/model/_packager/model_runtime/model_runtime.py +++ b/snowflake/ml/model/_packager/model_runtime/model_runtime.py @@ -17,6 +17,8 @@ for r in _snowml_inference_alternative_requirements.REQUIREMENTS ] +PACKAGES_NOT_ALLOWED_IN_WAREHOUSE = ["snowflake-connector-python", "pyarrow"] + class ModelRuntime: """Class to represent runtime in a model, which controls the runtime and version, imports and dependencies. @@ -61,15 +63,8 @@ def __init__( ], ) - if not is_warehouse and self.embed_local_ml_library: - self.runtime_env.include_if_absent( - [ - model_env.ModelDependency( - requirement="pyarrow", - pip_name="pyarrow", - ) - ], - ) + if is_warehouse and self.embed_local_ml_library: + self.runtime_env.remove_if_present_conda(PACKAGES_NOT_ALLOWED_IN_WAREHOUSE) if is_gpu: self.runtime_env.generate_env_for_cuda() diff --git a/snowflake/ml/model/_packager/model_runtime/model_runtime_test.py b/snowflake/ml/model/_packager/model_runtime/model_runtime_test.py index 2181b571..2fa06525 100644 --- a/snowflake/ml/model/_packager/model_runtime/model_runtime_test.py +++ b/snowflake/ml/model/_packager/model_runtime/model_runtime_test.py @@ -19,6 +19,13 @@ ) ) +_BASIC_DEPENDENCIES_TARGET_RELAXED_WAREHOUSE = list( + filter( + lambda x: not any(dep in x for dep in model_runtime.PACKAGES_NOT_ALLOWED_IN_WAREHOUSE), + _BASIC_DEPENDENCIES_TARGET_RELAXED, + ) +) + class ModelRuntimeTest(absltest.TestCase): def test_model_runtime(self) -> None: @@ -138,7 +145,7 @@ def test_model_runtime_local_snowml(self) -> None: with open(os.path.join(workspace, "runtimes/cpu/env/conda.yml"), encoding="utf-8") as f: dependencies = yaml.safe_load(f) - self.assertContainsSubset(_BASIC_DEPENDENCIES_TARGET_RELAXED + ["pyarrow"], dependencies["dependencies"]) + self.assertContainsSubset(_BASIC_DEPENDENCIES_TARGET_RELAXED, dependencies["dependencies"]) def test_model_runtime_local_snowml_warehouse(self) -> None: with tempfile.TemporaryDirectory() as workspace: @@ -165,8 +172,7 @@ def test_model_runtime_local_snowml_warehouse(self) -> None: with open(os.path.join(workspace, "runtimes/cpu/env/conda.yml"), encoding="utf-8") as f: dependencies = yaml.safe_load(f) - self.assertContainsSubset(_BASIC_DEPENDENCIES_TARGET_RELAXED, dependencies["dependencies"]) - self.assertNotIn("pyarrow", dependencies["dependencies"]) + self.assertContainsSubset(_BASIC_DEPENDENCIES_TARGET_RELAXED_WAREHOUSE, dependencies["dependencies"]) def test_model_runtime_dup_basic_dep(self) -> None: with tempfile.TemporaryDirectory() as workspace: diff --git a/snowflake/ml/model/_packager/model_task/model_task_utils.py b/snowflake/ml/model/_packager/model_task/model_task_utils.py index 851d4fee..3c0f9cea 100644 --- a/snowflake/ml/model/_packager/model_task/model_task_utils.py +++ b/snowflake/ml/model/_packager/model_task/model_task_utils.py @@ -84,7 +84,7 @@ def get_model_task_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel if type_utils.LazyType("lightgbm.Booster").isinstance(model): model_task = model.params["objective"] # type: ignore[attr-defined] elif hasattr(model, "objective_"): - model_task = model.objective_ + model_task = model.objective_ # type: ignore[assignment] if model_task in _BINARY_CLASSIFICATION_OBJECTIVES: return type_hints.Task.TABULAR_BINARY_CLASSIFICATION if model_task in _MULTI_CLASSIFICATION_OBJECTIVES: diff --git a/snowflake/ml/model/_signatures/core.py b/snowflake/ml/model/_signatures/core.py index 5c4663b4..a03ff68f 100644 --- a/snowflake/ml/model/_signatures/core.py +++ b/snowflake/ml/model/_signatures/core.py @@ -14,10 +14,12 @@ Type, Union, final, + get_args, ) import numpy as np import numpy.typing as npt +import pandas as pd import snowflake.snowpark.types as spt from snowflake.ml._internal.exceptions import ( @@ -29,6 +31,21 @@ import mlflow import torch +PandasExtensionTypes = Union[ + pd.Int8Dtype, + pd.Int16Dtype, + pd.Int32Dtype, + pd.Int64Dtype, + pd.UInt8Dtype, + pd.UInt16Dtype, + pd.UInt32Dtype, + pd.UInt64Dtype, + pd.Float32Dtype, + pd.Float64Dtype, + pd.BooleanDtype, + pd.StringDtype, +] + class DataType(Enum): def __init__(self, value: str, snowpark_type: Type[spt.DataType], numpy_type: npt.DTypeLike) -> None: @@ -67,11 +84,11 @@ def __repr__(self) -> str: return f"DataType.{self.name}" @classmethod - def from_numpy_type(cls, np_type: npt.DTypeLike) -> "DataType": + def from_numpy_type(cls, input_type: Union[npt.DTypeLike, PandasExtensionTypes]) -> "DataType": """Translate numpy dtype to DataType for signature definition. Args: - np_type: The numpy dtype. + input_type: The numpy dtype or Pandas Extension Dtype Raises: SnowflakeMLException: NotImplementedError: Raised when the given numpy type is not supported. @@ -79,6 +96,10 @@ def from_numpy_type(cls, np_type: npt.DTypeLike) -> "DataType": Returns: Corresponding DataType. """ + # To support pandas extension dtype + if isinstance(input_type, get_args(PandasExtensionTypes)): + input_type = input_type.type + np_to_snowml_type_mapping = {i._numpy_type: i for i in DataType} # Add datetime types: @@ -88,12 +109,12 @@ def from_numpy_type(cls, np_type: npt.DTypeLike) -> "DataType": np_to_snowml_type_mapping[f"datetime64[{res}]"] = DataType.TIMESTAMP_NTZ for potential_type in np_to_snowml_type_mapping.keys(): - if np.can_cast(np_type, potential_type, casting="no"): + if np.can_cast(input_type, potential_type, casting="no"): # This is used since the same dtype might represented in different ways. return np_to_snowml_type_mapping[potential_type] raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.NOT_IMPLEMENTED, - original_exception=NotImplementedError(f"Type {np_type} is not supported as a DataType."), + original_exception=NotImplementedError(f"Type {input_type} is not supported as a DataType."), ) @classmethod @@ -212,6 +233,7 @@ def __init__( name: str, dtype: DataType, shape: Optional[Tuple[int, ...]] = None, + nullable: bool = True, ) -> None: """ Initialize a feature. @@ -219,6 +241,7 @@ def __init__( Args: name: Name of the feature. dtype: Type of the elements in the feature. + nullable: Whether the feature is nullable. Defaults to True. shape: Used to represent scalar feature, 1-d feature list, or n-d tensor. Use -1 to represent variable length. Defaults to None. @@ -227,6 +250,7 @@ def __init__( - (2,): 1d list with a fixed length of 2. - (-1,): 1d list with variable length, used for ragged tensor representation. - (d1, d2, d3): 3d tensor. + nullable: Whether the feature is nullable. Defaults to True. Raises: SnowflakeMLException: TypeError: When the dtype input type is incorrect. @@ -248,6 +272,8 @@ def __init__( ) self._shape = shape + self._nullable = nullable + def as_snowpark_type(self) -> spt.DataType: result_type = self._dtype.as_snowpark_type() if not self._shape: @@ -256,13 +282,34 @@ def as_snowpark_type(self) -> spt.DataType: result_type = spt.ArrayType(result_type) return result_type - def as_dtype(self) -> Union[npt.DTypeLike, str]: + def as_dtype(self) -> Union[npt.DTypeLike, str, PandasExtensionTypes]: """Convert to corresponding local Type.""" + if not self._shape: # scalar dtype: use keys from `np.sctypeDict` to prevent unit-less dtype 'datetime64' if "datetime64" in self._dtype._value: return self._dtype._value - return self._dtype._numpy_type + + np_type = self._dtype._numpy_type + if self._nullable: + np_to_pd_dtype_mapping = { + np.int8: pd.Int8Dtype(), + np.int16: pd.Int16Dtype(), + np.int32: pd.Int32Dtype(), + np.int64: pd.Int64Dtype(), + np.uint8: pd.UInt8Dtype(), + np.uint16: pd.UInt16Dtype(), + np.uint32: pd.UInt32Dtype(), + np.uint64: pd.UInt64Dtype(), + np.float32: pd.Float32Dtype(), + np.float64: pd.Float64Dtype(), + np.bool_: pd.BooleanDtype(), + np.str_: pd.StringDtype(), + } + + return np_to_pd_dtype_mapping.get(np_type, np_type) # type: ignore[arg-type] + + return np_type return np.object_ def __eq__(self, other: object) -> bool: @@ -273,7 +320,10 @@ def __eq__(self, other: object) -> bool: def __repr__(self) -> str: shape_str = f", shape={repr(self._shape)}" if self._shape else "" - return f"FeatureSpec(dtype={repr(self._dtype)}, name={repr(self._name)}{shape_str})" + return ( + f"FeatureSpec(dtype={repr(self._dtype)}, " + f"name={repr(self._name)}{shape_str}, nullable={repr(self._nullable)})" + ) def to_dict(self) -> Dict[str, Any]: """Serialize the feature group into a dict. @@ -281,10 +331,7 @@ def to_dict(self) -> Dict[str, Any]: Returns: A dict that serializes the feature group. """ - base_dict: Dict[str, Any] = { - "type": self._dtype.name, - "name": self._name, - } + base_dict: Dict[str, Any] = {"type": self._dtype.name, "name": self._name, "nullable": self._nullable} if self._shape is not None: base_dict["shape"] = self._shape return base_dict @@ -304,7 +351,9 @@ def from_dict(cls, input_dict: Dict[str, Any]) -> "FeatureSpec": if shape: shape = tuple(shape) type = DataType[input_dict["type"]] - return FeatureSpec(name=name, dtype=type, shape=shape) + # If nullable is not provided, default to False for backward compatibility. + nullable = input_dict.get("nullable", False) + return FeatureSpec(name=name, dtype=type, shape=shape, nullable=nullable) @classmethod def from_mlflow_spec( @@ -475,10 +524,8 @@ def from_dict(cls, loaded: Dict[str, Any]) -> "ModelSignature": sig_outs = loaded["outputs"] sig_inputs = loaded["inputs"] - deserialize_spec: Callable[[Dict[str, Any]], BaseFeatureSpec] = ( - lambda sig_spec: FeatureGroupSpec.from_dict(sig_spec) - if "feature_group" in sig_spec - else FeatureSpec.from_dict(sig_spec) + deserialize_spec: Callable[[Dict[str, Any]], BaseFeatureSpec] = lambda sig_spec: ( + FeatureGroupSpec.from_dict(sig_spec) if "feature_group" in sig_spec else FeatureSpec.from_dict(sig_spec) ) return ModelSignature( diff --git a/snowflake/ml/model/_signatures/core_test.py b/snowflake/ml/model/_signatures/core_test.py index 897ac23f..93270477 100644 --- a/snowflake/ml/model/_signatures/core_test.py +++ b/snowflake/ml/model/_signatures/core_test.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd from absl.testing import absltest import snowflake.snowpark.types as spt @@ -14,6 +15,12 @@ def test_numpy_type(self) -> None: data = np.array(["a", "b", "c", "d"]) self.assertEqual(core.DataType.STRING, core.DataType.from_numpy_type(data.dtype)) + data = pd.Series([1, 2, 3, 4]).convert_dtypes() + self.assertEqual(core.DataType.INT64, core.DataType.from_numpy_type(data.dtype)) + + data = pd.Series(["a", "b", "c", "d"]).convert_dtypes() + self.assertEqual(core.DataType.STRING, core.DataType.from_numpy_type(data.dtype)) + def test_snowpark_type(self) -> None: self.assertEqual(core.DataType.INT8, core.DataType.from_snowpark_type(spt.ByteType())) self.assertEqual(core.DataType.INT16, core.DataType.from_snowpark_type(spt.ShortType())) @@ -37,11 +44,19 @@ def test_feature_spec(self) -> None: self.assertEqual(ft, eval(repr(ft), core.__dict__)) self.assertEqual(ft, core.FeatureSpec.from_dict(ft.to_dict())) self.assertEqual(ft.as_snowpark_type(), spt.LongType()) + self.assertEqual(ft.as_dtype(), pd.Int64Dtype()) + + ft = core.FeatureSpec(name="feature", dtype=core.DataType.INT64, nullable=False) + self.assertEqual(ft, eval(repr(ft), core.__dict__)) + self.assertEqual(ft, core.FeatureSpec.from_dict(ft.to_dict())) + self.assertEqual(ft.as_snowpark_type(), spt.LongType()) + self.assertEqual(ft.as_dtype(), np.int64) ft = core.FeatureSpec(name="feature", dtype=core.DataType.INT64, shape=(2,)) self.assertEqual(ft, eval(repr(ft), core.__dict__)) self.assertEqual(ft, core.FeatureSpec.from_dict(input_dict=ft.to_dict())) self.assertEqual(ft.as_snowpark_type(), spt.ArrayType(spt.LongType())) + self.assertEqual(ft.as_dtype(), np.object_) class FeatureGroupSpecTest(absltest.TestCase): @@ -102,13 +117,11 @@ def test_1(self) -> None: core.FeatureGroupSpec( name="cg1", specs=[ - core.FeatureSpec( - dtype=core.DataType.FLOAT, - name="cc1", - ), + core.FeatureSpec(dtype=core.DataType.FLOAT, name="cc1", nullable=True), core.FeatureSpec( dtype=core.DataType.FLOAT, name="cc2", + nullable=False, ), ], ), @@ -118,16 +131,19 @@ def test_1(self) -> None: ) target = { "inputs": [ - {"type": "FLOAT", "name": "c1"}, + {"type": "FLOAT", "name": "c1", "nullable": True}, { "feature_group": { "name": "cg1", - "specs": [{"type": "FLOAT", "name": "cc1"}, {"type": "FLOAT", "name": "cc2"}], + "specs": [ + {"type": "FLOAT", "name": "cc1", "nullable": True}, + {"type": "FLOAT", "name": "cc2", "nullable": False}, + ], } }, - {"type": "FLOAT", "name": "c2", "shape": (-1,)}, + {"type": "FLOAT", "name": "c2", "shape": (-1,), "nullable": True}, ], - "outputs": [{"type": "FLOAT", "name": "output"}], + "outputs": [{"type": "FLOAT", "name": "output", "nullable": True}], } self.assertDictEqual(s.to_dict(), target) self.assertEqual(s, eval(repr(s), core.__dict__)) @@ -140,13 +156,11 @@ def test_2(self) -> None: core.FeatureGroupSpec( name="cg1", specs=[ - core.FeatureSpec( - dtype=core.DataType.FLOAT, - name="cc1", - ), + core.FeatureSpec(dtype=core.DataType.FLOAT, name="cc1", nullable=True), core.FeatureSpec( dtype=core.DataType.FLOAT, name="cc2", + nullable=False, ), ], ), diff --git a/snowflake/ml/model/_signatures/pandas_handler.py b/snowflake/ml/model/_signatures/pandas_handler.py index 059d4fbc..da761671 100644 --- a/snowflake/ml/model/_signatures/pandas_handler.py +++ b/snowflake/ml/model/_signatures/pandas_handler.py @@ -1,4 +1,5 @@ -from typing import Literal, Sequence +import warnings +from typing import Literal, Sequence, Union import numpy as np import pandas as pd @@ -14,8 +15,8 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]): @staticmethod - def can_handle(data: model_types.SupportedDataType) -> TypeGuard[pd.DataFrame]: - return isinstance(data, pd.DataFrame) + def can_handle(data: model_types.SupportedDataType) -> TypeGuard[Union[pd.DataFrame, pd.Series]]: + return isinstance(data, pd.DataFrame) or isinstance(data, pd.Series) @staticmethod def count(data: pd.DataFrame) -> int: @@ -26,7 +27,17 @@ def truncate(data: pd.DataFrame) -> pd.DataFrame: return data.head(min(PandasDataFrameHandler.count(data), PandasDataFrameHandler.SIG_INFER_ROWS_COUNT_LIMIT)) @staticmethod - def validate(data: pd.DataFrame) -> None: + def validate(data: Union[pd.DataFrame, pd.Series]) -> None: + if isinstance(data, pd.Series): + # check if the series is empty and throw error + if data.empty: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_DATA, + original_exception=ValueError("Data Validation Error: Empty data is found."), + ) + # convert the series to a dataframe + data = data.to_frame() + df_cols = data.columns if df_cols.has_duplicates: # Rule out categorical index with duplicates @@ -60,21 +71,44 @@ def validate(data: pd.DataFrame) -> None: df_col_dtypes = [data[col].dtype for col in data.columns] for df_col, df_col_dtype in zip(df_cols, df_col_dtypes): + df_col_data = data[df_col] + if df_col_data.isnull().all(): + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_DATA, + original_exception=ValueError( + f"Data Validation Error: There is no non-null data in column {df_col}." + ), + ) + if df_col_data.isnull().any(): + warnings.warn( + ( + f"Null value detected in column {df_col}, model signature inference might not accurate, " + "or your prediction might fail if your model does not support null input. If this is not " + "expected, please check your input dataframe." + ), + category=UserWarning, + stacklevel=2, + ) + + df_col_data = utils.series_dropna(df_col_data) + df_col_dtype = df_col_data.dtype + if df_col_dtype == np.dtype("O"): # Check if all objects have the same type - if not all(isinstance(data_row, type(data[df_col].iloc[0])) for data_row in data[df_col]): + if not all(isinstance(data_row, type(df_col_data.iloc[0])) for data_row in df_col_data): raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_DATA, original_exception=ValueError( - f"Data Validation Error: Inconsistent type of object found in column data {data[df_col]}." + "Data Validation Error: " + + f"Inconsistent type of element in object found in column data {df_col_data}." ), ) - if isinstance(data[df_col].iloc[0], list): - arr = utils.convert_list_to_ndarray(data[df_col].iloc[0]) + if isinstance(df_col_data.iloc[0], list): + arr = utils.convert_list_to_ndarray(df_col_data.iloc[0]) arr_dtype = core.DataType.from_numpy_type(arr.dtype) - converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in data[df_col]] + converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in df_col_data] if not all( core.DataType.from_numpy_type(converted_data.dtype) == arr_dtype @@ -84,32 +118,37 @@ def validate(data: pd.DataFrame) -> None: error_code=error_codes.INVALID_DATA, original_exception=ValueError( "Data Validation Error: " - + f"Inconsistent type of element in object found in column data {data[df_col]}." + + f"Inconsistent type of element in object found in column data {df_col_data}." ), ) - elif isinstance(data[df_col].iloc[0], np.ndarray): - arr_dtype = core.DataType.from_numpy_type(data[df_col].iloc[0].dtype) + elif isinstance(df_col_data.iloc[0], np.ndarray): + arr_dtype = core.DataType.from_numpy_type(df_col_data.iloc[0].dtype) - if not all(core.DataType.from_numpy_type(data_row.dtype) == arr_dtype for data_row in data[df_col]): + if not all(core.DataType.from_numpy_type(data_row.dtype) == arr_dtype for data_row in df_col_data): raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_DATA, original_exception=ValueError( "Data Validation Error: " - + f"Inconsistent type of element in object found in column data {data[df_col]}." + + f"Inconsistent type of element in object found in column data {df_col_data}." ), ) - elif not isinstance(data[df_col].iloc[0], (str, bytes)): + elif not isinstance(df_col_data.iloc[0], (str, bytes)): raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_DATA, original_exception=ValueError( - f"Data Validation Error: Unsupported type confronted in {data[df_col]}" + f"Data Validation Error: Unsupported type confronted in {df_col_data}" ), ) @staticmethod - def infer_signature(data: pd.DataFrame, role: Literal["input", "output"]) -> Sequence[core.BaseFeatureSpec]: + def infer_signature( + data: Union[pd.DataFrame, pd.Series], + role: Literal["input", "output"], + ) -> Sequence[core.BaseFeatureSpec]: feature_prefix = f"{PandasDataFrameHandler.FEATURE_PREFIX}_" + if isinstance(data, pd.Series): + data = data.to_frame() df_cols = data.columns role_prefix = ( PandasDataFrameHandler.INPUT_PREFIX if role == "input" else PandasDataFrameHandler.OUTPUT_PREFIX @@ -123,29 +162,34 @@ def infer_signature(data: pd.DataFrame, role: Literal["input", "output"]) -> Seq specs = [] for df_col, df_col_dtype, ft_name in zip(df_cols, df_col_dtypes, ft_names): + df_col_data = data[df_col] + if df_col_data.isnull().any(): + df_col_data = utils.series_dropna(df_col_data) + df_col_dtype = df_col_data.dtype + if df_col_dtype == np.dtype("O"): - if isinstance(data[df_col].iloc[0], list): - arr = utils.convert_list_to_ndarray(data[df_col].iloc[0]) + if isinstance(df_col_data.iloc[0], list): + arr = utils.convert_list_to_ndarray(df_col_data.iloc[0]) arr_dtype = core.DataType.from_numpy_type(arr.dtype) - ft_shape = np.shape(data[df_col].iloc[0]) + ft_shape = np.shape(df_col_data.iloc[0]) - converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in data[df_col]] + converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in df_col_data] if not all(np.shape(converted_data) == ft_shape for converted_data in converted_data_list): ft_shape = (-1,) specs.append(core.FeatureSpec(dtype=arr_dtype, name=ft_name, shape=ft_shape)) - elif isinstance(data[df_col].iloc[0], np.ndarray): - arr_dtype = core.DataType.from_numpy_type(data[df_col].iloc[0].dtype) - ft_shape = np.shape(data[df_col].iloc[0]) + elif isinstance(df_col_data.iloc[0], np.ndarray): + arr_dtype = core.DataType.from_numpy_type(df_col_data.iloc[0].dtype) + ft_shape = np.shape(df_col_data.iloc[0]) - if not all(np.shape(data_row) == ft_shape for data_row in data[df_col]): + if not all(np.shape(data_row) == ft_shape for data_row in df_col_data): ft_shape = (-1,) specs.append(core.FeatureSpec(dtype=arr_dtype, name=ft_name, shape=ft_shape)) - elif isinstance(data[df_col].iloc[0], str): + elif isinstance(df_col_data.iloc[0], str): specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name)) - elif isinstance(data[df_col].iloc[0], bytes): + elif isinstance(df_col_data.iloc[0], bytes): specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name)) elif isinstance(df_col_dtype, pd.CategoricalDtype): category_dtype = df_col_dtype.categories.dtype diff --git a/snowflake/ml/model/_signatures/pandas_test.py b/snowflake/ml/model/_signatures/pandas_test.py index 3fe5b240..4ab55e2c 100644 --- a/snowflake/ml/model/_signatures/pandas_test.py +++ b/snowflake/ml/model/_signatures/pandas_test.py @@ -44,13 +44,17 @@ def test_validate_pd_DataFrame(self) -> None: df = pd.DataFrame([[1, "Hello"], [2, [2, 6]]], columns=["a", "b"]) with exception_utils.assert_snowml_exceptions( - self, expected_original_error_type=ValueError, expected_regex="Inconsistent type of object" + self, + expected_original_error_type=ValueError, + expected_regex="Inconsistent type of element in object found in column data", ): pandas_handler.PandasDataFrameHandler.validate(df) df = pd.DataFrame([[1, 2], [2, [2, 6]]], columns=["a", "b"]) with exception_utils.assert_snowml_exceptions( - self, expected_original_error_type=ValueError, expected_regex="Inconsistent type of object" + self, + expected_original_error_type=ValueError, + expected_regex="Inconsistent type of element in object found in column data", ): pandas_handler.PandasDataFrameHandler.validate(df) @@ -86,10 +90,34 @@ def test_validate_pd_DataFrame(self) -> None: with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, - expected_regex="Inconsistent type of object found in column data", + expected_regex="Inconsistent type of element in object found in column data", + ): + pandas_handler.PandasDataFrameHandler.validate(df) + + df = pd.DataFrame([[None, 2], [None, 6]], columns=["a", "b"]) + with exception_utils.assert_snowml_exceptions( + self, + expected_original_error_type=ValueError, + expected_regex="There is no non-null data in column", ): pandas_handler.PandasDataFrameHandler.validate(df) + df = pd.DataFrame([[1, None], [2, 6]], columns=["a", "b"]) + with self.assertWarnsRegex(UserWarning, "Null value detected in column"): + pandas_handler.PandasDataFrameHandler.validate(df) + + df = pd.DataFrame([[1, np.array([2.5, 6.8])], [2, np.array([2.5, 6.8])], [3, None]], columns=["a", "b"]) + with self.assertWarnsRegex(UserWarning, "Null value detected in column"): + pandas_handler.PandasDataFrameHandler.validate(df) + + df = pd.DataFrame([[1, None], [2, [6]]], columns=["a", "b"]) + with self.assertWarnsRegex(UserWarning, "Null value detected in column"): + pandas_handler.PandasDataFrameHandler.validate(df) + + df = pd.DataFrame([[1, None], [2, "a"]], columns=["a", "b"]) + with self.assertWarnsRegex(UserWarning, "Null value detected in column"): + pandas_handler.PandasDataFrameHandler.validate(df) + def test_trunc_pd_DataFrame(self) -> None: df = pd.DataFrame([1] * (pandas_handler.PandasDataFrameHandler.SIG_INFER_ROWS_COUNT_LIMIT + 1)) @@ -118,18 +146,36 @@ def test_infer_signature_pd_DataFrame(self) -> None: [core.FeatureSpec("a", core.DataType.INT64)], ) + df = pd.DataFrame([1, 2, 3, None], columns=["a"]) + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), + [core.FeatureSpec("a", core.DataType.INT64)], + ) + df = pd.DataFrame(["a", "b", "c", "d"], columns=["a"]) self.assertListEqual( pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), [core.FeatureSpec("a", core.DataType.STRING)], ) + df = pd.DataFrame(["a", "b", None, "d"], columns=["a"]) + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), + [core.FeatureSpec("a", core.DataType.STRING)], + ) + df = pd.DataFrame([ele.encode() for ele in ["a", "b", "c", "d"]], columns=["a"]) self.assertListEqual( pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), [core.FeatureSpec("a", core.DataType.BYTES)], ) + df = pd.DataFrame([ele.encode() for ele in ["a", "b", "c", "d"]] + [None], columns=["a"]) + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), + [core.FeatureSpec("a", core.DataType.BYTES)], + ) + df = pd.DataFrame([[1, 2.0], [2, 4.0]]) self.assertListEqual( pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), @@ -139,6 +185,24 @@ def test_infer_signature_pd_DataFrame(self) -> None: ], ) + df = pd.DataFrame([[1, 2.0], [2, None]]) + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), + [ + core.FeatureSpec("input_feature_0", core.DataType.INT64), + core.FeatureSpec("input_feature_1", core.DataType.INT64), + ], + ) + + df = pd.DataFrame([[1, 2.4], [2, None]]) + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), + [ + core.FeatureSpec("input_feature_0", core.DataType.INT64), + core.FeatureSpec("input_feature_1", core.DataType.DOUBLE), + ], + ) + df = pd.DataFrame([[1, [2.5, 6.8]], [2, [2.5, 6.8]]], columns=["a", "b"]) self.assertListEqual( pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), @@ -148,6 +212,15 @@ def test_infer_signature_pd_DataFrame(self) -> None: ], ) + df = pd.DataFrame([[1, [2.5, 6.8]], [2, None]], columns=["a", "b"]) + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), + [ + core.FeatureSpec("a", core.DataType.INT64), + core.FeatureSpec("b", core.DataType.DOUBLE, shape=(2,)), + ], + ) + df = pd.DataFrame([[1, [2.5, 6.8]], [2, [2.5]]], columns=["a", "b"]) self.assertListEqual( pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), @@ -166,6 +239,15 @@ def test_infer_signature_pd_DataFrame(self) -> None: ], ) + df = pd.DataFrame([[1, [[2.5], [6.8]]], [2, None]], columns=["a", "b"]) + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), + [ + core.FeatureSpec("a", core.DataType.INT64), + core.FeatureSpec("b", core.DataType.DOUBLE, shape=(2, 1)), + ], + ) + a = np.array([2.5, 6.8]) df = pd.DataFrame([[1, a], [2, a]], columns=["a", "b"]) self.assertListEqual( @@ -176,6 +258,16 @@ def test_infer_signature_pd_DataFrame(self) -> None: ], ) + a = np.array([2.5, 6.8]) + df = pd.DataFrame([[1, a], [2, None]], columns=["a", "b"]) + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), + [ + core.FeatureSpec("a", core.DataType.INT64), + core.FeatureSpec("b", core.DataType.DOUBLE, shape=(2,)), + ], + ) + df = pd.DataFrame([[1, np.array([2.5, 6.8])], [2, np.array([2.5])]], columns=["a", "b"]) self.assertListEqual( pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input"), @@ -314,6 +406,86 @@ def test_infer_signature_pd_DataFrame_with_random_row_labels(self) -> None: df["input"] = df["input"].astype(np.dtype("O")) pandas_handler.PandasDataFrameHandler.validate(df) + def test_validate_pd_Series(self) -> None: + s = pd.Series([], dtype=pd.Int16Dtype()) + with exception_utils.assert_snowml_exceptions( + self, expected_original_error_type=ValueError, expected_regex="Empty data is found." + ): + pandas_handler.PandasDataFrameHandler.validate(s) + + s = pd.Series([1, 2, 3, 4]) + pandas_handler.PandasDataFrameHandler.validate(s) + + s = pd.Series([1, 2, 3, 4], name="a") + pandas_handler.PandasDataFrameHandler.validate(s) + + s = pd.Series(["a", "b", "c", "d"], name="a") + pandas_handler.PandasDataFrameHandler.validate(s) + + s = pd.Series( + [ele.encode() for ele in ["a", "b", "c", "d"]], + name="a", + ) + pandas_handler.PandasDataFrameHandler.validate(s) + + s = pd.Series([1, 2.0]) + pandas_handler.PandasDataFrameHandler.validate(s) + + s = pd.Series([1, [2.5, 6.8]], name="a") + with exception_utils.assert_snowml_exceptions( + self, + expected_original_error_type=ValueError, + expected_regex="Inconsistent type of element in object found in column data", + ): + pandas_handler.PandasDataFrameHandler.validate(s) + + a = np.array([2.5, 6.8]) + s = pd.Series([1, a], name="a") + with exception_utils.assert_snowml_exceptions( + self, + expected_original_error_type=ValueError, + expected_regex="Inconsistent type of element in object found in column data", + ): + pandas_handler.PandasDataFrameHandler.validate(s) + + def test_infer_signature_pd_Series(self) -> None: + s = pd.Series([1, 2, 3, 4]) + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(s, role="input"), + [core.FeatureSpec("input_feature_0", core.DataType.INT64)], + ) + + s = pd.Series([1, 2, 3, 4], name="a") + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(s, role="input"), + [core.FeatureSpec("a", core.DataType.INT64)], + ) + + s = pd.Series(["a", "b", "c", "d"], name="a") + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(s, role="input"), + [core.FeatureSpec("a", core.DataType.STRING)], + ) + + s = pd.Series([ele.encode() for ele in ["a", "b", "c", "d"]], name="a") + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(s, role="input"), + [core.FeatureSpec("a", core.DataType.BYTES)], + ) + + s = pd.Series([1, 2.0]) + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(s, role="input"), + [core.FeatureSpec("input_feature_0", core.DataType.DOUBLE)], + ) + + # series with bytes data + s = pd.Series([b"1", b"2", b"3", b"4"]) + self.assertListEqual( + pandas_handler.PandasDataFrameHandler.infer_signature(s, role="input"), + [core.FeatureSpec("input_feature_0", core.DataType.BYTES)], + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_signatures/pytorch_handler.py b/snowflake/ml/model/_signatures/pytorch_handler.py index 30b7d9ce..a55cc216 100644 --- a/snowflake/ml/model/_signatures/pytorch_handler.py +++ b/snowflake/ml/model/_signatures/pytorch_handler.py @@ -72,10 +72,10 @@ def infer_signature( dtype = core.DataType.from_torch_type(data_col.dtype) ft_name = f"{role_prefix}{feature_prefix}{i}" if len(data_col.shape) == 1: - features.append(core.FeatureSpec(dtype=dtype, name=ft_name)) + features.append(core.FeatureSpec(dtype=dtype, name=ft_name, nullable=False)) else: ft_shape = tuple(data_col.shape[1:]) - features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape)) + features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape, nullable=False)) return features @staticmethod diff --git a/snowflake/ml/model/_signatures/pytorch_test.py b/snowflake/ml/model/_signatures/pytorch_test.py index e875b477..2a5ec117 100644 --- a/snowflake/ml/model/_signatures/pytorch_test.py +++ b/snowflake/ml/model/_signatures/pytorch_test.py @@ -90,77 +90,77 @@ def test_infer_schema_torch_tensor(self) -> None: t1 = [torch.IntTensor([1, 2, 3, 4])] self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t1, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT32)], + [core.FeatureSpec("input_feature_0", core.DataType.INT32, nullable=False)], ) t2 = [torch.LongTensor([1, 2, 3, 4])] self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t2, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT64)], + [core.FeatureSpec("input_feature_0", core.DataType.INT64, nullable=False)], ) t3 = [torch.ShortTensor([1, 2, 3, 4])] self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t3, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT16)], + [core.FeatureSpec("input_feature_0", core.DataType.INT16, nullable=False)], ) t4 = [torch.CharTensor([1, 2, 3, 4])] self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t4, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT8)], + [core.FeatureSpec("input_feature_0", core.DataType.INT8, nullable=False)], ) t5 = [torch.ByteTensor([1, 2, 3, 4])] self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t5, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.UINT8)], + [core.FeatureSpec("input_feature_0", core.DataType.UINT8, nullable=False)], ) t6 = [torch.BoolTensor([False, True])] self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t6, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.BOOL)], + [core.FeatureSpec("input_feature_0", core.DataType.BOOL, nullable=False)], ) t7 = [torch.FloatTensor([1.2, 3.4])] self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t7, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.FLOAT)], + [core.FeatureSpec("input_feature_0", core.DataType.FLOAT, nullable=False)], ) t8 = [torch.DoubleTensor([1.2, 3.4])] self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t8, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.DOUBLE)], + [core.FeatureSpec("input_feature_0", core.DataType.DOUBLE, nullable=False)], ) t9 = [torch.LongTensor([[1, 2], [3, 4]])] self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t9, role="input"), [ - core.FeatureSpec("input_feature_0", core.DataType.INT64, shape=(2,)), + core.FeatureSpec("input_feature_0", core.DataType.INT64, shape=(2,), nullable=False), ], ) t10 = [torch.LongTensor([[[1, 1], [2, 2]], [[3, 3], [4, 4]]])] self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t10, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT64, shape=(2, 2))], + [core.FeatureSpec("input_feature_0", core.DataType.INT64, shape=(2, 2), nullable=False)], ) t11 = [torch.LongTensor([1, 2, 3, 4])] self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t11, role="output"), - [core.FeatureSpec("output_feature_0", core.DataType.INT64)], + [core.FeatureSpec("output_feature_0", core.DataType.INT64, nullable=False)], ) t12 = [torch.LongTensor([1, 2]), torch.LongTensor([3, 4])] self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t12, role="output"), [ - core.FeatureSpec("output_feature_0", core.DataType.INT64), - core.FeatureSpec("output_feature_1", core.DataType.INT64), + core.FeatureSpec("output_feature_0", core.DataType.INT64, nullable=False), + core.FeatureSpec("output_feature_1", core.DataType.INT64, nullable=False), ], ) @@ -168,8 +168,8 @@ def test_infer_schema_torch_tensor(self) -> None: self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t13, role="output"), [ - core.FeatureSpec("output_feature_0", core.DataType.FLOAT), - core.FeatureSpec("output_feature_1", core.DataType.INT64), + core.FeatureSpec("output_feature_0", core.DataType.FLOAT, nullable=False), + core.FeatureSpec("output_feature_1", core.DataType.INT64, nullable=False), ], ) @@ -177,8 +177,8 @@ def test_infer_schema_torch_tensor(self) -> None: self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t14, role="output"), [ - core.FeatureSpec("output_feature_0", core.DataType.INT64, shape=(2,)), - core.FeatureSpec("output_feature_1", core.DataType.INT64, shape=(2,)), + core.FeatureSpec("output_feature_0", core.DataType.INT64, shape=(2,), nullable=False), + core.FeatureSpec("output_feature_1", core.DataType.INT64, shape=(2,), nullable=False), ], ) @@ -186,8 +186,8 @@ def test_infer_schema_torch_tensor(self) -> None: self.assertListEqual( pytorch_handler.SeqOfPyTorchTensorHandler.infer_signature(t15, role="output"), [ - core.FeatureSpec("output_feature_0", core.DataType.INT64, shape=(2,)), - core.FeatureSpec("output_feature_1", core.DataType.DOUBLE, shape=(2,)), + core.FeatureSpec("output_feature_0", core.DataType.INT64, shape=(2,), nullable=False), + core.FeatureSpec("output_feature_1", core.DataType.DOUBLE, shape=(2,), nullable=False), ], ) diff --git a/snowflake/ml/model/_signatures/snowpark_handler.py b/snowflake/ml/model/_signatures/snowpark_handler.py index 4d85e6f4..f59c4c28 100644 --- a/snowflake/ml/model/_signatures/snowpark_handler.py +++ b/snowflake/ml/model/_signatures/snowpark_handler.py @@ -82,7 +82,8 @@ def convert_to_df( identifier.get_unescaped_names(field.name) ].map(json.loads) # Only when the feature is not from inference, we are confident to do the type casting. - # Otherwise, dtype_map will be empty + # Otherwise, dtype_map will be empty. + # Errors are ignored to make sure None won't be converted and won't raise Error df_local = df_local.astype(dtype=dtype_map) return df_local diff --git a/snowflake/ml/model/_signatures/tensorflow_handler.py b/snowflake/ml/model/_signatures/tensorflow_handler.py index a78c16e2..fd10da3f 100644 --- a/snowflake/ml/model/_signatures/tensorflow_handler.py +++ b/snowflake/ml/model/_signatures/tensorflow_handler.py @@ -109,10 +109,10 @@ def infer_signature( dtype = core.DataType.from_numpy_type(data_col.dtype.as_numpy_dtype) ft_name = f"{role_prefix}{feature_prefix}{i}" if len(data_col.shape) == 1: - features.append(core.FeatureSpec(dtype=dtype, name=ft_name)) + features.append(core.FeatureSpec(dtype=dtype, name=ft_name, nullable=False)) else: ft_shape = tuple(data_col.shape[1:]) - features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape)) + features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape, nullable=False)) return features @staticmethod diff --git a/snowflake/ml/model/_signatures/tensorflow_test.py b/snowflake/ml/model/_signatures/tensorflow_test.py index f990845d..e9d71376 100644 --- a/snowflake/ml/model/_signatures/tensorflow_test.py +++ b/snowflake/ml/model/_signatures/tensorflow_test.py @@ -157,95 +157,95 @@ def test_infer_schema_tf_tensor(self) -> None: t1 = [tf.constant([1, 2, 3, 4], dtype=tf.int32)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t1, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT32)], + [core.FeatureSpec("input_feature_0", core.DataType.INT32, nullable=False)], ) t2 = [tf.constant([1, 2, 3, 4], dtype=tf.int64)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t2, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT64)], + [core.FeatureSpec("input_feature_0", core.DataType.INT64, nullable=False)], ) t3 = [tf.constant([1, 2, 3, 4], dtype=tf.int16)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t3, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT16)], + [core.FeatureSpec("input_feature_0", core.DataType.INT16, nullable=False)], ) t4 = [tf.constant([1, 2, 3, 4], dtype=tf.int8)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t4, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT8)], + [core.FeatureSpec("input_feature_0", core.DataType.INT8, nullable=False)], ) t5 = [tf.constant([1, 2, 3, 4], dtype=tf.uint32)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t5, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.UINT32)], + [core.FeatureSpec("input_feature_0", core.DataType.UINT32, nullable=False)], ) t6 = [tf.constant([1, 2, 3, 4], dtype=tf.uint64)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t6, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.UINT64)], + [core.FeatureSpec("input_feature_0", core.DataType.UINT64, nullable=False)], ) t7 = [tf.constant([1, 2, 3, 4], dtype=tf.uint16)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t7, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.UINT16)], + [core.FeatureSpec("input_feature_0", core.DataType.UINT16, nullable=False)], ) t8 = [tf.constant([1, 2, 3, 4], dtype=tf.uint8)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t8, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.UINT8)], + [core.FeatureSpec("input_feature_0", core.DataType.UINT8, nullable=False)], ) t9 = [tf.constant([False, True])] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t9, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.BOOL)], + [core.FeatureSpec("input_feature_0", core.DataType.BOOL, nullable=False)], ) t10 = [tf.constant([1.2, 3.4], dtype=tf.float32)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t10, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.FLOAT)], + [core.FeatureSpec("input_feature_0", core.DataType.FLOAT, nullable=False)], ) t11 = [tf.constant([1.2, 3.4], dtype=tf.float64)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t11, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.DOUBLE)], + [core.FeatureSpec("input_feature_0", core.DataType.DOUBLE, nullable=False)], ) t12 = [tf.constant([[1, 2], [3, 4]])] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t12, role="input"), [ - core.FeatureSpec("input_feature_0", core.DataType.INT32, shape=(2,)), + core.FeatureSpec("input_feature_0", core.DataType.INT32, shape=(2,), nullable=False), ], ) t13 = [tf.constant([[[1, 1], [2, 2]], [[3, 3], [4, 4]]])] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t13, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT32, shape=(2, 2))], + [core.FeatureSpec("input_feature_0", core.DataType.INT32, shape=(2, 2), nullable=False)], ) t14 = [tf.constant([1, 2, 3, 4])] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t14, role="output"), - [core.FeatureSpec("output_feature_0", core.DataType.INT32)], + [core.FeatureSpec("output_feature_0", core.DataType.INT32, nullable=False)], ) t15 = [tf.constant([1, 2]), tf.constant([3, 4])] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t15, role="output"), [ - core.FeatureSpec("output_feature_0", core.DataType.INT32), - core.FeatureSpec("output_feature_1", core.DataType.INT32), + core.FeatureSpec("output_feature_0", core.DataType.INT32, nullable=False), + core.FeatureSpec("output_feature_1", core.DataType.INT32, nullable=False), ], ) @@ -253,8 +253,8 @@ def test_infer_schema_tf_tensor(self) -> None: self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t16, role="output"), [ - core.FeatureSpec("output_feature_0", core.DataType.FLOAT), - core.FeatureSpec("output_feature_1", core.DataType.INT32), + core.FeatureSpec("output_feature_0", core.DataType.FLOAT, nullable=False), + core.FeatureSpec("output_feature_1", core.DataType.INT32, nullable=False), ], ) @@ -262,8 +262,8 @@ def test_infer_schema_tf_tensor(self) -> None: self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t17, role="output"), [ - core.FeatureSpec("output_feature_0", core.DataType.INT32, shape=(2,)), - core.FeatureSpec("output_feature_1", core.DataType.INT32, shape=(2,)), + core.FeatureSpec("output_feature_0", core.DataType.INT32, shape=(2,), nullable=False), + core.FeatureSpec("output_feature_1", core.DataType.INT32, shape=(2,), nullable=False), ], ) @@ -271,103 +271,103 @@ def test_infer_schema_tf_tensor(self) -> None: self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t18, role="output"), [ - core.FeatureSpec("output_feature_0", core.DataType.INT32, shape=(2,)), - core.FeatureSpec("output_feature_1", core.DataType.FLOAT, shape=(2,)), + core.FeatureSpec("output_feature_0", core.DataType.INT32, shape=(2,), nullable=False), + core.FeatureSpec("output_feature_1", core.DataType.FLOAT, shape=(2,), nullable=False), ], ) t21 = [tf.constant([1, 2, 3, 4], dtype=tf.int32)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t21, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT32)], + [core.FeatureSpec("input_feature_0", core.DataType.INT32, nullable=False)], ) t22 = [tf.constant([1, 2, 3, 4], dtype=tf.int64)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t22, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT64)], + [core.FeatureSpec("input_feature_0", core.DataType.INT64, nullable=False)], ) t23 = [tf.constant([1, 2, 3, 4], dtype=tf.int16)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t23, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT16)], + [core.FeatureSpec("input_feature_0", core.DataType.INT16, nullable=False)], ) t24 = [tf.constant([1, 2, 3, 4], dtype=tf.int8)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t24, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT8)], + [core.FeatureSpec("input_feature_0", core.DataType.INT8, nullable=False)], ) t25 = [tf.constant([1, 2, 3, 4], dtype=tf.uint32)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t25, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.UINT32)], + [core.FeatureSpec("input_feature_0", core.DataType.UINT32, nullable=False)], ) t26 = [tf.constant([1, 2, 3, 4], dtype=tf.uint64)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t26, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.UINT64)], + [core.FeatureSpec("input_feature_0", core.DataType.UINT64, nullable=False)], ) t27 = [tf.constant([1, 2, 3, 4], dtype=tf.uint16)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t27, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.UINT16)], + [core.FeatureSpec("input_feature_0", core.DataType.UINT16, nullable=False)], ) t28 = [tf.constant([1, 2, 3, 4], dtype=tf.uint8)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t28, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.UINT8)], + [core.FeatureSpec("input_feature_0", core.DataType.UINT8, nullable=False)], ) t29 = [tf.constant([False, True])] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t29, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.BOOL)], + [core.FeatureSpec("input_feature_0", core.DataType.BOOL, nullable=False)], ) t30 = [tf.constant([1.2, 3.4], dtype=tf.float32)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t30, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.FLOAT)], + [core.FeatureSpec("input_feature_0", core.DataType.FLOAT, nullable=False)], ) t31 = [tf.constant([1.2, 3.4], dtype=tf.float64)] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t31, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.DOUBLE)], + [core.FeatureSpec("input_feature_0", core.DataType.DOUBLE, nullable=False)], ) t32 = [tf.constant([[1, 2], [3, 4]])] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t32, role="input"), [ - core.FeatureSpec("input_feature_0", core.DataType.INT32, shape=(2,)), + core.FeatureSpec("input_feature_0", core.DataType.INT32, shape=(2,), nullable=False), ], ) t33 = [tf.constant([[[1, 1], [2, 2]], [[3, 3], [4, 4]]])] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t33, role="input"), - [core.FeatureSpec("input_feature_0", core.DataType.INT32, shape=(2, 2))], + [core.FeatureSpec("input_feature_0", core.DataType.INT32, shape=(2, 2), nullable=False)], ) t34 = [tf.constant([1, 2, 3, 4])] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t34, role="output"), - [core.FeatureSpec("output_feature_0", core.DataType.INT32)], + [core.FeatureSpec("output_feature_0", core.DataType.INT32, nullable=False)], ) t35 = [tf.constant([1, 2]), tf.constant([3, 4])] self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t35, role="output"), [ - core.FeatureSpec("output_feature_0", core.DataType.INT32), - core.FeatureSpec("output_feature_1", core.DataType.INT32), + core.FeatureSpec("output_feature_0", core.DataType.INT32, nullable=False), + core.FeatureSpec("output_feature_1", core.DataType.INT32, nullable=False), ], ) @@ -375,8 +375,8 @@ def test_infer_schema_tf_tensor(self) -> None: self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t36, role="output"), [ - core.FeatureSpec("output_feature_0", core.DataType.FLOAT), - core.FeatureSpec("output_feature_1", core.DataType.INT32), + core.FeatureSpec("output_feature_0", core.DataType.FLOAT, nullable=False), + core.FeatureSpec("output_feature_1", core.DataType.INT32, nullable=False), ], ) @@ -384,8 +384,8 @@ def test_infer_schema_tf_tensor(self) -> None: self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t37, role="output"), [ - core.FeatureSpec("output_feature_0", core.DataType.INT32, shape=(2,)), - core.FeatureSpec("output_feature_1", core.DataType.INT32, shape=(2,)), + core.FeatureSpec("output_feature_0", core.DataType.INT32, shape=(2,), nullable=False), + core.FeatureSpec("output_feature_1", core.DataType.INT32, shape=(2,), nullable=False), ], ) @@ -393,8 +393,8 @@ def test_infer_schema_tf_tensor(self) -> None: self.assertListEqual( tensorflow_handler.SeqOfTensorflowTensorHandler.infer_signature(t38, role="output"), [ - core.FeatureSpec("output_feature_0", core.DataType.INT32, shape=(2,)), - core.FeatureSpec("output_feature_1", core.DataType.FLOAT, shape=(2,)), + core.FeatureSpec("output_feature_0", core.DataType.INT32, shape=(2,), nullable=False), + core.FeatureSpec("output_feature_1", core.DataType.FLOAT, shape=(2,), nullable=False), ], ) diff --git a/snowflake/ml/model/_signatures/utils.py b/snowflake/ml/model/_signatures/utils.py index 354df78f..8ff11308 100644 --- a/snowflake/ml/model/_signatures/utils.py +++ b/snowflake/ml/model/_signatures/utils.py @@ -297,3 +297,7 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any]) ) return None + + +def series_dropna(series: pd.Series) -> pd.Series: + return series.dropna(inplace=False).reset_index(drop=True).convert_dtypes() diff --git a/snowflake/ml/model/model_signature.py b/snowflake/ml/model/model_signature.py index 93c21738..627a1e57 100644 --- a/snowflake/ml/model/model_signature.py +++ b/snowflake/ml/model/model_signature.py @@ -139,9 +139,32 @@ def _rename_signature_with_snowflake_identifiers( return signature -def _validate_numpy_array( - arr: model_types._SupportedNumpyArray, feature_type: core.DataType, strict: bool = False +def _validate_array_or_series_type( + arr: Union[model_types._SupportedNumpyArray, pd.Series], feature_type: core.DataType, strict: bool = False ) -> bool: + original_dtype = arr.dtype + dtype = arr.dtype + if isinstance( + dtype, + ( + pd.Int8Dtype, + pd.Int16Dtype, + pd.Int32Dtype, + pd.Int64Dtype, + pd.UInt8Dtype, + pd.UInt16Dtype, + pd.UInt32Dtype, + pd.UInt64Dtype, + pd.Float32Dtype, + pd.Float64Dtype, + pd.BooleanDtype, + ), + ): + dtype = dtype.type + elif isinstance(dtype, pd.CategoricalDtype): + dtype = dtype.categories.dtype + elif isinstance(dtype, pd.StringDtype): + dtype = np.str_ if feature_type in [ core.DataType.INT8, core.DataType.INT16, @@ -152,14 +175,17 @@ def _validate_numpy_array( core.DataType.UINT32, core.DataType.UINT64, ]: - if not (np.issubdtype(arr.dtype, np.integer)): + if not (np.issubdtype(dtype, np.integer)): return False if not strict: return True - min_v, max_v = arr.min(), arr.max() + if isinstance(original_dtype, pd.CategoricalDtype): + min_v, max_v = arr.cat.as_ordered().min(), arr.cat.as_ordered().min() # type: ignore[union-attr] + else: + min_v, max_v = arr.min(), arr.max() return bool(max_v <= np.iinfo(feature_type._numpy_type).max and min_v >= np.iinfo(feature_type._numpy_type).min) elif feature_type in [core.DataType.FLOAT, core.DataType.DOUBLE]: - if not (np.issubdtype(arr.dtype, np.integer) or np.issubdtype(arr.dtype, np.floating)): + if not (np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.floating)): return False if not strict: return True @@ -171,7 +197,7 @@ def _validate_numpy_array( elif feature_type in [core.DataType.TIMESTAMP_NTZ]: return np.issubdtype(arr.dtype, np.datetime64) else: - return np.can_cast(arr.dtype, feature_type._numpy_type, casting="no") + return np.can_cast(dtype, feature_type._numpy_type, casting="no") def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureSpec], strict: bool = False) -> None: @@ -204,7 +230,10 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS original_exception=ValueError(f"Data Validation Error: feature {ft_name} does not exist in data."), ) + if data_col.isnull().any(): + data_col = utils.series_dropna(data_col) df_col_dtype = data_col.dtype + if isinstance(feature, core.FeatureGroupSpec): raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.NOT_IMPLEMENTED, @@ -217,7 +246,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS if isinstance(df_col_dtype, pd.CategoricalDtype): df_col_dtype = df_col_dtype.categories.dtype if df_col_dtype != np.dtype("O"): - if not _validate_numpy_array(data_col.to_numpy(), ft_type, strict=strict): + if not _validate_array_or_series_type(data_col, ft_type, strict=strict): raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_DATA, original_exception=ValueError( @@ -247,7 +276,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in data_col] if not all( - _validate_numpy_array(converted_data, ft_type, strict=strict) + _validate_array_or_series_type(converted_data, ft_type, strict=strict) for converted_data in converted_data_list ): raise snowml_exceptions.SnowflakeMLException( @@ -278,7 +307,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS ), ) - if not all(_validate_numpy_array(data_row, ft_type, strict=strict) for data_row in data_col): + if not all(_validate_array_or_series_type(data_row, ft_type, strict=strict) for data_row in data_col): raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_DATA, original_exception=ValueError( diff --git a/snowflake/ml/model/model_signature_test.py b/snowflake/ml/model/model_signature_test.py index e2762f72..ebdecd7c 100644 --- a/snowflake/ml/model/model_signature_test.py +++ b/snowflake/ml/model/model_signature_test.py @@ -18,6 +18,12 @@ def test_infer_signature(self) -> None: [model_signature.FeatureSpec("input_feature_0", model_signature.DataType.INT64)], ) + df = pd.DataFrame([1, 2, None, 4]) + self.assertListEqual( + model_signature._infer_signature(df, role="input"), + [model_signature.FeatureSpec("input_feature_0", model_signature.DataType.INT64)], + ) + arr = np.array([1, 2, 3, 4]) self.assertListEqual( model_signature._infer_signature(arr, role="input"), @@ -120,6 +126,26 @@ def test_infer_signature(self) -> None: ], ) + # categorical column + df = pd.DataFrame({"column_0": ["a", "b", "c", "d"], "column_1": [1, 2, 3, 4]}) + df["column_0"] = df["column_0"].astype("category") + df["column_1"] = df["column_1"].astype("category") + + self.assertListEqual( + model_signature._infer_signature(df, role="input"), + [ + model_signature.FeatureSpec("column_0", model_signature.DataType.STRING), + model_signature.FeatureSpec("column_1", model_signature.DataType.INT64), + ], + ) + + series = pd.Series(["a", "b", "c", "d"], name="column_0") + series = series.astype("category") + self.assertListEqual( + model_signature._infer_signature(series, role="input"), + [model_signature.FeatureSpec("column_0", model_signature.DataType.STRING)], + ) + df = pd.DataFrame([1, 2, 3, 4]) lt = [df, arr] with exception_utils.assert_snowml_exceptions( @@ -187,6 +213,10 @@ def test_validate_pandas_df(self) -> None: model_signature._validate_pandas_df(pd.DataFrame([[2, 5], [6, 8]], columns=["a", "b"]), fts) + model_signature._validate_pandas_df(pd.DataFrame([[2, None], [6, 8]], columns=["a", "b"]), fts) + + model_signature._validate_pandas_df(pd.DataFrame([[2, None], [6, 8.0]], columns=["a", "b"]), fts) + with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, @@ -258,6 +288,8 @@ def test_validate_pandas_df(self) -> None: model_signature._validate_pandas_df(pd.DataFrame([[1, [2.5, 6.8]], [2, [2.5, 6.8]]], columns=["a", "b"]), fts) + model_signature._validate_pandas_df(pd.DataFrame([[1, [2.5, 6.8]], [2, None]], columns=["a", "b"]), fts) + with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, @@ -352,6 +384,10 @@ def test_validate_pandas_df(self) -> None: pd.DataFrame([[1, np.array([2.5, 6.8])], [2, np.array([2.5, 6.8])]], columns=["a", "b"]), fts ) + model_signature._validate_pandas_df( + pd.DataFrame([[1, np.array([2.5, 6.8])], [2, None]], columns=["a", "b"]), fts + ) + model_signature._validate_pandas_df( pd.DataFrame([[1, np.array([2.5, 6.8, 6.8])], [2, np.array([2.5, 6.8, 6.8])]], columns=["a", "b"]), fts ) @@ -373,6 +409,8 @@ def test_validate_pandas_df(self) -> None: pd.DataFrame([[1, [[2.5], [6.8]]], [2, [[2.5], [6.8]]]], columns=["a", "b"]), fts ) + model_signature._validate_pandas_df(pd.DataFrame([[1, [[2.5], [6.8]]], [2, None]], columns=["a", "b"]), fts) + with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, @@ -424,6 +462,8 @@ def test_validate_pandas_df(self) -> None: model_signature._validate_pandas_df(pd.DataFrame(["a", "b", "c", "d"], columns=["a"]), fts) model_signature._validate_pandas_df(pd.DataFrame(["a", "b", "c", "d"], columns=["a"], index=[2, 5, 6, 8]), fts) + model_signature._validate_pandas_df(pd.DataFrame(["a", "b", None, "d"], columns=["a"]), fts) + with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, @@ -455,6 +495,9 @@ def test_validate_pandas_df(self) -> None: model_signature._validate_pandas_df( pd.DataFrame([ele.encode() for ele in ["a", "b", "c", "d"]], columns=["a"], index=[2, 5, 6, 8]), fts ) + model_signature._validate_pandas_df( + pd.DataFrame([ele.encode() for ele in ["a", "b", "c", "d"]] + [None], columns=["a"]), fts + ) with exception_utils.assert_snowml_exceptions( self, diff --git a/snowflake/ml/model/type_hints.py b/snowflake/ml/model/type_hints.py index 0a1b980b..8fb7b0e1 100644 --- a/snowflake/ml/model/type_hints.py +++ b/snowflake/ml/model/type_hints.py @@ -66,7 +66,7 @@ "xgboost.XGBModel", "xgboost.Booster", "torch.nn.Module", - "torch.jit.ScriptModule", # type:ignore[name-defined] + "torch.jit.ScriptModule", "tensorflow.Module", ] diff --git a/snowflake/ml/monitoring/_client/BUILD.bazel b/snowflake/ml/monitoring/_client/BUILD.bazel index f60c4411..20844242 100644 --- a/snowflake/ml/monitoring/_client/BUILD.bazel +++ b/snowflake/ml/monitoring/_client/BUILD.bazel @@ -41,3 +41,14 @@ py_test( "//snowflake/ml/test_utils:mock_session", ], ) + +py_test( + name = "model_monitor_sql_client_server_test", + srcs = [ + "model_monitor_sql_client_server_test.py", + ], + deps = [ + ":model_monitor_sql_client", + "//snowflake/ml/test_utils:mock_session", + ], +) diff --git a/snowflake/ml/monitoring/_client/model_monitor_sql_client.py b/snowflake/ml/monitoring/_client/model_monitor_sql_client.py index b24a2d5b..29d3faf8 100644 --- a/snowflake/ml/monitoring/_client/model_monitor_sql_client.py +++ b/snowflake/ml/monitoring/_client/model_monitor_sql_client.py @@ -1,36 +1,23 @@ -import json -import string -import textwrap import typing from collections import Counter -from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, TypedDict - -from importlib_resources import files -from typing_extensions import Required +from typing import Any, Dict, List, Mapping, Optional, Set from snowflake import snowpark -from snowflake.connector import errors from snowflake.ml._internal.utils import ( db_utils, - formatting, query_result_checker, sql_identifier, table_manager, ) -from snowflake.ml.model import type_hints from snowflake.ml.model._client.sql import _base from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema -from snowflake.ml.monitoring.entities import model_monitor_interval, output_score_type -from snowflake.ml.monitoring.entities.model_monitor_interval import ( - ModelMonitorAggregationWindow, - ModelMonitorRefreshInterval, -) -from snowflake.snowpark import DataFrame, exceptions, session, types -from snowflake.snowpark._internal import type_utils +from snowflake.snowpark import session, types SNOWML_MONITORING_METADATA_TABLE_NAME = "_SYSTEM_MONITORING_METADATA" -_SNOWML_MONITORING_TABLE_NAME_PREFIX = "_SNOWML_OBS_MONITORING_" -_SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX = "_SNOWML_OBS_ACCURACY_" + +MODEL_JSON_COL_NAME = "model" +MODEL_JSON_MODEL_NAME_FIELD = "model_name" +MODEL_JSON_VERSION_NAME_FIELD = "version_name" MONITOR_NAME_COL_NAME = "MONITOR_NAME" SOURCE_TABLE_NAME_COL_NAME = "SOURCE_TABLE_NAME" @@ -44,84 +31,10 @@ LABEL_COL_NAMES_COL_NAME = "LABEL_COLUMN_NAMES" ID_COL_NAMES_COL_NAME = "ID_COLUMN_NAMES" -_DASHBOARD_UDTFS_COMMON_LIST = ["record_count"] -_DASHBOARD_UDTFS_REGRESSION_LIST = ["rmse"] - - -def _initialize_monitoring_metadata_tables( - session: session.Session, - database_name: sql_identifier.SqlIdentifier, - schema_name: sql_identifier.SqlIdentifier, - statement_params: Optional[Dict[str, Any]] = None, -) -> None: - """Create tables necessary for Model Monitoring in provided schema. - - Args: - session: Active Snowpark session. - database_name: The database in which to setup resources for Model Monitoring. - schema_name: The schema in which to setup resources for Model Monitoring. - statement_params: Optional statement params for queries. - """ - table_manager.create_single_table( - session, - database_name, - schema_name, - SNOWML_MONITORING_METADATA_TABLE_NAME, - [ - (MONITOR_NAME_COL_NAME, "VARCHAR"), - (SOURCE_TABLE_NAME_COL_NAME, "VARCHAR"), - (FQ_MODEL_NAME_COL_NAME, "VARCHAR"), - (VERSION_NAME_COL_NAME, "VARCHAR"), - (FUNCTION_NAME_COL_NAME, "VARCHAR"), - (TASK_COL_NAME, "VARCHAR"), - (MONITORING_ENABLED_COL_NAME, "BOOLEAN"), - (TIMESTAMP_COL_NAME_COL_NAME, "VARCHAR"), - (PREDICTION_COL_NAMES_COL_NAME, "ARRAY"), - (LABEL_COL_NAMES_COL_NAME, "ARRAY"), - (ID_COL_NAMES_COL_NAME, "ARRAY"), - ], - statement_params=statement_params, - ) - - -def _create_baseline_table_name(model_name: str, version_name: str) -> str: - return f"_SNOWML_OBS_BASELINE_{model_name}_{version_name}" - - -def _infer_numeric_categoric_feature_column_names( - *, - source_table_schema: Mapping[str, types.DataType], - timestamp_column: sql_identifier.SqlIdentifier, - id_columns: List[sql_identifier.SqlIdentifier], - prediction_columns: List[sql_identifier.SqlIdentifier], - label_columns: List[sql_identifier.SqlIdentifier], -) -> Tuple[List[sql_identifier.SqlIdentifier], List[sql_identifier.SqlIdentifier]]: - cols_to_remove = {timestamp_column, *id_columns, *prediction_columns, *label_columns} - cols_to_consider = [ - (col_name, source_table_schema[col_name]) for col_name in source_table_schema if col_name not in cols_to_remove - ] - numeric_cols = [ - sql_identifier.SqlIdentifier(column[0]) - for column in cols_to_consider - if isinstance(column[1], types._NumericType) - ] - categorical_cols = [ - sql_identifier.SqlIdentifier(column[0]) - for column in cols_to_consider - if isinstance(column[1], types.StringType) or isinstance(column[1], types.BooleanType) - ] - return (numeric_cols, categorical_cols) - - -class _ModelMonitorParams(TypedDict): - """Class to transfer model monitor parameters to the ModelMonitor class.""" - monitor_name: Required[str] - fully_qualified_model_name: Required[str] - version_name: Required[str] - function_name: Required[str] - prediction_columns: Required[List[sql_identifier.SqlIdentifier]] - label_columns: Required[List[sql_identifier.SqlIdentifier]] +def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str: + sql_list = ", ".join([f"'{column}'" for column in columns]) + return f"({sql_list})" class ModelMonitorSQLClient: @@ -143,38 +56,95 @@ def __init__( self._database_name = database_name self._schema_name = schema_name - @staticmethod - def initialize_monitoring_schema( - session: session.Session, - database_name: sql_identifier.SqlIdentifier, - schema_name: sql_identifier.SqlIdentifier, + def _infer_qualified_schema( + self, database_name: Optional[sql_identifier.SqlIdentifier], schema_name: Optional[sql_identifier.SqlIdentifier] + ) -> str: + return f"{database_name or self._database_name}.{schema_name or self._schema_name}" + + def create_model_monitor( + self, + *, + monitor_database: Optional[sql_identifier.SqlIdentifier], + monitor_schema: Optional[sql_identifier.SqlIdentifier], + monitor_name: sql_identifier.SqlIdentifier, + source_database: Optional[sql_identifier.SqlIdentifier], + source_schema: Optional[sql_identifier.SqlIdentifier], + source: sql_identifier.SqlIdentifier, + model_database: Optional[sql_identifier.SqlIdentifier], + model_schema: Optional[sql_identifier.SqlIdentifier], + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + function_name: str, + warehouse_name: sql_identifier.SqlIdentifier, + timestamp_column: sql_identifier.SqlIdentifier, + id_columns: List[sql_identifier.SqlIdentifier], + prediction_score_columns: List[sql_identifier.SqlIdentifier], + prediction_class_columns: List[sql_identifier.SqlIdentifier], + actual_score_columns: List[sql_identifier.SqlIdentifier], + actual_class_columns: List[sql_identifier.SqlIdentifier], + refresh_interval: str, + aggregation_window: str, + baseline_database: Optional[sql_identifier.SqlIdentifier] = None, + baseline_schema: Optional[sql_identifier.SqlIdentifier] = None, + baseline: Optional[sql_identifier.SqlIdentifier] = None, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - """Initialize tables for tracking metadata associated with model monitoring. - - Args: - session: The Snowpark Session to connect with Snowflake. - database_name: The database in which to setup resources for Model Monitoring. - schema_name: The schema in which to setup resources for Model Monitoring. - statement_params: Optional set of statement_params to include with query. - """ - # Create metadata management tables - _initialize_monitoring_metadata_tables(session, database_name, schema_name, statement_params) + baseline_sql = "" + if baseline: + baseline_sql = f"BASELINE='{self._infer_qualified_schema(baseline_database, baseline_schema)}.{baseline}'" + query_result_checker.SqlResultValidator( + self._sql_client._session, + f""" + CREATE MODEL MONITOR {self._infer_qualified_schema(monitor_database, monitor_schema)}.{monitor_name} + WITH + MODEL='{self._infer_qualified_schema(model_database, model_schema)}.{model_name}' + VERSION='{version_name}' + FUNCTION='{function_name}' + WAREHOUSE='{warehouse_name}' + SOURCE='{self._infer_qualified_schema(source_database, source_schema)}.{source}' + ID_COLUMNS={_build_sql_list_from_columns(id_columns)} + PREDICTION_SCORE_COLUMNS={_build_sql_list_from_columns(prediction_score_columns)} + PREDICTION_CLASS_COLUMNS={_build_sql_list_from_columns(prediction_class_columns)} + ACTUAL_SCORE_COLUMNS={_build_sql_list_from_columns(actual_score_columns)} + ACTUAL_CLASS_COLUMNS={_build_sql_list_from_columns(actual_class_columns)} + TIMESTAMP_COLUMN='{timestamp_column}' + REFRESH_INTERVAL='{refresh_interval}' + AGGREGATION_WINDOW='{aggregation_window}' + {baseline_sql}""", + statement_params=statement_params, + ).has_column("status").has_dimensions(1, 1).validate() - def _validate_is_initialized(self) -> bool: - """Validates whether monitoring metadata has been initialized. + def drop_model_monitor( + self, + *, + database_name: Optional[sql_identifier.SqlIdentifier] = None, + schema_name: Optional[sql_identifier.SqlIdentifier] = None, + monitor_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + search_database_name = database_name or self._database_name + search_schema_name = schema_name or self._schema_name + query_result_checker.SqlResultValidator( + self._sql_client._session, + f"DROP MODEL MONITOR {search_database_name}.{search_schema_name}.{monitor_name}", + statement_params=statement_params, + ).validate() - Returns: - boolean to indicate whether tables have been initialized. - """ - try: - return table_manager.validate_table_exist( + def show_model_monitors( + self, + *, + statement_params: Optional[Dict[str, Any]] = None, + ) -> List[snowpark.Row]: + fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()]) + return ( + query_result_checker.SqlResultValidator( self._sql_client._session, - SNOWML_MONITORING_METADATA_TABLE_NAME, - f"{self._database_name}.{self._schema_name}", + f"SHOW MODEL MONITORS IN {fully_qualified_schema_name}", + statement_params=statement_params, ) - except exceptions.SnowparkSQLException: - return False + .has_column("name", allow_empty=True) + .validate() + ) def _validate_unique_columns( self, @@ -191,53 +161,24 @@ def _validate_unique_columns( def validate_existence_by_name( self, + *, + database_name: Optional[sql_identifier.SqlIdentifier] = None, + schema_name: Optional[sql_identifier.SqlIdentifier] = None, monitor_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, ) -> bool: + search_database_name = database_name or self._database_name + search_schema_name = schema_name or self._schema_name res = ( query_result_checker.SqlResultValidator( self._sql_client._session, - f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME} - FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME} - WHERE {MONITOR_NAME_COL_NAME} = '{monitor_name}'""", - statement_params=statement_params, - ) - .has_column(FQ_MODEL_NAME_COL_NAME, allow_empty=True) - .has_column(VERSION_NAME_COL_NAME, allow_empty=True) - .validate() - ) - return len(res) >= 1 - - def validate_existence( - self, - fully_qualified_model_name: str, - version_name: sql_identifier.SqlIdentifier, - statement_params: Optional[Dict[str, Any]] = None, - ) -> bool: - """Validate existence of a ModelMonitor on a Model Version. - - Args: - fully_qualified_model_name: Fully qualified name of model. - version_name: Name of model version. - statement_params: Optional set of statement_params to include with query. - - Returns: - Boolean indicating whether monitor exists on model version. - """ - res = ( - query_result_checker.SqlResultValidator( - self._sql_client._session, - f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME} - FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME} - WHERE {FQ_MODEL_NAME_COL_NAME} = '{fully_qualified_model_name}' - AND {VERSION_NAME_COL_NAME} = '{version_name}'""", + f"SHOW MODEL MONITORS LIKE '{monitor_name.resolved()}' IN {search_database_name}.{search_schema_name}", statement_params=statement_params, ) - .has_column(FQ_MODEL_NAME_COL_NAME, allow_empty=True) - .has_column(VERSION_NAME_COL_NAME, allow_empty=True) + .has_column("name", allow_empty=True) .validate() ) - return len(res) >= 1 + return len(res) == 1 def validate_monitor_warehouse( self, @@ -261,115 +202,47 @@ def validate_monitor_warehouse( ): raise ValueError(f"Warehouse '{warehouse_name}' not found.") - def add_dashboard_udtfs( - self, - monitor_name: sql_identifier.SqlIdentifier, - model_name: sql_identifier.SqlIdentifier, - model_version_name: sql_identifier.SqlIdentifier, - task: type_hints.Task, - score_type: output_score_type.OutputScoreType, - output_columns: List[sql_identifier.SqlIdentifier], - ground_truth_columns: List[sql_identifier.SqlIdentifier], - statement_params: Optional[Dict[str, Any]] = None, - ) -> None: - udtf_name_query_map = self._create_dashboard_udtf_queries( - monitor_name, - model_name, - model_version_name, - task, - score_type, - output_columns, - ground_truth_columns, - ) - for udtf_query in udtf_name_query_map.values(): - query_result_checker.SqlResultValidator( - self._sql_client._session, - f"""{udtf_query}""", - statement_params=statement_params, - ).validate() - - def get_monitoring_table_fully_qualified_name( - self, - model_name: sql_identifier.SqlIdentifier, - model_version_name: sql_identifier.SqlIdentifier, - ) -> str: - table_name = f"{_SNOWML_MONITORING_TABLE_NAME_PREFIX}_{model_name}_{model_version_name}" - return table_manager.get_fully_qualified_table_name(self._database_name, self._schema_name, table_name) - - def get_accuracy_monitoring_table_fully_qualified_name( - self, - model_name: sql_identifier.SqlIdentifier, - model_version_name: sql_identifier.SqlIdentifier, - ) -> str: - table_name = f"{_SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX}_{model_name}_{model_version_name}" - return table_manager.get_fully_qualified_table_name(self._database_name, self._schema_name, table_name) - - def _create_dashboard_udtf_queries( - self, - monitor_name: sql_identifier.SqlIdentifier, - model_name: sql_identifier.SqlIdentifier, - model_version_name: sql_identifier.SqlIdentifier, - task: type_hints.Task, - score_type: output_score_type.OutputScoreType, - output_columns: List[sql_identifier.SqlIdentifier], - ground_truth_columns: List[sql_identifier.SqlIdentifier], - ) -> Mapping[str, str]: - query_files = files("snowflake.ml.monitoring._client") - # TODO(apgupta): Expand list of queries based on model objective and score type. - queries_list = [] - queries_list.extend(_DASHBOARD_UDTFS_COMMON_LIST) - if task == type_hints.Task.TABULAR_REGRESSION: - queries_list.extend(_DASHBOARD_UDTFS_REGRESSION_LIST) - var_map = { - "MODEL_MONITOR_NAME": monitor_name, - "MONITORING_TABLE": self.get_monitoring_table_fully_qualified_name(model_name, model_version_name), - "MONITORING_PRED_LABEL_JOINED_TABLE": self.get_accuracy_monitoring_table_fully_qualified_name( - model_name, model_version_name - ), - "OUTPUT_COLUMN_NAME": output_columns[0], - "GROUND_TRUTH_COLUMN_NAME": ground_truth_columns[0], - } - - udf_name_query_map = {} - for q in queries_list: - q_template = query_files.joinpath(f"queries/{q}.ssql").read_text() - q_actual = string.Template(q_template).substitute(var_map) - udf_name_query_map[q] = q_actual - return udf_name_query_map - - def _validate_columns_exist_in_source_table( + def _validate_columns_exist_in_source( self, *, - table_schema: Mapping[str, types.DataType], - source_table_name: sql_identifier.SqlIdentifier, + source_column_schema: Mapping[str, types.DataType], timestamp_column: sql_identifier.SqlIdentifier, - prediction_columns: List[sql_identifier.SqlIdentifier], - label_columns: List[sql_identifier.SqlIdentifier], + prediction_score_columns: List[sql_identifier.SqlIdentifier], + prediction_class_columns: List[sql_identifier.SqlIdentifier], + actual_score_columns: List[sql_identifier.SqlIdentifier], + actual_class_columns: List[sql_identifier.SqlIdentifier], id_columns: List[sql_identifier.SqlIdentifier], ) -> None: """Ensures all columns exist in the source table. Args: - table_schema: Dictionary of column names and types in the source table. - source_table_name: Name of the table with model data to monitor. + source_column_schema: Dictionary of column names and types in the source. timestamp_column: Name of the timestamp column. - prediction_columns: List of prediction column names. - label_columns: List of label column names. + prediction_score_columns: List of prediction score column names. + prediction_class_columns: List of prediction class names. + actual_score_columns: List of actual score column names. + actual_class_columns: List of actual class column names. id_columns: List of id column names. Raises: - ValueError: If any of the columns do not exist in the source table. + ValueError: If any of the columns do not exist in the source. """ - if timestamp_column not in table_schema: - raise ValueError(f"Timestamp column {timestamp_column} does not exist in table {source_table_name}.") + if timestamp_column not in source_column_schema: + raise ValueError(f"Timestamp column {timestamp_column} does not exist in source.") - if not all([column_name in table_schema for column_name in prediction_columns]): - raise ValueError(f"Prediction column(s): {prediction_columns} do not exist in table {source_table_name}.") - if not all([column_name in table_schema for column_name in label_columns]): - raise ValueError(f"Label column(s): {label_columns} do not exist in table {source_table_name}.") - if not all([column_name in table_schema for column_name in id_columns]): - raise ValueError(f"ID column(s): {id_columns} do not exist in table {source_table_name}.") + if not all([column_name in source_column_schema for column_name in prediction_score_columns]): + raise ValueError(f"Prediction Score column(s): {prediction_score_columns} do not exist in source.") + if not all([column_name in source_column_schema for column_name in prediction_class_columns]): + raise ValueError(f"Prediction Class column(s): {prediction_class_columns} do not exist in source.") + if not all([column_name in source_column_schema for column_name in actual_score_columns]): + raise ValueError(f"Actual Score column(s): {actual_score_columns} do not exist in source.") + + if not all([column_name in source_column_schema for column_name in actual_class_columns]): + raise ValueError(f"Actual Class column(s): {actual_class_columns} do not exist in source.") + + if not all([column_name in source_column_schema for column_name in id_columns]): + raise ValueError(f"ID column(s): {id_columns} do not exist in source.") def _validate_timestamp_column_type( self, table_schema: Mapping[str, types.DataType], timestamp_column: sql_identifier.SqlIdentifier @@ -490,190 +363,37 @@ def _validate_source_table_features_shape( f"Model function expected: {inputs} but got {table_schema_without_special_columns}" ) - def get_model_monitor_by_name( - self, - monitor_name: sql_identifier.SqlIdentifier, - statement_params: Optional[Dict[str, Any]] = None, - ) -> _ModelMonitorParams: - """Fetch metadata for a Model Monitor by name. - - Args: - monitor_name: Name of ModelMonitor to fetch. - statement_params: Optional set of statement_params to include with query. - - Returns: - _ModelMonitorParams dict with Name of monitor, fully qualified model name, - model version name, model function name, prediction_col, label_col. - - Raises: - ValueError: If multiple ModelMonitors exist with the same name. - """ - try: - res = ( - query_result_checker.SqlResultValidator( - self._sql_client._session, - f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME}, - {PREDICTION_COL_NAMES_COL_NAME}, {LABEL_COL_NAMES_COL_NAME} - FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME} - WHERE {MONITOR_NAME_COL_NAME} = '{monitor_name}'""", - statement_params=statement_params, - ) - .has_column(FQ_MODEL_NAME_COL_NAME) - .has_column(VERSION_NAME_COL_NAME) - .has_column(FUNCTION_NAME_COL_NAME) - .has_column(PREDICTION_COL_NAMES_COL_NAME) - .has_column(LABEL_COL_NAMES_COL_NAME) - .validate() - ) - except errors.DataError: - raise ValueError(f"Failed to find any monitor with name '{monitor_name}'") - - if len(res) > 1: - raise ValueError(f"Invalid state. Multiple Monitors exist with name '{monitor_name}'") - - return _ModelMonitorParams( - monitor_name=str(monitor_name), - fully_qualified_model_name=res[0][FQ_MODEL_NAME_COL_NAME], - version_name=res[0][VERSION_NAME_COL_NAME], - function_name=res[0][FUNCTION_NAME_COL_NAME], - prediction_columns=[ - sql_identifier.SqlIdentifier(prediction_column) - for prediction_column in json.loads(res[0][PREDICTION_COL_NAMES_COL_NAME]) - ], - label_columns=[ - sql_identifier.SqlIdentifier(label_column) - for label_column in json.loads(res[0][LABEL_COL_NAMES_COL_NAME]) - ], - ) - - def get_model_monitor_by_model_version( + def validate_source( self, *, - model_db: sql_identifier.SqlIdentifier, - model_schema: sql_identifier.SqlIdentifier, - model_name: sql_identifier.SqlIdentifier, - version_name: sql_identifier.SqlIdentifier, - statement_params: Optional[Dict[str, Any]] = None, - ) -> _ModelMonitorParams: - """Fetch metadata for a Model Monitor by model version. - - Args: - model_db: Database of model. - model_schema: Schema of model. - model_name: Model name. - version_name: Model version name - statement_params: Optional set of statement_params to include with queries. - - Returns: - _ModelMonitorParams dict with Name of monitor, fully qualified model name, - model version name, model function name, prediction_col, label_col. - - Raises: - ValueError: If multiple ModelMonitors exist with the same name. - """ - res = ( - query_result_checker.SqlResultValidator( - self._sql_client._session, - f"""SELECT {MONITOR_NAME_COL_NAME}, {FQ_MODEL_NAME_COL_NAME}, - {VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME} - FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME} - WHERE {FQ_MODEL_NAME_COL_NAME} = '{model_db}.{model_schema}.{model_name}' - AND {VERSION_NAME_COL_NAME} = '{version_name}'""", - statement_params=statement_params, - ) - .has_column(MONITOR_NAME_COL_NAME) - .has_column(FQ_MODEL_NAME_COL_NAME) - .has_column(VERSION_NAME_COL_NAME) - .has_column(FUNCTION_NAME_COL_NAME) - .validate() - ) - if len(res) > 1: - raise ValueError( - f"Invalid state. Multiple Monitors exist for model: '{model_name}' and version: '{version_name}'" - ) - return _ModelMonitorParams( - monitor_name=res[0][MONITOR_NAME_COL_NAME], - fully_qualified_model_name=res[0][FQ_MODEL_NAME_COL_NAME], - version_name=res[0][VERSION_NAME_COL_NAME], - function_name=res[0][FUNCTION_NAME_COL_NAME], - prediction_columns=[ - sql_identifier.SqlIdentifier(prediction_column) - for prediction_column in json.loads(res[0][PREDICTION_COL_NAMES_COL_NAME]) - ], - label_columns=[ - sql_identifier.SqlIdentifier(label_column) - for label_column in json.loads(res[0][LABEL_COL_NAMES_COL_NAME]) - ], - ) - - def get_score_type( - self, - task: type_hints.Task, - source_table_name: sql_identifier.SqlIdentifier, - prediction_columns: List[sql_identifier.SqlIdentifier], - ) -> output_score_type.OutputScoreType: - """Infer score type given model task and prediction table columns. - - Args: - task: Model task - source_table_name: Source data table containing model outputs. - prediction_columns: columns in source data table corresponding to model outputs. - - Returns: - OutputScoreType for model. - """ - table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types( - self._sql_client._session, - self._database_name, - self._schema_name, - source_table_name, - ) - return output_score_type.OutputScoreType.deduce_score_type(table_schema, prediction_columns, task) - - def validate_source_table( - self, - source_table_name: sql_identifier.SqlIdentifier, + source_database: Optional[sql_identifier.SqlIdentifier], + source_schema: Optional[sql_identifier.SqlIdentifier], + source: sql_identifier.SqlIdentifier, timestamp_column: sql_identifier.SqlIdentifier, - prediction_columns: List[sql_identifier.SqlIdentifier], - label_columns: List[sql_identifier.SqlIdentifier], + prediction_score_columns: List[sql_identifier.SqlIdentifier], + prediction_class_columns: List[sql_identifier.SqlIdentifier], + actual_score_columns: List[sql_identifier.SqlIdentifier], + actual_class_columns: List[sql_identifier.SqlIdentifier], id_columns: List[sql_identifier.SqlIdentifier], - model_function: model_manifest_schema.ModelFunctionInfo, ) -> None: - # Validate source table exists - if not table_manager.validate_table_exist( - self._sql_client._session, - source_table_name, - f"{self._database_name}.{self._schema_name}", - ): - raise ValueError( - f"Table {source_table_name} does not exist in schema {self._database_name}.{self._schema_name}." - ) - table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types( + source_database = source_database or self._database_name + source_schema = source_schema or self._schema_name + # Get Schema of the source. Implicitly validates that the source exists. + source_column_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types( self._sql_client._session, - self._database_name, - self._schema_name, - source_table_name, + source_database, + source_schema, + source, ) - self._validate_columns_exist_in_source_table( - table_schema=table_schema, - source_table_name=source_table_name, + self._validate_columns_exist_in_source( + source_column_schema=source_column_schema, timestamp_column=timestamp_column, - prediction_columns=prediction_columns, - label_columns=label_columns, + prediction_score_columns=prediction_score_columns, + prediction_class_columns=prediction_class_columns, + actual_score_columns=actual_score_columns, + actual_class_columns=actual_class_columns, id_columns=id_columns, ) - self._validate_column_types( - table_schema=table_schema, - timestamp_column=timestamp_column, - id_columns=id_columns, - prediction_columns=prediction_columns, - label_columns=label_columns, - ) - self._validate_source_table_features_shape( - table_schema=table_schema, - special_columns={timestamp_column, *id_columns, *prediction_columns, *label_columns}, - model_function=model_function, - ) def delete_monitor_metadata( self, @@ -691,645 +411,38 @@ def delete_monitor_metadata( WHERE {MONITOR_NAME_COL_NAME} = '{name}'""", ).collect(statement_params=statement_params) - def delete_baseline_table( - self, - fully_qualified_model_name: str, - version_name: str, - statement_params: Optional[Dict[str, Any]] = None, - ) -> None: - """Delete the baseline table corresponding to a particular model and version. - - Args: - fully_qualified_model_name: Fully qualified name of the model. - version_name: Name of the model version. - statement_params: Optional set of statement_params to include with query. - """ - table_name = _create_baseline_table_name(fully_qualified_model_name, version_name) - self._sql_client._session.sql( - f"""DROP TABLE IF EXISTS {self._database_name}.{self._schema_name}.{table_name}""" - ).collect(statement_params=statement_params) - - def delete_dynamic_tables( - self, - fully_qualified_model_name: str, - version_name: str, - statement_params: Optional[Dict[str, Any]] = None, - ) -> None: - """Delete the dynamic tables corresponding to a particular model and version. - - Args: - fully_qualified_model_name: Fully qualified name of the model. - version_name: Name of the model version. - statement_params: Optional set of statement_params to include with query. - """ - _, _, model_name = sql_identifier.parse_fully_qualified_name(fully_qualified_model_name) - model_id = sql_identifier.SqlIdentifier(model_name) - version_id = sql_identifier.SqlIdentifier(version_name) - monitoring_table_name = self.get_monitoring_table_fully_qualified_name(model_id, version_id) - self._sql_client._session.sql(f"""DROP DYNAMIC TABLE IF EXISTS {monitoring_table_name}""").collect( - statement_params=statement_params - ) - accuracy_table_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_id, version_id) - self._sql_client._session.sql(f"""DROP DYNAMIC TABLE IF EXISTS {accuracy_table_name}""").collect( - statement_params=statement_params - ) - - def create_monitor_on_model_version( - self, - monitor_name: sql_identifier.SqlIdentifier, - source_table_name: sql_identifier.SqlIdentifier, - fully_qualified_model_name: str, - version_name: sql_identifier.SqlIdentifier, - function_name: str, - timestamp_column: sql_identifier.SqlIdentifier, - prediction_columns: List[sql_identifier.SqlIdentifier], - label_columns: List[sql_identifier.SqlIdentifier], - id_columns: List[sql_identifier.SqlIdentifier], - task: type_hints.Task, - statement_params: Optional[Dict[str, Any]] = None, - ) -> None: - """ - Creates a ModelMonitor on a Model Version from the Snowflake Model Registry. Creates public schema for metadata. - - Args: - monitor_name: Name of monitor object to create. - source_table_name: Name of source data table to monitor. - fully_qualified_model_name: fully qualified name of model to monitor '..'. - version_name: model version name to monitor. - function_name: function_name to monitor in model version. - timestamp_column: timestamp column name. - prediction_columns: list of prediction column names. - label_columns: list of label column names. - id_columns: list of id column names. - task: Task of the model, e.g. TABULAR_REGRESSION. - statement_params: Optional dict of statement_params to include with queries. - - Raises: - ValueError: If model version is already monitored. - """ - # Validate monitor does not already exist on model version. - if self.validate_existence(fully_qualified_model_name, version_name, statement_params): - raise ValueError(f"Model {fully_qualified_model_name} Version {version_name} is already monitored!") - - if self.validate_existence_by_name(monitor_name, statement_params): - raise ValueError(f"Model Monitor with name '{monitor_name}' already exists!") - - prediction_columns_for_select = formatting.format_value_for_select(prediction_columns) - label_columns_for_select = formatting.format_value_for_select(label_columns) - id_columns_for_select = formatting.format_value_for_select(id_columns) - query_result_checker.SqlResultValidator( - self._sql_client._session, - textwrap.dedent( - f"""INSERT INTO {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME} - ({MONITOR_NAME_COL_NAME}, {SOURCE_TABLE_NAME_COL_NAME}, {FQ_MODEL_NAME_COL_NAME}, - {VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME}, {TASK_COL_NAME}, - {MONITORING_ENABLED_COL_NAME}, {TIMESTAMP_COL_NAME_COL_NAME}, - {PREDICTION_COL_NAMES_COL_NAME}, {LABEL_COL_NAMES_COL_NAME}, - {ID_COL_NAMES_COL_NAME}) - SELECT '{monitor_name}', '{source_table_name}', '{fully_qualified_model_name}', - '{version_name}', '{function_name}', '{task.value}', TRUE, '{timestamp_column}', - {prediction_columns_for_select}, {label_columns_for_select}, {id_columns_for_select}""" - ), - statement_params=statement_params, - ).insertion_success(expected_num_rows=1).validate() - - def initialize_baseline_table( - self, - model_name: sql_identifier.SqlIdentifier, - version_name: sql_identifier.SqlIdentifier, - source_table_name: str, - columns_to_drop: Optional[List[sql_identifier.SqlIdentifier]] = None, - statement_params: Optional[Dict[str, Any]] = None, - ) -> None: - """ - Initializes the baseline table for a Model Version. Creates schema for baseline data using the source table. - - Args: - model_name: name of model to monitor. - version_name: model version name to monitor. - source_table_name: name of the user's table containing their model data. - columns_to_drop: special columns in the source table to be excluded from baseline tables. - statement_params: Optional dict of statement_params to include with queries. - """ - table_schema = table_manager.get_table_schema_types( - self._sql_client._session, - database=self._database_name, - schema=self._schema_name, - table_name=source_table_name, - ) - - if columns_to_drop is None: - columns_to_drop = [] - - table_manager.create_single_table( - self._sql_client._session, - self._database_name, - self._schema_name, - _create_baseline_table_name(model_name, version_name), - [ - (k, type_utils.convert_sp_to_sf_type(v)) - for k, v in table_schema.items() - if sql_identifier.SqlIdentifier(k) not in columns_to_drop - ], - statement_params=statement_params, - ) - - def get_all_model_monitor_metadata( - self, - statement_params: Optional[Dict[str, Any]] = None, - ) -> List[snowpark.Row]: - """Get the metadata for all model monitors in the given schema. - - Args: - statement_params: Optional dict of statement_params to include with queries. - - Returns: - List of snowpark.Row containing metadata for each model monitor. - """ - return query_result_checker.SqlResultValidator( - self._sql_client._session, - f"""SELECT * - FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}""", - statement_params=statement_params, - ).validate() - - def materialize_baseline_dataframe( - self, - baseline_df: DataFrame, - fully_qualified_model_name: str, - model_version_name: sql_identifier.SqlIdentifier, - statement_params: Optional[Dict[str, Any]] = None, - ) -> None: - """ - Materialize baseline dataframe to a permanent snowflake table. This method - truncates (overwrite without dropping) any existing data in the baseline table. - - Args: - baseline_df: dataframe containing baseline data that monitored data will be compared against. - fully_qualified_model_name: name of the model. - model_version_name: model version name to monitor. - statement_params: Optional dict of statement_params to include with queries. - - Raises: - ValueError: If no baseline table was initialized. - """ - - _, _, model_name = sql_identifier.parse_fully_qualified_name(fully_qualified_model_name) - baseline_table_name = _create_baseline_table_name(model_name, model_version_name) - - baseline_table_exists = db_utils.db_object_exists( - self._sql_client._session, - db_utils.SnowflakeDbObjectType.TABLE, - sql_identifier.SqlIdentifier(baseline_table_name), - database_name=self._database_name, - schema_name=self._schema_name, - statement_params=statement_params, - ) - if not baseline_table_exists: - raise ValueError( - f"Baseline table '{baseline_table_name}' does not exist for model: " - f"'{model_name}' and model_version: '{model_version_name}'" - ) - - fully_qualified_baseline_table_name = [self._database_name, self._schema_name, baseline_table_name] - - try: - # Truncate overwrites by clearing the rows in the table, instead of dropping the table. - # This lets us keep the schema to validate the baseline_df against. - baseline_df.write.mode("truncate").save_as_table( - fully_qualified_baseline_table_name, statement_params=statement_params - ) - except exceptions.SnowparkSQLException as e: - raise ValueError( - f"""Failed to save baseline dataframe. - Ensure that the baseline dataframe columns match those provided in your monitored table: {e}""" - ) - - def _alter_monitor_dynamic_tables( + def _alter_monitor( self, operation: str, - model_name: sql_identifier.SqlIdentifier, - version_name: sql_identifier.SqlIdentifier, + monitor_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, ) -> None: if operation not in {"SUSPEND", "RESUME"}: raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables") - fq_monitor_dt_name = self.get_monitoring_table_fully_qualified_name(model_name, version_name) - query_result_checker.SqlResultValidator( - self._sql_client._session, - f"""ALTER DYNAMIC TABLE {fq_monitor_dt_name} {operation}""", - statement_params=statement_params, - ).has_column("status").has_dimensions(1, 1).validate() - - fq_accuracy_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, version_name) query_result_checker.SqlResultValidator( self._sql_client._session, - f"""ALTER DYNAMIC TABLE {fq_accuracy_dt_name} {operation}""", + f"""ALTER MODEL MONITOR {self._database_name}.{self._schema_name}.{monitor_name} {operation}""", statement_params=statement_params, ).has_column("status").has_dimensions(1, 1).validate() - def suspend_monitor_dynamic_tables( + def suspend_monitor( self, - model_name: sql_identifier.SqlIdentifier, - version_name: sql_identifier.SqlIdentifier, + monitor_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - self._alter_monitor_dynamic_tables( + self._alter_monitor( operation="SUSPEND", - model_name=model_name, - version_name=version_name, + monitor_name=monitor_name, statement_params=statement_params, ) - def resume_monitor_dynamic_tables( + def resume_monitor( self, - model_name: sql_identifier.SqlIdentifier, - version_name: sql_identifier.SqlIdentifier, + monitor_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - self._alter_monitor_dynamic_tables( + self._alter_monitor( operation="RESUME", - model_name=model_name, - version_name=version_name, + monitor_name=monitor_name, statement_params=statement_params, ) - - def create_dynamic_tables_for_monitor( - self, - *, - model_name: sql_identifier.SqlIdentifier, - model_version_name: sql_identifier.SqlIdentifier, - task: type_hints.Task, - source_table_name: sql_identifier.SqlIdentifier, - refresh_interval: model_monitor_interval.ModelMonitorRefreshInterval, - aggregation_window: model_monitor_interval.ModelMonitorAggregationWindow, - warehouse_name: sql_identifier.SqlIdentifier, - timestamp_column: sql_identifier.SqlIdentifier, - id_columns: List[sql_identifier.SqlIdentifier], - prediction_columns: List[sql_identifier.SqlIdentifier], - label_columns: List[sql_identifier.SqlIdentifier], - score_type: output_score_type.OutputScoreType, - ) -> None: - table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types( - self._sql_client._session, - self._database_name, - self._schema_name, - source_table_name, - ) - (numeric_features_names, categorical_feature_names) = _infer_numeric_categoric_feature_column_names( - source_table_schema=table_schema, - timestamp_column=timestamp_column, - id_columns=id_columns, - prediction_columns=prediction_columns, - label_columns=label_columns, - ) - features_dynamic_table_query = self._monitoring_dynamic_table_query( - model_name=model_name, - model_version_name=model_version_name, - source_table_name=source_table_name, - refresh_interval=refresh_interval, - aggregate_window=aggregation_window, - warehouse_name=warehouse_name, - timestamp_column=timestamp_column, - numeric_features=numeric_features_names, - categoric_features=categorical_feature_names, - prediction_columns=prediction_columns, - label_columns=label_columns, - ) - query_result_checker.SqlResultValidator(self._sql_client._session, features_dynamic_table_query).has_column( - "status" - ).has_dimensions(1, 1).validate() - - label_pred_join_table_query = self._monitoring_accuracy_table_query( - model_name=model_name, - model_version_name=model_version_name, - task=task, - source_table_name=source_table_name, - refresh_interval=refresh_interval, - aggregate_window=aggregation_window, - warehouse_name=warehouse_name, - timestamp_column=timestamp_column, - prediction_columns=prediction_columns, - label_columns=label_columns, - score_type=score_type, - ) - query_result_checker.SqlResultValidator(self._sql_client._session, label_pred_join_table_query).has_column( - "status" - ).has_dimensions(1, 1).validate() - - def _monitoring_dynamic_table_query( - self, - *, - model_name: sql_identifier.SqlIdentifier, - model_version_name: sql_identifier.SqlIdentifier, - source_table_name: sql_identifier.SqlIdentifier, - refresh_interval: ModelMonitorRefreshInterval, - aggregate_window: ModelMonitorAggregationWindow, - warehouse_name: sql_identifier.SqlIdentifier, - timestamp_column: sql_identifier.SqlIdentifier, - numeric_features: List[sql_identifier.SqlIdentifier], - categoric_features: List[sql_identifier.SqlIdentifier], - prediction_columns: List[sql_identifier.SqlIdentifier], - label_columns: List[sql_identifier.SqlIdentifier], - ) -> str: - """ - Generates a dynamic table query for Observability - Monitoring. - - Args: - model_name: Model name to monitor. - model_version_name: Model version name to monitor. - source_table_name: Name of source data table to monitor. - refresh_interval: Refresh interval in minutes. - aggregate_window: Aggregate window minutes. - warehouse_name: Warehouse name to use for dynamic table. - timestamp_column: Timestamp column name. - numeric_features: List of numeric features to capture. - categoric_features: List of categoric features to capture. - prediction_columns: List of columns that contain model inference outputs. - label_columns: List of columns that contain ground truth values. - - Raises: - ValueError: If multiple output/ground truth columns are specified. MultiClass models are not yet supported. - - Returns: - Dynamic table query. - """ - # output and ground cols are list to keep interface extensible. - # for prpr only one label and one output col will be supported - if len(prediction_columns) != 1 or len(label_columns) != 1: - raise ValueError("Multiple Output columns are not supported in monitoring") - - monitoring_dt_name = self.get_monitoring_table_fully_qualified_name(model_name, model_version_name) - - feature_cols_query_list = [] - for feature in numeric_features + prediction_columns + label_columns: - feature_cols_query_list.append( - """ - OBJECT_CONSTRUCT( - 'sketch', APPROX_PERCENTILE_ACCUMULATE({col}), - 'count', count_if({col} is not null), - 'count_null', count_if({col} is null), - 'min', min({col}), - 'max', max({col}), - 'sum', sum({col}) - ) AS {col}""".format( - col=feature - ) - ) - - for col in categoric_features: - feature_cols_query_list.append( - f""" - {self._database_name}.{self._schema_name}.OBJECT_SUM(to_varchar({col})) AS {col}""" - ) - feature_cols_query = ",".join(feature_cols_query_list) - - return f""" - CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name} - TARGET_LAG = '{refresh_interval.minutes} minutes' - WAREHOUSE = {warehouse_name} - REFRESH_MODE = AUTO - INITIALIZE = ON_CREATE - AS - SELECT - TIME_SLICE({timestamp_column}, {aggregate_window.minutes}, 'MINUTE') timestamp,{feature_cols_query} - FROM - {source_table_name} - GROUP BY - 1 - """ - - def _monitoring_accuracy_table_query( - self, - *, - model_name: sql_identifier.SqlIdentifier, - model_version_name: sql_identifier.SqlIdentifier, - task: type_hints.Task, - source_table_name: sql_identifier.SqlIdentifier, - refresh_interval: ModelMonitorRefreshInterval, - aggregate_window: ModelMonitorAggregationWindow, - warehouse_name: sql_identifier.SqlIdentifier, - timestamp_column: sql_identifier.SqlIdentifier, - prediction_columns: List[sql_identifier.SqlIdentifier], - label_columns: List[sql_identifier.SqlIdentifier], - score_type: output_score_type.OutputScoreType, - ) -> str: - # output and ground cols are list to keep interface extensible. - # for prpr only one label and one output col will be supported - if len(prediction_columns) != 1 or len(label_columns) != 1: - raise ValueError("Multiple Output columns are not supported in monitoring") - if task == type_hints.Task.TABULAR_BINARY_CLASSIFICATION: - return self._monitoring_classification_accuracy_table_query( - model_name=model_name, - model_version_name=model_version_name, - source_table_name=source_table_name, - refresh_interval=refresh_interval, - aggregate_window=aggregate_window, - warehouse_name=warehouse_name, - timestamp_column=timestamp_column, - prediction_columns=prediction_columns, - label_columns=label_columns, - score_type=score_type, - ) - else: - return self._monitoring_regression_accuracy_table_query( - model_name=model_name, - model_version_name=model_version_name, - source_table_name=source_table_name, - refresh_interval=refresh_interval, - aggregate_window=aggregate_window, - warehouse_name=warehouse_name, - timestamp_column=timestamp_column, - prediction_columns=prediction_columns, - label_columns=label_columns, - ) - - def _monitoring_regression_accuracy_table_query( - self, - *, - model_name: sql_identifier.SqlIdentifier, - model_version_name: sql_identifier.SqlIdentifier, - source_table_name: sql_identifier.SqlIdentifier, - refresh_interval: ModelMonitorRefreshInterval, - aggregate_window: ModelMonitorAggregationWindow, - warehouse_name: sql_identifier.SqlIdentifier, - timestamp_column: sql_identifier.SqlIdentifier, - prediction_columns: List[sql_identifier.SqlIdentifier], - label_columns: List[sql_identifier.SqlIdentifier], - ) -> str: - """ - Generates a dynamic table query for Monitoring - regression model accuracy. - - Args: - model_name: Model name to monitor. - model_version_name: Model version name to monitor. - source_table_name: Name of source data table to monitor. - refresh_interval: Refresh interval in minutes. - aggregate_window: Aggregate window minutes. - warehouse_name: Warehouse name to use for dynamic table. - timestamp_column: Timestamp column name. - prediction_columns: List of output columns. - label_columns: List of ground truth columns. - - Returns: - Dynamic table query. - - Raises: - ValueError: If output columns are not same as ground truth columns. - - """ - - if len(prediction_columns) != len(label_columns): - raise ValueError(f"Mismatch in output & ground truth columns: {prediction_columns} != {label_columns}") - - monitoring_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, model_version_name) - - output_cols_query_list = [] - - output_cols_query_list.append( - f""" - OBJECT_CONSTRUCT( - 'sum_difference_label_pred', sum({prediction_columns[0]} - {label_columns[0]}), - 'sum_log_difference_square_label_pred', - sum( - case - when {prediction_columns[0]} > -1 and {label_columns[0]} > -1 - then pow(ln({prediction_columns[0]} + 1) - ln({label_columns[0]} + 1),2) - else null - END - ), - 'sum_difference_squares_label_pred', - sum( - pow( - {prediction_columns[0]} - {label_columns[0]}, - 2 - ) - ), - 'sum_absolute_regression_labels', sum(abs({label_columns[0]})), - 'sum_absolute_percentage_error', - sum( - abs( - div0null( - ({prediction_columns[0]} - {label_columns[0]}), - {label_columns[0]} - ) - ) - ), - 'sum_absolute_difference_label_pred', - sum( - abs({prediction_columns[0]} - {label_columns[0]}) - ), - 'sum_prediction', sum({prediction_columns[0]}), - 'sum_label', sum({label_columns[0]}), - 'count', count(*) - ) AS AGGREGATE_METRICS, - APPROX_PERCENTILE_ACCUMULATE({prediction_columns[0]}) prediction_sketch, - APPROX_PERCENTILE_ACCUMULATE({label_columns[0]}) label_sketch""" - ) - output_cols_query = ", ".join(output_cols_query_list) - - return f""" - CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name} - TARGET_LAG = '{refresh_interval.minutes} minutes' - WAREHOUSE = {warehouse_name} - REFRESH_MODE = AUTO - INITIALIZE = ON_CREATE - AS - SELECT - TIME_SLICE({timestamp_column}, {aggregate_window.minutes}, 'MINUTE') timestamp, - 'class_regression' label_class,{output_cols_query} - FROM - {source_table_name} - GROUP BY - 1 - """ - - def _monitoring_classification_accuracy_table_query( - self, - *, - model_name: sql_identifier.SqlIdentifier, - model_version_name: sql_identifier.SqlIdentifier, - source_table_name: sql_identifier.SqlIdentifier, - refresh_interval: ModelMonitorRefreshInterval, - aggregate_window: ModelMonitorAggregationWindow, - warehouse_name: sql_identifier.SqlIdentifier, - timestamp_column: sql_identifier.SqlIdentifier, - prediction_columns: List[sql_identifier.SqlIdentifier], - label_columns: List[sql_identifier.SqlIdentifier], - score_type: output_score_type.OutputScoreType, - ) -> str: - monitoring_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, model_version_name) - - # Initialize the select clause components - select_clauses = [] - - select_clauses.append( - f""" - {prediction_columns[0]}, - {label_columns[0]}, - CASE - WHEN {label_columns[0]} = 1 THEN 'class_positive' - ELSE 'class_negative' - END AS label_class""" - ) - - # Join all the select clauses into a single string - select_clause = f"{timestamp_column} AS timestamp," + ",".join(select_clauses) - - # Create the final CTE query - cte_query = f""" - WITH filtered_data AS ( - SELECT - {select_clause} - FROM - {source_table_name} - )""" - - # Initialize the select clause components - select_clauses = [] - - score_type_agg_clause = "" - if score_type == output_score_type.OutputScoreType.PROBITS: - score_type_agg_clause = f""" - 'sum_log_loss', - CASE - WHEN label_class = 'class_positive' THEN sum(-ln({prediction_columns[0]})) - ELSE sum(-ln(1 - {prediction_columns[0]})) - END,""" - else: - score_type_agg_clause = f""" - 'tp', count_if({label_columns[0]} = 1 AND {prediction_columns[0]} = 1), - 'tn', count_if({label_columns[0]} = 0 AND {prediction_columns[0]} = 0), - 'fp', count_if({label_columns[0]} = 0 AND {prediction_columns[0]} = 1), - 'fn', count_if({label_columns[0]} = 1 AND {prediction_columns[0]} = 0),""" - - select_clauses.append( - f""" - label_class, - OBJECT_CONSTRUCT( - 'sum_prediction', sum({prediction_columns[0]}), - 'sum_label', sum({label_columns[0]}),{score_type_agg_clause} - 'count', count(*) - ) AS AGGREGATE_METRICS, - APPROX_PERCENTILE_ACCUMULATE({prediction_columns[0]}) prediction_sketch, - APPROX_PERCENTILE_ACCUMULATE({label_columns[0]}) label_sketch""" - ) - - # Join all the select clauses into a single string - select_clause = ",\n".join(select_clauses) - - return f""" - CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name} - TARGET_LAG = '{refresh_interval.minutes} minutes' - WAREHOUSE = {warehouse_name} - REFRESH_MODE = AUTO - INITIALIZE = ON_CREATE - AS{cte_query} - select - time_slice(timestamp, {aggregate_window.minutes}, 'MINUTE') timestamp,{select_clause} - FROM - filtered_data - group by - 1, - 2 - """ diff --git a/snowflake/ml/monitoring/_client/model_monitor_sql_client_server_test.py b/snowflake/ml/monitoring/_client/model_monitor_sql_client_server_test.py new file mode 100644 index 00000000..daa2ef05 --- /dev/null +++ b/snowflake/ml/monitoring/_client/model_monitor_sql_client_server_test.py @@ -0,0 +1,215 @@ +from typing import Optional, cast + +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.monitoring._client import model_monitor_sql_client +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import Row, Session + + +class ModelMonitorSqlClientServerTest(absltest.TestCase): + """Test the ModelMonitorSqlClientServer class when calling server side Model Monitor SQL.""" + + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.test_schema = sql_identifier.SqlIdentifier("TEST_SCHEMA") + self.test_db = sql_identifier.SqlIdentifier("TEST_DB") + session = cast(Session, self.m_session) + self.monitor_sql_client = model_monitor_sql_client.ModelMonitorSQLClient( + session, database_name=self.test_db, schema_name=self.test_schema + ) + + self.model_name = sql_identifier.SqlIdentifier("MODEL") + self.model_version = sql_identifier.SqlIdentifier("VERSION") + self.model_function = sql_identifier.SqlIdentifier("FUNCTION") + self.warehouse_name = sql_identifier.SqlIdentifier("WAREHOUSE") + self.source = sql_identifier.SqlIdentifier("SOURCE") + self.id_columns = [sql_identifier.SqlIdentifier("ID")] + self.timestamp_column = sql_identifier.SqlIdentifier("TIMESTAMP_COLUMN") + self.refresh_interval = "1 day" + self.aggregation_window = "1 day" + + self.prediction_score_columns = [sql_identifier.SqlIdentifier("PRED_SCORE")] + self.prediction_class_columns = [sql_identifier.SqlIdentifier("PRED_CLASS")] + self.actual_score_columns = [sql_identifier.SqlIdentifier("ACTUAL_SCORE")] + self.actual_class_columns = [sql_identifier.SqlIdentifier("ACTUAL_CLASS")] + + def tearDown(self) -> None: + self.m_session.finalize() + + def _build_expected_create_model_monitor_sql( + self, + id_cols_sql: str, + baseline: Optional[str] = None, + db_override: Optional[str] = None, + schema_override: Optional[str] = None, + ) -> str: + fq_schema = ( + f"{db_override}.{schema_override}" + if db_override and schema_override + else f"{self.test_db}.{self.test_schema}" + ) + baseline_sql = f"BASELINE='{fq_schema}.{baseline}'" if baseline else "" + return f""" + CREATE MODEL MONITOR {fq_schema}.M + WITH + MODEL='{fq_schema}.{self.model_name}' + VERSION='{self.model_version}' + FUNCTION='{self.model_function}' + WAREHOUSE='{self.warehouse_name}' + SOURCE='{fq_schema}.{self.source}' + ID_COLUMNS={id_cols_sql} + PREDICTION_SCORE_COLUMNS=('PRED_SCORE') + PREDICTION_CLASS_COLUMNS=('PRED_CLASS') + ACTUAL_SCORE_COLUMNS=('ACTUAL_SCORE') + ACTUAL_CLASS_COLUMNS=('ACTUAL_CLASS') + TIMESTAMP_COLUMN='{self.timestamp_column}' + REFRESH_INTERVAL='{self.refresh_interval}' + AGGREGATION_WINDOW='{self.aggregation_window}' + {baseline_sql} + """ + + def test_build_sql_list_from_columns(self) -> None: + columns = [sql_identifier.SqlIdentifier("col1")] + res = model_monitor_sql_client._build_sql_list_from_columns(columns) + self.assertEqual(res, "('COL1')") + + columns = [sql_identifier.SqlIdentifier("col1"), sql_identifier.SqlIdentifier("col2")] + res = model_monitor_sql_client._build_sql_list_from_columns(columns) + self.assertEqual(res, "('COL1', 'COL2')") + + columns = [] + res = model_monitor_sql_client._build_sql_list_from_columns(columns) + self.assertEqual(res, "()") + + def test_show_model_monitors(self) -> None: + self.m_session.add_mock_sql( + f"SHOW MODEL MONITORS IN {self.test_db}.{self.test_schema}", + result=mock_data_frame.MockDataFrame([Row(name="TEST")]), + ) + res = self.monitor_sql_client.show_model_monitors() + self.assertEqual(res[0]["name"], "TEST") + + def test_create_model_monitor(self) -> None: + self.m_session.add_mock_sql( + self._build_expected_create_model_monitor_sql(id_cols_sql="('ID')"), + result=mock_data_frame.MockDataFrame([Row(status="success")]), + ) + self.monitor_sql_client.create_model_monitor( + monitor_database=None, + monitor_schema=None, + monitor_name=sql_identifier.SqlIdentifier("m"), + source_database=None, + source_schema=None, + source=self.source, + model_database=None, + model_schema=None, + model_name=self.model_name, + version_name=self.model_version, + function_name=self.model_function, + warehouse_name=self.warehouse_name, + timestamp_column=self.timestamp_column, + id_columns=self.id_columns, + prediction_score_columns=self.prediction_score_columns, + prediction_class_columns=self.prediction_class_columns, + actual_score_columns=self.actual_score_columns, + actual_class_columns=self.actual_class_columns, + refresh_interval=self.refresh_interval, + aggregation_window=self.aggregation_window, + ) + + def test_create_model_monitor_multiple_id_cols(self) -> None: + self.m_session.add_mock_sql( + self._build_expected_create_model_monitor_sql(id_cols_sql="('ID1', 'ID2')"), + result=mock_data_frame.MockDataFrame([Row(status="success")]), + ) + self.monitor_sql_client.create_model_monitor( + monitor_database=None, + monitor_schema=None, + monitor_name=sql_identifier.SqlIdentifier("m"), + source_database=None, + source_schema=None, + source=self.source, + model_database=None, + model_schema=None, + model_name=self.model_name, + version_name=self.model_version, + function_name=self.model_function, + warehouse_name=self.warehouse_name, + timestamp_column=self.timestamp_column, + id_columns=[sql_identifier.SqlIdentifier("ID1"), sql_identifier.SqlIdentifier("ID2")], + prediction_score_columns=self.prediction_score_columns, + prediction_class_columns=self.prediction_class_columns, + actual_score_columns=self.actual_score_columns, + actual_class_columns=self.actual_class_columns, + refresh_interval=self.refresh_interval, + aggregation_window=self.aggregation_window, + ) + + def test_create_model_monitor_empty_id_cols(self) -> None: + self.m_session.add_mock_sql( + self._build_expected_create_model_monitor_sql(id_cols_sql="()"), + result=mock_data_frame.MockDataFrame([Row(status="success")]), + ) + self.monitor_sql_client.create_model_monitor( + monitor_database=None, + monitor_schema=None, + monitor_name=sql_identifier.SqlIdentifier("m"), + source_database=None, + source_schema=None, + source=self.source, + model_database=None, + model_schema=None, + model_name=self.model_name, + version_name=self.model_version, + function_name=self.model_function, + warehouse_name=self.warehouse_name, + timestamp_column=self.timestamp_column, + id_columns=[], + prediction_score_columns=self.prediction_score_columns, + prediction_class_columns=self.prediction_class_columns, + actual_score_columns=self.actual_score_columns, + actual_class_columns=self.actual_class_columns, + refresh_interval=self.refresh_interval, + aggregation_window=self.aggregation_window, + ) + + def test_create_model_monitor_objects_in_different_schemas(self) -> None: + override_db = sql_identifier.SqlIdentifier("OVERRIDE_DB") + override_schema = sql_identifier.SqlIdentifier("OVERRIDE_SCHEMA") + self.m_session.add_mock_sql( + self._build_expected_create_model_monitor_sql( + id_cols_sql="()", baseline="BASELINE", db_override=override_db, schema_override=override_schema + ), + result=mock_data_frame.MockDataFrame([Row(status="success")]), + ) + self.monitor_sql_client.create_model_monitor( + monitor_database=override_db, + monitor_schema=override_schema, + monitor_name=sql_identifier.SqlIdentifier("m"), + source_database=override_db, + source_schema=override_schema, + source=self.source, + model_database=override_db, + model_schema=override_schema, + model_name=self.model_name, + version_name=self.model_version, + function_name=self.model_function, + warehouse_name=self.warehouse_name, + timestamp_column=self.timestamp_column, + id_columns=[], + prediction_score_columns=self.prediction_score_columns, + prediction_class_columns=self.prediction_class_columns, + actual_score_columns=self.actual_score_columns, + actual_class_columns=self.actual_class_columns, + refresh_interval=self.refresh_interval, + aggregation_window=self.aggregation_window, + baseline_database=override_db, + baseline_schema=override_schema, + baseline=sql_identifier.SqlIdentifier("BASELINE"), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/monitoring/_client/model_monitor_sql_client_test.py b/snowflake/ml/monitoring/_client/model_monitor_sql_client_test.py index 66559422..7580e610 100644 --- a/snowflake/ml/monitoring/_client/model_monitor_sql_client_test.py +++ b/snowflake/ml/monitoring/_client/model_monitor_sql_client_test.py @@ -4,16 +4,9 @@ from absl.testing import absltest from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.model import model_signature, type_hints -from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.monitoring._client import model_monitor_sql_client -from snowflake.ml.monitoring.entities import output_score_type -from snowflake.ml.monitoring.entities.model_monitor_interval import ( - ModelMonitorAggregationWindow, - ModelMonitorRefreshInterval, -) from snowflake.ml.test_utils import mock_data_frame, mock_session -from snowflake.snowpark import DataFrame, Row, Session, types +from snowflake.snowpark import Row, Session, types class ModelMonitorSqlClientTest(absltest.TestCase): @@ -40,17 +33,6 @@ def setUp(self) -> None: session, database_name=self.test_db_name, schema_name=self.test_schema_name ) - self.mon_table_name = ( - f"{model_monitor_sql_client._SNOWML_MONITORING_TABLE_NAME_PREFIX}_" - + self.test_model_name - + f"_{self.test_model_version_name}" - ) - self.acc_table_name = ( - f"{model_monitor_sql_client._SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX}_" - + self.test_model_name - + f"_{self.test_model_version_name}" - ) - def test_validate_source_table(self) -> None: mocked_table_out = mock.MagicMock(name="schema") self.m_session.table = mock.MagicMock(name="table", return_value=mocked_table_out) @@ -61,24 +43,16 @@ def test_validate_source_table(self) -> None: types.StructField(self.test_label_column_name, types.DoubleType()), types.StructField(self.test_id_column_name, types.StringType()), ] - - self.m_session.add_mock_sql( - query=f"""SHOW TABLES LIKE '{self.test_source_table_name}' IN SNOWML_OBSERVABILITY.DATA""", - result=mock_data_frame.MockDataFrame([Row(name=self.test_source_table_name)]), - ) - self.monitor_sql_client.validate_source_table( - source_table_name=self.test_source_table_name, + self.monitor_sql_client.validate_source( + source_database=None, + source_schema=None, + source=self.test_source_table_name, timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), id_columns=[sql_identifier.SqlIdentifier("ID")], - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - model_function=model_manifest_schema.ModelFunctionInfo( - name="PREDICT", - target_method="predict", - target_method_function_type="FUNCTION", - signature=model_signature.ModelSignature(inputs=[], outputs=[]), - is_partitioned=False, - ), + prediction_score_columns=[], + prediction_class_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + actual_score_columns=[], + actual_class_columns=[sql_identifier.SqlIdentifier("LABEL")], ) self.m_session.table.assert_called_once_with( f"{self.test_db_name}.{self.test_schema_name}.{self.test_source_table_name}" @@ -97,79 +71,22 @@ def test_validate_source_table_shape(self) -> None: types.StructField("feature1", types.StringType()), ] - self.m_session.add_mock_sql( - query=f"""SHOW TABLES LIKE '{self.test_source_table_name}' IN SNOWML_OBSERVABILITY.DATA""", - result=mock_data_frame.MockDataFrame([Row(name=self.test_source_table_name)]), - ) - self.monitor_sql_client.validate_source_table( - source_table_name=self.test_source_table_name, + self.monitor_sql_client.validate_source( + source_database=None, + source_schema=None, + source=self.test_source_table_name, timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), id_columns=[sql_identifier.SqlIdentifier("ID")], - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - model_function=model_manifest_schema.ModelFunctionInfo( - name="PREDICT", - target_method="predict", - target_method_function_type="FUNCTION", - signature=model_signature.ModelSignature( - inputs=[ - model_signature.FeatureSpec("input_feature_0", model_signature.DataType.STRING), - ], - outputs=[], - ), - is_partitioned=False, - ), + prediction_class_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + prediction_score_columns=[], + actual_score_columns=[sql_identifier.SqlIdentifier("LABEL")], + actual_class_columns=[], ) self.m_session.table.assert_called_once_with( f"{self.test_db_name}.{self.test_schema_name}.{self.test_source_table_name}" ) self.m_session.finalize() - def test_validate_source_table_shape_does_not_match_function_signature(self) -> None: - mocked_table_out = mock.MagicMock(name="schema") - self.m_session.table = mock.MagicMock(name="table", return_value=mocked_table_out) - mocked_table_out.schema = mock.MagicMock(name="schema") - mocked_table_out.schema.fields = [ - types.StructField(self.test_timestamp_column, types.TimestampType()), - types.StructField(self.test_prediction_column_name, types.DoubleType()), - types.StructField(self.test_label_column_name, types.DoubleType()), - types.StructField(self.test_id_column_name, types.StringType()), - types.StructField("feature1", types.StringType()), - ] - - self.m_session.add_mock_sql( - query=f"""SHOW TABLES LIKE '{self.test_source_table_name}' IN SNOWML_OBSERVABILITY.DATA""", - result=mock_data_frame.MockDataFrame([Row(name=self.test_source_table_name)]), - ) - - expected_msg = ( - r"Model function input types do not match the source table input columns types\. Model function expected: " - r"\[FeatureSpec\(dtype=DataType\.STRING, name='input_feature_0'\), FeatureSpec\(dtype=DataType\.STRING, " - r"name='unexpected_feature'\)\] but got \{'FEATURE1': StringType\(\)\}" - ) - with self.assertRaisesRegex(ValueError, expected_msg): - self.monitor_sql_client.validate_source_table( - source_table_name=self.test_source_table_name, - timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - id_columns=[sql_identifier.SqlIdentifier("ID")], - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - model_function=model_manifest_schema.ModelFunctionInfo( - name="PREDICT", - target_method="predict", - target_method_function_type="FUNCTION", - signature=model_signature.ModelSignature( - inputs=[ - model_signature.FeatureSpec("input_feature_0", model_signature.DataType.STRING), - model_signature.FeatureSpec("unexpected_feature", model_signature.DataType.STRING), - ], - outputs=[], - ), - is_partitioned=False, - ), - ) - self.m_session.finalize() - def test_validate_monitor_warehouse(self) -> None: self.m_session.add_mock_sql( query=f"""SHOW WAREHOUSES LIKE '{self.test_wh_name}'""", @@ -178,34 +95,7 @@ def test_validate_monitor_warehouse(self) -> None: with self.assertRaisesRegex(ValueError, f"Warehouse '{self.test_wh_name}' not found"): self.monitor_sql_client.validate_monitor_warehouse(self.test_wh_name) - def test_validate_source_table_not_exists(self) -> None: - self.m_session.add_mock_sql( - query=f"""SHOW TABLES LIKE '{self.test_source_table_name}' IN SNOWML_OBSERVABILITY.DATA""", - result=mock_data_frame.MockDataFrame([]), - ) - expected_msg = ( - f"Table {self.test_source_table_name} does not exist in schema {self.test_db_name}.{self.test_schema_name}." - ) - with self.assertRaisesRegex(ValueError, expected_msg): - self.monitor_sql_client.validate_source_table( - source_table_name=self.test_source_table_name, - timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - id_columns=[sql_identifier.SqlIdentifier("ID")], - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - model_function=model_manifest_schema.ModelFunctionInfo( - name="PREDICT", - target_method="predict", - target_method_function_type="FUNCTION", - signature=model_signature.ModelSignature(inputs=[], outputs=[]), - is_partitioned=False, - ), - ) - self.m_session.finalize() - def test_validate_columns_exist_in_source_table(self) -> None: - source_table_name = self.test_source_table_name - table_schema = { "feature1": types.StringType(), "feature2": types.StringType(), @@ -215,12 +105,13 @@ def test_validate_columns_exist_in_source_table(self) -> None: "LABEL": types.DoubleType(), "ID": types.StringType(), } - self.monitor_sql_client._validate_columns_exist_in_source_table( - table_schema=table_schema, - source_table_name=source_table_name, + self.monitor_sql_client._validate_columns_exist_in_source( + source_column_schema=table_schema, timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], + prediction_score_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + prediction_class_columns=[], + actual_score_columns=[sql_identifier.SqlIdentifier("LABEL")], + actual_class_columns=[], id_columns=[sql_identifier.SqlIdentifier("ID")], ) @@ -232,13 +123,14 @@ def test_validate_columns_exist_in_source_table(self) -> None: "LABEL": types.DoubleType(), "ID": types.StringType(), } - with self.assertRaisesRegex(ValueError, "Timestamp column TIMESTAMP does not exist in table MODEL_OUTPUTS"): - self.monitor_sql_client._validate_columns_exist_in_source_table( - table_schema=table_schema, - source_table_name=source_table_name, + with self.assertRaisesRegex(ValueError, "Timestamp column TIMESTAMP does not exist in source"): + self.monitor_sql_client._validate_columns_exist_in_source( + source_column_schema=table_schema, timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], + prediction_score_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + prediction_class_columns=[], + actual_class_columns=[sql_identifier.SqlIdentifier("LABEL")], + actual_score_columns=[], id_columns=[sql_identifier.SqlIdentifier("ID")], ) @@ -252,14 +144,15 @@ def test_validate_columns_exist_in_source_table(self) -> None: } with self.assertRaisesRegex( - ValueError, r"Prediction column\(s\): \['PREDICTION'\] do not exist in table MODEL_OUTPUTS." + ValueError, r"Prediction Class column\(s\): \['PREDICTION'\] do not exist in source." ): - self.monitor_sql_client._validate_columns_exist_in_source_table( - table_schema=table_schema, - source_table_name=source_table_name, + self.monitor_sql_client._validate_columns_exist_in_source( + source_column_schema=table_schema, timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], + prediction_class_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + prediction_score_columns=[], + actual_class_columns=[], + actual_score_columns=[sql_identifier.SqlIdentifier("LABEL")], id_columns=[sql_identifier.SqlIdentifier("ID")], ) @@ -271,13 +164,14 @@ def test_validate_columns_exist_in_source_table(self) -> None: "PREDICTION": types.DoubleType(), "ID": types.StringType(), } - with self.assertRaisesRegex(ValueError, r"Label column\(s\): \['LABEL'\] do not exist in table MODEL_OUTPUTS."): - self.monitor_sql_client._validate_columns_exist_in_source_table( - table_schema=table_schema, - source_table_name=source_table_name, + with self.assertRaisesRegex(ValueError, r"Actual Class column\(s\): \['LABEL'\] do not exist in source."): + self.monitor_sql_client._validate_columns_exist_in_source( + source_column_schema=table_schema, timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], + prediction_score_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + prediction_class_columns=[], + actual_class_columns=[sql_identifier.SqlIdentifier("LABEL")], + actual_score_columns=[], id_columns=[sql_identifier.SqlIdentifier("ID")], ) @@ -289,13 +183,14 @@ def test_validate_columns_exist_in_source_table(self) -> None: "PREDICTION": types.DoubleType(), "LABEL": types.DoubleType(), } - with self.assertRaisesRegex(ValueError, r"ID column\(s\): \['ID'\] do not exist in table MODEL_OUTPUTS"): - self.monitor_sql_client._validate_columns_exist_in_source_table( - table_schema=table_schema, - source_table_name=source_table_name, + with self.assertRaisesRegex(ValueError, r"ID column\(s\): \['ID'\] do not exist in source."): + self.monitor_sql_client._validate_columns_exist_in_source( + source_column_schema=table_schema, timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], + prediction_score_columns=[sql_identifier.SqlIdentifier("PREDICTION")], + prediction_class_columns=[], + actual_class_columns=[sql_identifier.SqlIdentifier("LABEL")], + actual_score_columns=[], id_columns=[sql_identifier.SqlIdentifier("ID")], ) @@ -388,517 +283,35 @@ def test_validate_id_columns_types_all_string(self) -> None: ], ) - def test_monitoring_dynamic_table_query_single_numeric_single_categoric(self) -> None: - query = self.monitor_sql_client._monitoring_dynamic_table_query( - model_name=self.test_model_name, - model_version_name=self.test_model_version_name, - source_table_name=self.test_source_table_name, - refresh_interval=ModelMonitorRefreshInterval("15 minutes"), - aggregate_window=ModelMonitorAggregationWindow.WINDOW_1_HOUR, - warehouse_name=self.test_wh_name, - timestamp_column=self.test_timestamp_column, - numeric_features=[sql_identifier.SqlIdentifier("NUM_0")], - categoric_features=[sql_identifier.SqlIdentifier("STR_COL_0")], - prediction_columns=[sql_identifier.SqlIdentifier("OUTPUT")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - ) - - expected = f""" - CREATE DYNAMIC TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA.{self.mon_table_name} - TARGET_LAG = '15 minutes' - WAREHOUSE = ML_OBS_WAREHOUSE - REFRESH_MODE = AUTO - INITIALIZE = ON_CREATE - AS - SELECT - TIME_SLICE(TIMESTAMP, 60, 'MINUTE') timestamp, - OBJECT_CONSTRUCT( - 'sketch', APPROX_PERCENTILE_ACCUMULATE(NUM_0), - 'count', count_if(NUM_0 is not null), - 'count_null', count_if(NUM_0 is null), - 'min', min(NUM_0), - 'max', max(NUM_0), - 'sum', sum(NUM_0) - ) AS NUM_0, - OBJECT_CONSTRUCT( - 'sketch', APPROX_PERCENTILE_ACCUMULATE(OUTPUT), - 'count', count_if(OUTPUT is not null), - 'count_null', count_if(OUTPUT is null), - 'min', min(OUTPUT), - 'max', max(OUTPUT), - 'sum', sum(OUTPUT) - ) AS OUTPUT, - OBJECT_CONSTRUCT( - 'sketch', APPROX_PERCENTILE_ACCUMULATE(LABEL), - 'count', count_if(LABEL is not null), - 'count_null', count_if(LABEL is null), - 'min', min(LABEL), - 'max', max(LABEL), - 'sum', sum(LABEL) - ) AS LABEL, - SNOWML_OBSERVABILITY.DATA.OBJECT_SUM(to_varchar(STR_COL_0)) AS STR_COL_0 - FROM - MODEL_OUTPUTS - GROUP BY - 1 - """ - self.assertEqual(query, expected) - - def test_monitoring_dynamic_table_query_multi_feature(self) -> None: - query = self.monitor_sql_client._monitoring_dynamic_table_query( - model_name=self.test_model_name, - model_version_name=self.test_model_version_name, - source_table_name=self.test_source_table_name, - refresh_interval=ModelMonitorRefreshInterval("15 minutes"), - aggregate_window=ModelMonitorAggregationWindow.WINDOW_1_HOUR, - warehouse_name=self.test_wh_name, - timestamp_column=self.test_timestamp_column, - numeric_features=[ - sql_identifier.SqlIdentifier("NUM_0"), - sql_identifier.SqlIdentifier("NUM_1"), - sql_identifier.SqlIdentifier("NUM_2"), - ], - categoric_features=[sql_identifier.SqlIdentifier("STR_COL_0"), sql_identifier.SqlIdentifier("STR_COL_1")], - prediction_columns=[sql_identifier.SqlIdentifier("OUTPUT")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - ) - self.assertEqual( - query, - f""" - CREATE DYNAMIC TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA.{self.mon_table_name} - TARGET_LAG = '15 minutes' - WAREHOUSE = ML_OBS_WAREHOUSE - REFRESH_MODE = AUTO - INITIALIZE = ON_CREATE - AS - SELECT - TIME_SLICE(TIMESTAMP, 60, 'MINUTE') timestamp, - OBJECT_CONSTRUCT( - 'sketch', APPROX_PERCENTILE_ACCUMULATE(NUM_0), - 'count', count_if(NUM_0 is not null), - 'count_null', count_if(NUM_0 is null), - 'min', min(NUM_0), - 'max', max(NUM_0), - 'sum', sum(NUM_0) - ) AS NUM_0, - OBJECT_CONSTRUCT( - 'sketch', APPROX_PERCENTILE_ACCUMULATE(NUM_1), - 'count', count_if(NUM_1 is not null), - 'count_null', count_if(NUM_1 is null), - 'min', min(NUM_1), - 'max', max(NUM_1), - 'sum', sum(NUM_1) - ) AS NUM_1, - OBJECT_CONSTRUCT( - 'sketch', APPROX_PERCENTILE_ACCUMULATE(NUM_2), - 'count', count_if(NUM_2 is not null), - 'count_null', count_if(NUM_2 is null), - 'min', min(NUM_2), - 'max', max(NUM_2), - 'sum', sum(NUM_2) - ) AS NUM_2, - OBJECT_CONSTRUCT( - 'sketch', APPROX_PERCENTILE_ACCUMULATE(OUTPUT), - 'count', count_if(OUTPUT is not null), - 'count_null', count_if(OUTPUT is null), - 'min', min(OUTPUT), - 'max', max(OUTPUT), - 'sum', sum(OUTPUT) - ) AS OUTPUT, - OBJECT_CONSTRUCT( - 'sketch', APPROX_PERCENTILE_ACCUMULATE(LABEL), - 'count', count_if(LABEL is not null), - 'count_null', count_if(LABEL is null), - 'min', min(LABEL), - 'max', max(LABEL), - 'sum', sum(LABEL) - ) AS LABEL, - SNOWML_OBSERVABILITY.DATA.OBJECT_SUM(to_varchar(STR_COL_0)) AS STR_COL_0, - SNOWML_OBSERVABILITY.DATA.OBJECT_SUM(to_varchar(STR_COL_1)) AS STR_COL_1 - FROM - MODEL_OUTPUTS - GROUP BY - 1 - """, - ) - - def test_monitoring_accuracy_regression_dynamic_table_query_single_prediction(self) -> None: - query = self.monitor_sql_client._monitoring_regression_accuracy_table_query( - model_name=self.test_model_name, - model_version_name=self.test_model_version_name, - source_table_name=self.test_source_table_name, - refresh_interval=ModelMonitorRefreshInterval("15 minutes"), - aggregate_window=ModelMonitorAggregationWindow.WINDOW_1_HOUR, - warehouse_name=self.test_wh_name, - timestamp_column=self.test_timestamp_column, - prediction_columns=[sql_identifier.SqlIdentifier("OUTPUT")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - ) - self.assertEqual( - query, - f""" - CREATE DYNAMIC TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA.{self.acc_table_name} - TARGET_LAG = '15 minutes' - WAREHOUSE = ML_OBS_WAREHOUSE - REFRESH_MODE = AUTO - INITIALIZE = ON_CREATE - AS - SELECT - TIME_SLICE(TIMESTAMP, 60, 'MINUTE') timestamp, - 'class_regression' label_class, - OBJECT_CONSTRUCT( - 'sum_difference_label_pred', sum(OUTPUT - LABEL), - 'sum_log_difference_square_label_pred', - sum( - case - when OUTPUT > -1 and LABEL > -1 - then pow(ln(OUTPUT + 1) - ln(LABEL + 1),2) - else null - END - ), - 'sum_difference_squares_label_pred', - sum( - pow( - OUTPUT - LABEL, - 2 - ) - ), - 'sum_absolute_regression_labels', sum(abs(LABEL)), - 'sum_absolute_percentage_error', - sum( - abs( - div0null( - (OUTPUT - LABEL), - LABEL - ) - ) - ), - 'sum_absolute_difference_label_pred', - sum( - abs(OUTPUT - LABEL) - ), - 'sum_prediction', sum(OUTPUT), - 'sum_label', sum(LABEL), - 'count', count(*) - ) AS AGGREGATE_METRICS, - APPROX_PERCENTILE_ACCUMULATE(OUTPUT) prediction_sketch, - APPROX_PERCENTILE_ACCUMULATE(LABEL) label_sketch - FROM - MODEL_OUTPUTS - GROUP BY - 1 - """, - ) - - def test_monitoring_accuracy_classification_probit_dynamic_table_query_single_prediction(self) -> None: - query = self.monitor_sql_client._monitoring_classification_accuracy_table_query( - model_name=self.test_model_name, - model_version_name=self.test_model_version_name, - source_table_name=self.test_source_table_name, - refresh_interval=ModelMonitorRefreshInterval("15 minutes"), - aggregate_window=ModelMonitorAggregationWindow.WINDOW_1_HOUR, - warehouse_name=self.test_wh_name, - timestamp_column=self.test_timestamp_column, - prediction_columns=[sql_identifier.SqlIdentifier("OUTPUT")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - score_type=output_score_type.OutputScoreType.PROBITS, - ) - self.assertEqual( - query, - f""" - CREATE DYNAMIC TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA.{self.acc_table_name} - TARGET_LAG = '15 minutes' - WAREHOUSE = ML_OBS_WAREHOUSE - REFRESH_MODE = AUTO - INITIALIZE = ON_CREATE - AS - WITH filtered_data AS ( - SELECT - TIMESTAMP AS timestamp, - OUTPUT, - LABEL, - CASE - WHEN LABEL = 1 THEN 'class_positive' - ELSE 'class_negative' - END AS label_class - FROM - MODEL_OUTPUTS - ) - select - time_slice(timestamp, 60, 'MINUTE') timestamp, - label_class, - OBJECT_CONSTRUCT( - 'sum_prediction', sum(OUTPUT), - 'sum_label', sum(LABEL), - 'sum_log_loss', - CASE - WHEN label_class = 'class_positive' THEN sum(-ln(OUTPUT)) - ELSE sum(-ln(1 - OUTPUT)) - END, - 'count', count(*) - ) AS AGGREGATE_METRICS, - APPROX_PERCENTILE_ACCUMULATE(OUTPUT) prediction_sketch, - APPROX_PERCENTILE_ACCUMULATE(LABEL) label_sketch - FROM - filtered_data - group by - 1, - 2 - """, - ) - - def test_monitoring_accuracy_classification_class_dynamic_table_query_single_prediction(self) -> None: - query = self.monitor_sql_client._monitoring_classification_accuracy_table_query( - model_name=self.test_model_name, - model_version_name=self.test_model_version_name, - source_table_name=self.test_source_table_name, - refresh_interval=ModelMonitorRefreshInterval("15 minutes"), - aggregate_window=ModelMonitorAggregationWindow.WINDOW_1_HOUR, - warehouse_name=self.test_wh_name, - timestamp_column=self.test_timestamp_column, - prediction_columns=[sql_identifier.SqlIdentifier("OUTPUT")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - score_type=output_score_type.OutputScoreType.CLASSIFICATION, - ) - self.assertEqual( - query, - f""" - CREATE DYNAMIC TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA.{self.acc_table_name} - TARGET_LAG = '15 minutes' - WAREHOUSE = ML_OBS_WAREHOUSE - REFRESH_MODE = AUTO - INITIALIZE = ON_CREATE - AS - WITH filtered_data AS ( - SELECT - TIMESTAMP AS timestamp, - OUTPUT, - LABEL, - CASE - WHEN LABEL = 1 THEN 'class_positive' - ELSE 'class_negative' - END AS label_class - FROM - MODEL_OUTPUTS - ) - select - time_slice(timestamp, 60, 'MINUTE') timestamp, - label_class, - OBJECT_CONSTRUCT( - 'sum_prediction', sum(OUTPUT), - 'sum_label', sum(LABEL), - 'tp', count_if(LABEL = 1 AND OUTPUT = 1), - 'tn', count_if(LABEL = 0 AND OUTPUT = 0), - 'fp', count_if(LABEL = 0 AND OUTPUT = 1), - 'fn', count_if(LABEL = 1 AND OUTPUT = 0), - 'count', count(*) - ) AS AGGREGATE_METRICS, - APPROX_PERCENTILE_ACCUMULATE(OUTPUT) prediction_sketch, - APPROX_PERCENTILE_ACCUMULATE(LABEL) label_sketch - FROM - filtered_data - group by - 1, - 2 - """, - ) - - def test_monitoring_accuracy_dynamic_table_query_multi_prediction(self) -> None: - with self.assertRaises(ValueError): - _ = self.monitor_sql_client._monitoring_accuracy_table_query( - model_name=self.test_model_name, - model_version_name=self.test_model_version_name, - task=type_hints.Task.TABULAR_BINARY_CLASSIFICATION, - source_table_name=self.test_source_table_name, - refresh_interval=ModelMonitorRefreshInterval("15 minutes"), - aggregate_window=ModelMonitorAggregationWindow.WINDOW_1_HOUR, - warehouse_name=self.test_wh_name, - timestamp_column=self.test_timestamp_column, - prediction_columns=[sql_identifier.SqlIdentifier("LABEL"), sql_identifier.SqlIdentifier("output_1")], - label_columns=[sql_identifier.SqlIdentifier("LABEL"), sql_identifier.SqlIdentifier("label_1")], - score_type=output_score_type.OutputScoreType.REGRESSION, - ) - def test_validate_existence_by_name(self) -> None: self.m_session.add_mock_sql( - query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME - FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA - WHERE MONITOR_NAME = '{self.test_monitor_name}' - """, + query=f"SHOW MODEL MONITORS LIKE '{self.test_monitor_name}' IN {self.test_db_name}.{self.test_schema_name}", result=mock_data_frame.MockDataFrame([]), ) - res = self.monitor_sql_client.validate_existence_by_name(self.test_monitor_name) - self.assertFalse(res) - - self.m_session.add_mock_sql( - query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME - FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA - WHERE MONITOR_NAME = '{self.test_monitor_name}' - """, - result=mock_data_frame.MockDataFrame( - [ - Row( - FULLY_QUALIFIED_MODEL_NAME=self.test_fq_model_name, - MODEL_VERSION_NAME=self.test_model_version_name, - ) - ] - ), + res = self.monitor_sql_client.validate_existence_by_name( + database_name=None, schema_name=None, monitor_name=self.test_monitor_name ) - res = self.monitor_sql_client.validate_existence_by_name(self.test_monitor_name) - self.assertTrue(res) - self.m_session.finalize() - - def test_validate_existence(self) -> None: - self.m_session.add_mock_sql( - query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME - FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA - WHERE FULLY_QUALIFIED_MODEL_NAME = '{self.test_fq_model_name}' - AND MODEL_VERSION_NAME = '{self.test_model_version_name}' - """, - result=mock_data_frame.MockDataFrame([]), - ) - res = self.monitor_sql_client.validate_existence(self.test_fq_model_name, self.test_model_version_name) self.assertFalse(res) self.m_session.add_mock_sql( - query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME - FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA - WHERE FULLY_QUALIFIED_MODEL_NAME = '{self.test_fq_model_name}' - AND MODEL_VERSION_NAME = '{self.test_model_version_name}' - """, - result=mock_data_frame.MockDataFrame( - [ - Row( - FULLY_QUALIFIED_MODEL_NAME=self.test_fq_model_name, - MODEL_VERSION_NAME=self.test_model_version_name, - ) - ] - ), + query=f"SHOW MODEL MONITORS LIKE '{self.test_monitor_name}' IN {self.test_db_name}.{self.test_schema_name}", + result=mock_data_frame.MockDataFrame([Row(name=self.test_monitor_name)]), ) - res = self.monitor_sql_client.validate_existence(self.test_fq_model_name, self.test_model_version_name) - self.assertTrue(res) - - self.m_session.finalize() - - def test_create_monitor_on_model_version(self) -> None: - self.m_session.add_mock_sql( - query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME - FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA - WHERE FULLY_QUALIFIED_MODEL_NAME = '{self.test_fq_model_name}' - AND MODEL_VERSION_NAME = '{self.test_model_version_name}' - """, - result=mock_data_frame.MockDataFrame([]), - ) - self.m_session.add_mock_sql( - query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME - FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA - WHERE MONITOR_NAME = '{self.test_monitor_name}' - """, - result=mock_data_frame.MockDataFrame([]), + res = self.monitor_sql_client.validate_existence_by_name( + database_name=None, schema_name=None, monitor_name=self.test_monitor_name ) + self.assertTrue(res) self.m_session.add_mock_sql( - query=f"""INSERT INTO SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA - (MONITOR_NAME, SOURCE_TABLE_NAME, FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME, - FUNCTION_NAME, TASK, IS_ENABLED, - TIMESTAMP_COLUMN_NAME, PREDICTION_COLUMN_NAMES, LABEL_COLUMN_NAMES, ID_COLUMN_NAMES) - SELECT '{self.test_monitor_name}', '{self.test_source_table_name}', - '{self.test_fq_model_name}', '{self.test_model_version_name}', '{self.test_function_name}', - 'TABULAR_BINARY_CLASSIFICATION', TRUE, - '{self.test_timestamp_column}', ARRAY_CONSTRUCT('{self.test_prediction_column_name}'), - ARRAY_CONSTRUCT('{self.test_label_column_name}'), ARRAY_CONSTRUCT('{self.test_id_column_name}')""", - result=mock_data_frame.MockDataFrame([Row(**{"number of rows inserted": 1})]), + query=f"SHOW MODEL MONITORS LIKE '{self.test_monitor_name}' IN NEW_DB.NEW_SCHEMA", + result=mock_data_frame.MockDataFrame([Row(name=self.test_monitor_name)]), ) - - self.monitor_sql_client.create_monitor_on_model_version( + res = self.monitor_sql_client.validate_existence_by_name( + database_name=sql_identifier.SqlIdentifier("NEW_DB"), + schema_name=sql_identifier.SqlIdentifier("NEW_SCHEMA"), monitor_name=self.test_monitor_name, - source_table_name=self.test_source_table_name, - fully_qualified_model_name=self.test_fq_model_name, - version_name=self.test_model_version_name, - function_name=self.test_function_name, - timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - id_columns=[sql_identifier.SqlIdentifier("ID")], - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - task=type_hints.Task.TABULAR_BINARY_CLASSIFICATION, - statement_params=None, - ) - self.m_session.finalize() - - def test_create_monitor_on_model_version_fails_if_model_exists(self) -> None: - self.m_session.add_mock_sql( - query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME - FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA - WHERE FULLY_QUALIFIED_MODEL_NAME = '{self.test_fq_model_name}' - AND MODEL_VERSION_NAME = '{self.test_model_version_name}' - """, - result=mock_data_frame.MockDataFrame( - [ - Row( - FULLY_QUALIFIED_MODEL_NAME=self.test_fq_model_name, - MODEL_VERSION_NAME=self.test_model_version_name, - ) - ] - ), - ) - expected_msg = f"Model {self.test_fq_model_name} Version {self.test_model_version_name} is already monitored!" - with self.assertRaisesRegex(ValueError, expected_msg): - self.monitor_sql_client.create_monitor_on_model_version( - monitor_name=self.test_monitor_name, - source_table_name=self.test_source_table_name, - fully_qualified_model_name=self.test_fq_model_name, - version_name=self.test_model_version_name, - function_name=self.test_function_name, - timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - id_columns=[sql_identifier.SqlIdentifier("ID")], - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - task=type_hints.Task.TABULAR_BINARY_CLASSIFICATION, - statement_params=None, - ) - - self.m_session.finalize() - - def test_create_monitor_on_model_version_fails_if_monitor_name_exists(self) -> None: - self.m_session.add_mock_sql( - query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME - FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA - WHERE FULLY_QUALIFIED_MODEL_NAME = '{self.test_fq_model_name}' - AND MODEL_VERSION_NAME = '{self.test_model_version_name}' - """, - result=mock_data_frame.MockDataFrame([]), - ) - self.m_session.add_mock_sql( - query=f"""SELECT FULLY_QUALIFIED_MODEL_NAME, MODEL_VERSION_NAME - FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA - WHERE MONITOR_NAME = '{self.test_monitor_name}' - """, - result=mock_data_frame.MockDataFrame( - [ - Row( - FULLY_QUALIFIED_MODEL_NAME=self.test_fq_model_name, - MODEL_VERSION_NAME=self.test_model_version_name, - ) - ] - ), ) - - expected_msg = f"Model Monitor with name '{self.test_monitor_name}' already exists!" - with self.assertRaisesRegex(ValueError, expected_msg): - self.monitor_sql_client.create_monitor_on_model_version( - monitor_name=self.test_monitor_name, - source_table_name=self.test_source_table_name, - fully_qualified_model_name=self.test_fq_model_name, - version_name=self.test_model_version_name, - function_name=self.test_function_name, - timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - id_columns=[sql_identifier.SqlIdentifier("ID")], - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - task=type_hints.Task.TABULAR_BINARY_CLASSIFICATION, - statement_params=None, - ) - + self.assertTrue(res) self.m_session.finalize() def test_validate_unique_columns(self) -> None: @@ -924,455 +337,30 @@ def test_validate_unique_columns_column_used_twice(self) -> None: label_columns=[sql_identifier.SqlIdentifier("LABEL")], ) - def test_infer_numeric_categoric_column_names(self) -> None: - from snowflake.snowpark import types - - timestamp_col = sql_identifier.SqlIdentifier("TS_COL") - id_col = sql_identifier.SqlIdentifier("ID_COL") - output_column = sql_identifier.SqlIdentifier("OUTPUT") - label_column = sql_identifier.SqlIdentifier("LABEL") - test_schema = { - timestamp_col: types.TimeType(), - id_col: types.FloatType(), - output_column: types.FloatType(), - label_column: types.FloatType(), - "STR_COL": types.StringType(16777216), - "LONG_COL": types.LongType(), - "FLOAT_COL": types.FloatType(), - "DOUBLE_COL": types.DoubleType(), - "BINARY_COL": types.BinaryType(), - "ARRAY_COL": types.ArrayType(), - "NULL_COL": types.NullType(), - } - - expected_numeric = [ - sql_identifier.SqlIdentifier("LONG_COL"), - sql_identifier.SqlIdentifier("FLOAT_COL"), - sql_identifier.SqlIdentifier("DOUBLE_COL"), - ] - expected_categoric = [ - sql_identifier.SqlIdentifier("STR_COL"), - ] - - numeric, categoric = model_monitor_sql_client._infer_numeric_categoric_feature_column_names( - source_table_schema=test_schema, - timestamp_column=timestamp_col, - id_columns=[id_col], - prediction_columns=[output_column], - label_columns=[label_column], - ) - self.assertListEqual(expected_numeric, numeric) - self.assertListEqual(expected_categoric, categoric) - - def test_initialize_baseline_table(self) -> None: - mocked_table_out = mock.MagicMock(name="schema") - self.m_session.table = mock.MagicMock(name="table", return_value=mocked_table_out) - mocked_table_out.schema = mock.MagicMock(name="schema") - mocked_table_out.schema.fields = [ - types.StructField(self.test_timestamp_column, types.TimestampType()), - types.StructField(self.test_prediction_column_name, types.DoubleType()), - types.StructField(self.test_label_column_name, types.DoubleType()), - types.StructField(self.test_id_column_name, types.StringType()), - ] - - self.m_session.add_mock_sql( - query=f"""CREATE TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA._SNOWML_OBS_BASELINE_""" - f"""{self.test_model_name}_{self.test_model_version_name}""" - f"""(PREDICTION DOUBLE, LABEL DOUBLE)""", - result=mock_data_frame.MockDataFrame( - [ - Row( - name="PREDICTION", - type="DOUBLE", - ), - Row( - name="LABEL", - type="DOUBLE", - ), - ] - ), - ) - - self.monitor_sql_client.initialize_baseline_table( - model_name=self.test_model_name, - version_name=self.test_model_version_name, - source_table_name=self.test_source_table_name, - columns_to_drop=[self.test_id_column_name, self.test_timestamp_column], - ) - - def test_materialize_baseline_dataframe(self) -> None: - mocked_dataframe = mock_data_frame.MockDataFrame( - [ - Row(TIMESTAMP="2022-01-01 00:00:00", PREDICTION=0.8, LABEL=1.0, ID="12345"), - Row(TIMESTAMP="2022-01-02 00:00:00", PREDICTION=0.6, LABEL=0.0, ID="67890"), - ] - ) + def test_suspend_monitor(self) -> None: self.m_session.add_mock_sql( - f"SHOW TABLES LIKE '{self.test_baseline_table_name_sql}' IN SNOWML_OBSERVABILITY.DATA", - mock_data_frame.MockDataFrame([Row(name=self.test_baseline_table_name_sql)]), - ) - - mocked_dataframe.write = mock.MagicMock(name="write") - save_as_table = mock.MagicMock(name="save_as_table") - mocked_dataframe.write.mode = mock.MagicMock(name="mode", return_value=save_as_table) - - self.monitor_sql_client.materialize_baseline_dataframe( - baseline_df=cast(DataFrame, mocked_dataframe), - fully_qualified_model_name=self.test_model_name, - model_version_name=self.test_model_version_name, - ) - - mocked_dataframe.write.mode.assert_called_once_with("truncate") - save_as_table.save_as_table.assert_called_once_with( - [self.test_db_name, self.test_schema_name, self.test_baseline_table_name_sql], - statement_params=mock.ANY, - ) - - def test_materialize_baseline_dataframe_table_not_exists(self) -> None: - mocked_dataframe = mock_data_frame.MockDataFrame( - [ - Row(TIMESTAMP="2022-01-01 00:00:00", PREDICTION=0.8, LABEL=1.0, ID="12345"), - Row(TIMESTAMP="2022-01-02 00:00:00", PREDICTION=0.6, LABEL=0.0, ID="67890"), - ] - ) - self.m_session.add_mock_sql( - f"SHOW TABLES LIKE '{self.test_baseline_table_name_sql}' IN SNOWML_OBSERVABILITY.DATA", - mock_data_frame.MockDataFrame([]), - ) - - expected_msg = ( - f"Baseline table '{self.test_baseline_table_name_sql}' does not exist for model: " - "'TEST_MODEL' and model_version: 'TEST_MODEL_VERSION'" - ) - with self.assertRaisesRegex(ValueError, expected_msg): - self.monitor_sql_client.materialize_baseline_dataframe( - baseline_df=cast(DataFrame, mocked_dataframe), - fully_qualified_model_name=self.test_model_name, - model_version_name=self.test_model_version_name, - ) - - def test_initialize_baseline_table_different_data_kinds(self) -> None: - mocked_table_out = mock.MagicMock(name="schema") - self.m_session.table = mock.MagicMock(name="table", return_value=mocked_table_out) - mocked_table_out.schema = mock.MagicMock(name="schema") - mocked_table_out.schema.fields = [ - types.StructField(self.test_timestamp_column, types.TimestampType()), - types.StructField(self.test_prediction_column_name, types.DoubleType()), - types.StructField(self.test_label_column_name, types.DoubleType()), - types.StructField(self.test_id_column_name, types.StringType()), - types.StructField(sql_identifier.SqlIdentifier("FEATURE1"), types.StringType()), - types.StructField(sql_identifier.SqlIdentifier("FEATURE2"), types.DoubleType()), - types.StructField(sql_identifier.SqlIdentifier("FEATURE3"), types.FloatType()), - types.StructField(sql_identifier.SqlIdentifier("FEATURE4"), types.DecimalType(38, 9)), - types.StructField(sql_identifier.SqlIdentifier("FEATURE5"), types.IntegerType()), - types.StructField(sql_identifier.SqlIdentifier("FEATURE6"), types.LongType()), - types.StructField(sql_identifier.SqlIdentifier("FEATURE7"), types.ShortType()), - types.StructField(sql_identifier.SqlIdentifier("FEATURE8"), types.BinaryType()), - types.StructField(sql_identifier.SqlIdentifier("FEATURE9"), types.BooleanType()), - types.StructField(sql_identifier.SqlIdentifier("FEATURE10"), types.TimestampType()), - types.StructField( - sql_identifier.SqlIdentifier("FEATURE11"), types.TimestampType(types.TimestampTimeZone("ltz")) - ), - types.StructField( - sql_identifier.SqlIdentifier("FEATURE12"), types.TimestampType(types.TimestampTimeZone("ntz")) - ), - types.StructField( - sql_identifier.SqlIdentifier("FEATURE13"), types.TimestampType(types.TimestampTimeZone("tz")) - ), - ] - - self.m_session.add_mock_sql( - query=f"""CREATE TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.DATA._SNOWML_OBS_BASELINE_""" - f"""{self.test_model_name}_{self.test_model_version_name}""" - f"""(PREDICTION DOUBLE, LABEL DOUBLE, - FEATURE1 STRING, FEATURE2 DOUBLE, FEATURE3 FLOAT, FEATURE4 NUMBER(38, 9), FEATURE5 INT, - FEATURE6 BIGINT, FEATURE7 SMALLINT, FEATURE8 BINARY, FEATURE9 BOOLEAN, FEATURE10 TIMESTAMP, - FEATURE11 TIMESTAMP_LTZ, FEATURE12 TIMESTAMP_NTZ, FEATURE13 TIMESTAMP_TZ)""", - result=mock_data_frame.MockDataFrame( - [ - Row( - name="PREDICTION", - type="DOUBLE", - ), - Row( - name="LABEL", - type="DOUBLE", - ), - Row( - name="FEATURE1", - type="STRING", - ), - Row( - name="FEATURE2", - type="DOUBLE", - ), - Row( - name="FEATURE3", - type="FLOAT", - ), - Row( - name="FEATURE4", - type="NUMBER", - ), - Row( - name="FEATURE5", - type="INTEGER", - ), - Row( - name="FEATURE6", - type="INTEGER", - ), - Row( - name="FEATURE7", - type="INTEGER", - ), - Row( - name="FEATURE8", - type="BINARY", - ), - Row( - name="FEATURE9", - type="BOOLEAN", - ), - Row( - name="FEATURE10", - type="TIMESTAMP", - ), - Row( - name="FEATURE11", - type="TIMESTAMP_LTZ", - ), - Row( - name="FEATURE12", - type="TIMESTAMP_NTZ", - ), - Row( - name="FEATURE13", - type="TIMESTAMP_TZ", - ), - ] - ), - ) - - self.monitor_sql_client.initialize_baseline_table( - model_name=self.test_model_name, - version_name=self.test_model_version_name, - source_table_name=self.test_source_table_name, - columns_to_drop=[self.test_timestamp_column, self.test_id_column_name], - ) - - def test_get_model_monitor_by_model_version(self) -> None: - model_db = sql_identifier.SqlIdentifier("MODEL_DB") - model_schema = sql_identifier.SqlIdentifier("MODEL_SCHEMA") - self.m_session.add_mock_sql( - f"""SELECT {model_monitor_sql_client.MONITOR_NAME_COL_NAME}, - {model_monitor_sql_client.FQ_MODEL_NAME_COL_NAME}, - {model_monitor_sql_client.VERSION_NAME_COL_NAME}, - {model_monitor_sql_client.FUNCTION_NAME_COL_NAME} - FROM - {self.test_db_name}.{self.test_schema_name}.{model_monitor_sql_client.SNOWML_MONITORING_METADATA_TABLE_NAME} - WHERE {model_monitor_sql_client.FQ_MODEL_NAME_COL_NAME} = '{model_db}.{model_schema}.{self.test_model_name}' - AND {model_monitor_sql_client.VERSION_NAME_COL_NAME} = '{self.test_model_version_name}'""", - result=mock_data_frame.MockDataFrame( - [ - Row( - MONITOR_NAME=self.test_monitor_name, - FULLY_QUALIFIED_MODEL_NAME=f"{model_db}.{model_schema}.{self.test_model_name}", - MODEL_VERSION_NAME=self.test_model_version_name, - FUNCTION_NAME=self.test_function_name, - PREDICTION_COLUMN_NAMES="[]", - LABEL_COLUMN_NAMES="[]", - ) - ] - ), - ) - # name, fq_model_name, version_name, function_name - monitor_params = self.monitor_sql_client.get_model_monitor_by_model_version( - model_db=model_db, - model_schema=model_schema, - model_name=self.test_model_name, - version_name=self.test_model_version_name, - ) - self.assertEqual(monitor_params["monitor_name"], str(self.test_monitor_name)) - self.assertEqual( - monitor_params["fully_qualified_model_name"], f"{model_db}.{model_schema}.{self.test_model_name}" - ) - self.assertEqual(monitor_params["version_name"], str(self.test_model_version_name)) - self.assertEqual(monitor_params["function_name"], str(self.test_function_name)) - - self.m_session.finalize() # TODO: Move to tearDown() for all tests. - - def test_get_model_monitor_by_model_version_fails_if_multiple(self) -> None: - model_db = sql_identifier.SqlIdentifier("MODEL_DB") - model_schema = sql_identifier.SqlIdentifier("MODEL_SCHEMA") - self.m_session.add_mock_sql( - f"""SELECT {model_monitor_sql_client.MONITOR_NAME_COL_NAME}, - {model_monitor_sql_client.FQ_MODEL_NAME_COL_NAME}, - {model_monitor_sql_client.VERSION_NAME_COL_NAME}, - {model_monitor_sql_client.FUNCTION_NAME_COL_NAME} - FROM - {self.test_db_name}.{self.test_schema_name}.{model_monitor_sql_client.SNOWML_MONITORING_METADATA_TABLE_NAME} - WHERE {model_monitor_sql_client.FQ_MODEL_NAME_COL_NAME} = '{model_db}.{model_schema}.{self.test_model_name}' - AND {model_monitor_sql_client.VERSION_NAME_COL_NAME} = '{self.test_model_version_name}'""", - result=mock_data_frame.MockDataFrame( - [ - Row( - MONITOR_NAME=self.test_monitor_name, - FULLY_QUALIFIED_MODEL_NAME=f"{model_db}.{model_schema}.{self.test_model_name}", - MODEL_VERSION_NAME=self.test_model_version_name, - FUNCTION_NAME=self.test_function_name, - ), - Row( - MONITOR_NAME=self.test_monitor_name, - FULLY_QUALIFIED_MODEL_NAME=f"{model_db}.{model_schema}.{self.test_model_name}", - MODEL_VERSION_NAME=self.test_model_version_name, - FUNCTION_NAME=self.test_function_name, - ), - ] - ), - ) - with self.assertRaisesRegex(ValueError, "Invalid state. Multiple Monitors exist for model:"): - self.monitor_sql_client.get_model_monitor_by_model_version( - model_db=model_db, - model_schema=model_schema, - model_name=self.test_model_name, - version_name=self.test_model_version_name, - ) - - self.m_session.finalize() # TODO: Move to tearDown() for all tests. - - def test_dashboard_udtf_queries(self) -> None: - queries_map = self.monitor_sql_client._create_dashboard_udtf_queries( - self.test_monitor_name, - self.test_model_version_name, - self.test_model_name, - type_hints.Task.TABULAR_REGRESSION, - output_score_type.OutputScoreType.REGRESSION, - output_columns=[self.test_prediction_column_name], - ground_truth_columns=[self.test_label_column_name], - ) - self.assertIn("rmse", queries_map) - EXPECTED_RMSE = """CREATE OR REPLACE FUNCTION TEST_RMSE() - RETURNS TABLE(event_timestamp TIMESTAMP_NTZ, value FLOAT) - AS -$$ -WITH metric_of_interest as ( - select - time_slice(timestamp, 1, 'hour') as event_timestamp, - AGGREGATE_METRICS:"sum_difference_squares_label_pred" as aggregate_field, - AGGREGATE_METRICS:"count" as "count" - from - SNOWML_OBSERVABILITY.DATA._SNOWML_OBS_ACCURACY__TEST_MODEL_VERSION_TEST_MODEL -), metric_combine as ( - select - event_timestamp, - CAST(SUM(NVL(aggregate_field, 0)) as DOUBLE) as metric_sum, - SUM("count") as metric_count - from - metric_of_interest - where - cast(aggregate_field as varchar) not in ('inf','-inf','NaN') - group by - 1 -) select - event_timestamp, - SQRT(DIV0(metric_sum,metric_count)) as VALUE -from metric_combine -order by 1 desc -$$; -""" - self.assertEqual(queries_map["rmse"], EXPECTED_RMSE) - - self.assertIn("record_count", queries_map) - EXPECTED_RECORD_COUNT = """CREATE OR REPLACE FUNCTION TEST_PREDICTION_COUNT() - RETURNS TABLE(event_timestamp TIMESTAMP_NTZ, count FLOAT) - AS - $$ -SELECT - time_slice(timestamp, 1, 'hour') as "event_timestamp", - sum(get(PREDICTION,'count')) as count -from - SNOWML_OBSERVABILITY.DATA._SNOWML_OBS_MONITORING__TEST_MODEL_VERSION_TEST_MODEL -group by - 1 -order by - 1 desc - $$; -""" - self.assertEqual(queries_map["record_count"], EXPECTED_RECORD_COUNT) - - def test_get_all_model_monitor_metadata(self) -> None: - expected_result = [Row(MONITOR_NAME="monitor")] - self.m_session.add_mock_sql( - query="SELECT * FROM SNOWML_OBSERVABILITY.DATA._SYSTEM_MONITORING_METADATA", - result=mock_data_frame.MockDataFrame(expected_result), - ) - res = self.monitor_sql_client.get_all_model_monitor_metadata() - self.assertEqual(res, expected_result) - - def test_suspend_monitor_dynamic_tables(self) -> None: - self.m_session.add_mock_sql( - f"""ALTER DYNAMIC TABLE {self.test_db_name}.{self.test_schema_name}.{self.mon_table_name} SUSPEND""", + f"""ALTER MODEL MONITOR {self.test_db_name}.{self.test_schema_name}.{self.test_monitor_name} SUSPEND""", result=mock_data_frame.MockDataFrame([Row(status="Success")]), ) - self.m_session.add_mock_sql( - f"""ALTER DYNAMIC TABLE {self.test_db_name}.{self.test_schema_name}.{self.acc_table_name} SUSPEND""", - result=mock_data_frame.MockDataFrame([Row(status="Success")]), - ) - self.monitor_sql_client.suspend_monitor_dynamic_tables(self.test_model_name, self.test_model_version_name) + self.monitor_sql_client.suspend_monitor(self.test_monitor_name) self.m_session.finalize() - def test_resume_monitor_dynamic_tables(self) -> None: - self.m_session.add_mock_sql( - f"""ALTER DYNAMIC TABLE {self.test_db_name}.{self.test_schema_name}.{self.mon_table_name} RESUME""", - result=mock_data_frame.MockDataFrame([Row(status="Success")]), - ) + def test_resume_monitor(self) -> None: self.m_session.add_mock_sql( - f"""ALTER DYNAMIC TABLE {self.test_db_name}.{self.test_schema_name}.{self.acc_table_name} RESUME""", + f"""ALTER MODEL MONITOR {self.test_db_name}.{self.test_schema_name}.{self.test_monitor_name} RESUME""", result=mock_data_frame.MockDataFrame([Row(status="Success")]), ) - self.monitor_sql_client.resume_monitor_dynamic_tables(self.test_model_name, self.test_model_version_name) + self.monitor_sql_client.resume_monitor(self.test_monitor_name) self.m_session.finalize() - def test_delete_monitor_metadata(self) -> None: - monitor = "TEST_MONITOR" - self.m_session.add_mock_sql( - query=f"DELETE FROM {self.test_db_name}.{self.test_schema_name}." - f"{model_monitor_sql_client.SNOWML_MONITORING_METADATA_TABLE_NAME} WHERE " - f"{model_monitor_sql_client.MONITOR_NAME_COL_NAME} = '{monitor}'", - result=mock_data_frame.MockDataFrame([]), - ) - self.monitor_sql_client.delete_monitor_metadata(monitor) - - def test_delete_baseline_table(self) -> None: - model = "TEST_MODEL" - version = "TEST_VERSION" - table = model_monitor_sql_client._create_baseline_table_name(model, version) - self.m_session.add_mock_sql( - query=f"DROP TABLE IF EXISTS {self.test_db_name}.{self.test_schema_name}.{table}", - result=mock_data_frame.MockDataFrame([]), - ) - self.monitor_sql_client.delete_baseline_table(model, version) - - def test_delete_dynamic_tables(self) -> None: - model = "TEST_MODEL" - model_id = sql_identifier.SqlIdentifier(model) - fully_qualified_model = f"{self.test_db_name}.{self.test_schema_name}.{model}" - version = "TEST_VERSION" - version_id = sql_identifier.SqlIdentifier(version) - monitoring_table = self.monitor_sql_client.get_monitoring_table_fully_qualified_name(model_id, version_id) - accuracy_table = self.monitor_sql_client.get_accuracy_monitoring_table_fully_qualified_name( - model_id, version_id - ) + # TODO: Move to new test class + def test_drop_model_monitor(self) -> None: self.m_session.add_mock_sql( - query=f"DROP DYNAMIC TABLE IF EXISTS {monitoring_table}", - result=mock_data_frame.MockDataFrame([]), - ) - self.m_session.add_mock_sql( - query=f"DROP DYNAMIC TABLE IF EXISTS {accuracy_table}", - result=mock_data_frame.MockDataFrame([]), + f"""DROP MODEL MONITOR {self.test_db_name}.{self.test_schema_name}.{self.test_monitor_name}""", + result=mock_data_frame.MockDataFrame([Row(status="Success")]), ) - self.monitor_sql_client.delete_dynamic_tables(fully_qualified_model, version) + self.monitor_sql_client.drop_model_monitor(monitor_name=self.test_monitor_name) + self.m_session.finalize() if __name__ == "__main__": diff --git a/snowflake/ml/monitoring/_manager/model_monitor_manager.py b/snowflake/ml/monitoring/_manager/model_monitor_manager.py index 1b8a47da..1a8d8e4a 100644 --- a/snowflake/ml/monitoring/_manager/model_monitor_manager.py +++ b/snowflake/ml/monitoring/_manager/model_monitor_manager.py @@ -1,59 +1,20 @@ +import json from typing import Any, Dict, List, Optional from snowflake import snowpark -from snowflake.ml._internal import telemetry -from snowflake.ml._internal.utils import db_utils, sql_identifier +from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.model import type_hints from snowflake.ml.model._client.model import model_version_impl -from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.monitoring import model_monitor from snowflake.ml.monitoring._client import model_monitor_sql_client -from snowflake.ml.monitoring.entities import ( - model_monitor_config, - model_monitor_interval, -) +from snowflake.ml.monitoring.entities import model_monitor_config from snowflake.snowpark import session -def _validate_name_constraints(model_version: model_version_impl.ModelVersion) -> None: - system_table_prefixes = [ - model_monitor_sql_client._SNOWML_MONITORING_TABLE_NAME_PREFIX, - model_monitor_sql_client._SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX, - ] - - max_allowed_model_name_and_version_length = ( - db_utils.MAX_IDENTIFIER_LENGTH - max(len(prefix) for prefix in system_table_prefixes) - 1 - ) # -1 includes '_' between model_name + model_version - if len(model_version.model_name) + len(model_version.version_name) > max_allowed_model_name_and_version_length: - error_msg = f"Model name and version name exceeds maximum length of {max_allowed_model_name_and_version_length}" - raise ValueError(error_msg) - - class ModelMonitorManager: - """Class to manage internal operations for Model Monitor workflows.""" # TODO: Move to Registry. - - @staticmethod - def setup(session: session.Session, database_name: str, schema_name: str) -> None: - """Static method to set up schema for Model Monitoring resources. - - Args: - session: The Snowpark Session to connect with Snowflake. - database_name: The name of the database. If None, the current database of the session - will be used. Defaults to None. - schema_name: The name of the schema. If None, the current schema of the session - will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None. - """ - statement_params = telemetry.get_statement_params( - project=telemetry.TelemetryProject.MLOPS.value, - subproject=telemetry.TelemetrySubProject.MONITORING.value, - ) - database_name_id = sql_identifier.SqlIdentifier(database_name) - schema_name_id = sql_identifier.SqlIdentifier(schema_name) - model_monitor_sql_client.ModelMonitorSQLClient.initialize_monitoring_schema( - session, database_name_id, schema_name_id, statement_params=statement_params - ) + """Class to manage internal operations for Model Monitor workflows.""" - def _fetch_task_from_model_version( + def _validate_task_from_model_version( self, model_version: model_version_impl.ModelVersion, ) -> type_hints.Task: @@ -68,7 +29,6 @@ def __init__( database_name: sql_identifier.SqlIdentifier, schema_name: sql_identifier.SqlIdentifier, *, - create_if_not_exists: bool = False, statement_params: Optional[Dict[str, Any]] = None, ) -> None: """ @@ -79,233 +39,156 @@ def __init__( session: The Snowpark Session to connect with Snowflake. database_name: The name of the database. schema_name: The name of the schema. - create_if_not_exists: Flag whether to initialize resources in the schema needed for Model Monitoring. statement_params: Optional set of statement params. - - Raises: - ValueError: When there is no specified or active database in the session. """ self._database_name = database_name self._schema_name = schema_name self.statement_params = statement_params + self._model_monitor_client = model_monitor_sql_client.ModelMonitorSQLClient( session, database_name=self._database_name, schema_name=self._schema_name, ) - if create_if_not_exists: - model_monitor_sql_client.ModelMonitorSQLClient.initialize_monitoring_schema( - session, self._database_name, self._schema_name, self.statement_params - ) - elif not self._model_monitor_client._validate_is_initialized(): - raise ValueError( - "Monitoring has not been setup. Set create_if_not_exists or call ModelMonitorManager.setup" - ) - def _get_and_validate_model_function_from_model_version( + def _validate_model_function_from_model_version( self, function: str, model_version: model_version_impl.ModelVersion - ) -> model_manifest_schema.ModelFunctionInfo: + ) -> None: functions = model_version.show_functions() for f in functions: if f["target_method"] == function: - return f + return existing_target_methods = {f["target_method"] for f in functions} raise ValueError( f"Function with name {function} does not exist in the given model version. " f"Found: {existing_target_methods}." ) - def _validate_monitor_config_or_raise( - self, - table_config: model_monitor_config.ModelMonitorTableConfig, - model_monitor_config: model_monitor_config.ModelMonitorConfig, - ) -> None: - """Validate provided config for model monitor. - - Args: - table_config: Config for model monitor tables. - model_monitor_config: Config for ModelMonitor. - - Raises: - ValueError: If warehouse provided does not exist. - """ - - # Validate naming will not exceed 255 chars - _validate_name_constraints(model_monitor_config.model_version) - - if len(table_config.prediction_columns) != len(table_config.label_columns): - raise ValueError("Prediction and Label column names must be of the same length.") - # output and ground cols are list to keep interface extensible. - # for prpr only one label and one output col will be supported - if len(table_config.prediction_columns) != 1 or len(table_config.label_columns) != 1: - raise ValueError("Multiple Output columns are not supported in monitoring") - - # Validate warehouse exists. - warehouse_name_id = sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name) - self._model_monitor_client.validate_monitor_warehouse(warehouse_name_id, statement_params=self.statement_params) - - # Validate refresh interval. - try: - num_units, time_units = model_monitor_config.refresh_interval.strip().split(" ") - int(num_units) # try to cast - if time_units.lower() not in {"seconds", "minutes", "hours", "days"}: - raise ValueError( - """Invalid time unit in refresh interval. Provide ' '. -See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info.""" - ) - except Exception as e: # TODO: Link to DT page. - raise ValueError( - f"""Failed to parse refresh interval with exception {e}. - Provide ' '. -See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info.""" - ) + def _build_column_list_from_input(self, columns: Optional[List[str]]) -> List[sql_identifier.SqlIdentifier]: + return [sql_identifier.SqlIdentifier(column_name) for column_name in columns] if columns else [] def add_monitor( self, name: str, - table_config: model_monitor_config.ModelMonitorTableConfig, + source_config: model_monitor_config.ModelMonitorSourceConfig, model_monitor_config: model_monitor_config.ModelMonitorConfig, - *, - add_dashboard_udtfs: bool = False, ) -> model_monitor.ModelMonitor: """Add a new Model Monitor. Args: name: Name of Model Monitor to create. - table_config: Configuration options for the source table used in ModelMonitor. + source_config: Configuration options for the source table used in ModelMonitor. model_monitor_config: Configuration options of ModelMonitor. - add_dashboard_udtfs: Add UDTFs useful for creating a dashboard. Returns: The newly added ModelMonitor object. """ - # Validates configuration or raise. - self._validate_monitor_config_or_raise(table_config, model_monitor_config) - model_function = self._get_and_validate_model_function_from_model_version( + warehouse_name_id = sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name) + self._model_monitor_client.validate_monitor_warehouse(warehouse_name_id, statement_params=self.statement_params) + self._validate_model_function_from_model_version( model_monitor_config.model_function_name, model_monitor_config.model_version ) - monitor_refresh_interval = model_monitor_interval.ModelMonitorRefreshInterval( - model_monitor_config.refresh_interval + self._validate_task_from_model_version(model_monitor_config.model_version) + monitor_database_name_id, monitor_schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name( + name + ) + source_database_name_id, source_schema_name_id, source_name_id = sql_identifier.parse_fully_qualified_name( + source_config.source + ) + baseline_database_name_id, baseline_schema_name_id, baseline_name_id = ( + sql_identifier.parse_fully_qualified_name(source_config.baseline) + if source_config.baseline + else (None, None, None) ) - name_id = sql_identifier.SqlIdentifier(name) - source_table_name_id = sql_identifier.SqlIdentifier(table_config.source_table) - prediction_columns = [ - sql_identifier.SqlIdentifier(column_name) for column_name in table_config.prediction_columns - ] - label_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in table_config.label_columns] - id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in table_config.id_columns] - ts_column = sql_identifier.SqlIdentifier(table_config.timestamp_column) + model_database_name_id, model_schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name( + model_monitor_config.model_version.fully_qualified_model_name + ) + + prediction_score_columns = self._build_column_list_from_input(source_config.prediction_score_columns) + prediction_class_columns = self._build_column_list_from_input(source_config.prediction_class_columns) + actual_score_columns = self._build_column_list_from_input(source_config.actual_score_columns) + actual_class_columns = self._build_column_list_from_input(source_config.actual_class_columns) + + id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in source_config.id_columns] + ts_column = sql_identifier.SqlIdentifier(source_config.timestamp_column) # Validate source table - self._model_monitor_client.validate_source_table( - source_table_name=source_table_name_id, + self._model_monitor_client.validate_source( + source_database=source_database_name_id, + source_schema=source_schema_name_id, + source=source_name_id, timestamp_column=ts_column, - prediction_columns=prediction_columns, - label_columns=label_columns, + prediction_score_columns=prediction_score_columns, + prediction_class_columns=prediction_class_columns, + actual_score_columns=actual_score_columns, + actual_class_columns=actual_class_columns, id_columns=id_columns, - model_function=model_function, ) - task = self._fetch_task_from_model_version(model_version=model_monitor_config.model_version) - score_type = self._model_monitor_client.get_score_type(task, source_table_name_id, prediction_columns) - - # Insert monitoring metadata for new model version. - self._model_monitor_client.create_monitor_on_model_version( - monitor_name=name_id, - source_table_name=source_table_name_id, - fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name, + self._model_monitor_client.create_model_monitor( + monitor_database=monitor_database_name_id, + monitor_schema=monitor_schema_name_id, + monitor_name=monitor_name_id, + source_database=source_database_name_id, + source_schema=source_schema_name_id, + source=source_name_id, + model_database=model_database_name_id, + model_schema=model_schema_name_id, + model_name=model_name_id, version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name), function_name=model_monitor_config.model_function_name, + warehouse_name=warehouse_name_id, timestamp_column=ts_column, - prediction_columns=prediction_columns, - label_columns=label_columns, id_columns=id_columns, - task=task, - statement_params=self.statement_params, - ) - - # Create Dynamic tables for model monitor. - self._model_monitor_client.create_dynamic_tables_for_monitor( - model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name), - model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name), - task=task, - source_table_name=source_table_name_id, - refresh_interval=monitor_refresh_interval, + prediction_score_columns=prediction_score_columns, + prediction_class_columns=prediction_class_columns, + actual_score_columns=actual_score_columns, + actual_class_columns=actual_class_columns, + refresh_interval=model_monitor_config.refresh_interval, aggregation_window=model_monitor_config.aggregation_window, - warehouse_name=sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name), - timestamp_column=sql_identifier.SqlIdentifier(table_config.timestamp_column), - id_columns=id_columns, - prediction_columns=prediction_columns, - label_columns=label_columns, - score_type=score_type, - ) - - # Initialize baseline table. - self._model_monitor_client.initialize_baseline_table( - model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name), - version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name), - source_table_name=table_config.source_table, - columns_to_drop=[ts_column, *id_columns], + baseline_database=baseline_database_name_id, + baseline_schema=baseline_schema_name_id, + baseline=baseline_name_id, statement_params=self.statement_params, ) - - # Add udtfs helpful for dashboard queries. - # TODO(apgupta) Make this true by default. - if add_dashboard_udtfs: - self._model_monitor_client.add_dashboard_udtfs( - name_id, - model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name), - model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name), - task=task, - score_type=score_type, - output_columns=prediction_columns, - ground_truth_columns=label_columns, - ) - return model_monitor.ModelMonitor._ref( model_monitor_client=self._model_monitor_client, - name=name_id, - fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name, - version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name), - function_name=sql_identifier.SqlIdentifier(model_monitor_config.model_function_name), - prediction_columns=prediction_columns, - label_columns=label_columns, + name=monitor_name_id, ) def get_monitor_by_model_version( self, model_version: model_version_impl.ModelVersion ) -> model_monitor.ModelMonitor: - fq_model_name = model_version.fully_qualified_model_name - version_name = sql_identifier.SqlIdentifier(model_version.version_name) - if self._model_monitor_client.validate_existence(fq_model_name, version_name, self.statement_params): - model_db, model_schema, model_name = sql_identifier.parse_fully_qualified_name(fq_model_name) - if model_db is None or model_schema is None: - raise ValueError("Failed to parse model name") - - model_monitor_params: model_monitor_sql_client._ModelMonitorParams = ( - self._model_monitor_client.get_model_monitor_by_model_version( - model_db=model_db, - model_schema=model_schema, - model_name=model_name, - version_name=version_name, - statement_params=self.statement_params, - ) - ) - return model_monitor.ModelMonitor._ref( - model_monitor_client=self._model_monitor_client, - name=sql_identifier.SqlIdentifier(model_monitor_params["monitor_name"]), - fully_qualified_model_name=fq_model_name, - version_name=version_name, - function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]), - prediction_columns=model_monitor_params["prediction_columns"], - label_columns=model_monitor_params["label_columns"], - ) + """Get a Model Monitor by Model Version. - else: - raise ValueError( - f"ModelMonitor not found for model version {model_version.model_name} - {model_version.version_name}" + Args: + model_version: ModelVersion to retrieve Model Monitor for. + + Returns: + The fetched ModelMonitor. + + Raises: + ValueError: If model monitor is not found. + """ + rows = self._model_monitor_client.show_model_monitors(statement_params=self.statement_params) + + def model_match_fn(model_details: Dict[str, str]) -> bool: + return ( + model_details[model_monitor_sql_client.MODEL_JSON_MODEL_NAME_FIELD] == model_version.model_name + and model_details[model_monitor_sql_client.MODEL_JSON_VERSION_NAME_FIELD] == model_version.version_name ) + rows = [row for row in rows if model_match_fn(json.loads(row[model_monitor_sql_client.MODEL_JSON_COL_NAME]))] + if len(rows) == 0: + raise ValueError("Unable to find model monitor for the given model version.") + if len(rows) > 1: + raise ValueError("Found multiple model monitors for the given model version.") + + return model_monitor.ModelMonitor._ref( + model_monitor_client=self._model_monitor_client, + name=sql_identifier.SqlIdentifier(rows[0]["name"]), + ) + def get_monitor(self, name: str) -> model_monitor.ModelMonitor: """Get a Model Monitor from the Registry @@ -318,25 +201,18 @@ def get_monitor(self, name: str) -> model_monitor.ModelMonitor: Returns: The fetched ModelMonitor. """ - name_id = sql_identifier.SqlIdentifier(name) + database_name_id, schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(name) if not self._model_monitor_client.validate_existence_by_name( - monitor_name=name_id, + database_name=database_name_id, + schema_name=schema_name_id, + monitor_name=monitor_name_id, statement_params=self.statement_params, ): raise ValueError(f"Unable to find model monitor '{name}'") - model_monitor_params: model_monitor_sql_client._ModelMonitorParams = ( - self._model_monitor_client.get_model_monitor_by_name(name_id, statement_params=self.statement_params) - ) - return model_monitor.ModelMonitor._ref( model_monitor_client=self._model_monitor_client, - name=name_id, - fully_qualified_model_name=model_monitor_params["fully_qualified_model_name"], - version_name=sql_identifier.SqlIdentifier(model_monitor_params["version_name"]), - function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]), - prediction_columns=model_monitor_params["prediction_columns"], - label_columns=model_monitor_params["label_columns"], + name=monitor_name_id, ) def show_model_monitors(self) -> List[snowpark.Row]: @@ -345,7 +221,7 @@ def show_model_monitors(self) -> List[snowpark.Row]: Returns: List of snowpark.Row containing metadata for each model monitor. """ - return self._model_monitor_client.get_all_model_monitor_metadata() + return self._model_monitor_client.show_model_monitors(statement_params=self.statement_params) def delete_monitor(self, name: str) -> None: """Delete a Model Monitor from the Registry @@ -353,10 +229,10 @@ def delete_monitor(self, name: str) -> None: Args: name: Name of the Model Monitor to delete. """ - name_id = sql_identifier.SqlIdentifier(name) - monitor_params = self._model_monitor_client.get_model_monitor_by_name(name_id) - _, _, model = sql_identifier.parse_fully_qualified_name(monitor_params["fully_qualified_model_name"]) - version = sql_identifier.SqlIdentifier(monitor_params["version_name"]) - self._model_monitor_client.delete_monitor_metadata(name_id) - self._model_monitor_client.delete_baseline_table(model, version) - self._model_monitor_client.delete_dynamic_tables(model, version) + database_name_id, schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(name) + self._model_monitor_client.drop_model_monitor( + database_name=database_name_id, + schema_name=schema_name_id, + monitor_name=monitor_name_id, + statement_params=self.statement_params, + ) diff --git a/snowflake/ml/monitoring/_manager/model_monitor_manager_test.py b/snowflake/ml/monitoring/_manager/model_monitor_manager_test.py index a45b7eeb..20cd61cb 100644 --- a/snowflake/ml/monitoring/_manager/model_monitor_manager_test.py +++ b/snowflake/ml/monitoring/_manager/model_monitor_manager_test.py @@ -8,14 +8,9 @@ from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.model import model_signature, type_hints from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema -from snowflake.ml.monitoring._client import model_monitor_sql_client from snowflake.ml.monitoring._manager import model_monitor_manager -from snowflake.ml.monitoring.entities import ( - model_monitor_config, - model_monitor_interval, - output_score_type, -) -from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.ml.monitoring.entities import model_monitor_config +from snowflake.ml.test_utils import mock_session from snowflake.snowpark import Row, Session @@ -70,49 +65,25 @@ def setUp(self) -> None: model_function_name="predict", background_compute_warehouse_name=self.test_warehouse, ) - self.test_table_config = model_monitor_config.ModelMonitorTableConfig( - prediction_columns=["A"], - label_columns=["B"], + self.test_source_config = model_monitor_config.ModelMonitorSourceConfig( + prediction_score_columns=["A"], + actual_score_columns=["B"], id_columns=["C"], timestamp_column="D", - source_table=self.test_source_table_name, + source=self.test_source_table_name, ) self._init_mm_with_patch() def tearDown(self) -> None: self.m_session.finalize() - def test_validate_monitor_config(self) -> None: - malformed_refresh = "BAD BAD" - mm_config = model_monitor_config.ModelMonitorConfig( - model_version=_build_mock_model_version(self.test_fq_model_name, self.test_version_name), - model_function_name="predict", - background_compute_warehouse_name=self.test_warehouse, - refresh_interval=malformed_refresh, - ) - with self.assertRaisesRegex(ValueError, "Failed to parse refresh interval with exception"): - self.mm._validate_monitor_config_or_raise(self.test_table_config, mm_config) - - def test_validate_name_constraints(self) -> None: - model_name, version_name = "M" * 231, "V" - m_model_version = _build_mock_model_version(model_name, version_name) - with self.assertRaisesRegex( - ValueError, - "Model name and version name exceeds maximum length of 231", - ): - model_monitor_manager._validate_name_constraints(m_model_version) - - good_model_name = "M" * 230 - m_model_version = _build_mock_model_version(good_model_name, version_name) - model_monitor_manager._validate_name_constraints(m_model_version) - - def test_fetch_task(self) -> None: + def test_validate_task_from_model_version(self) -> None: model_version = _build_mock_model_version( self.test_fq_model_name, self.test_version_name, task=type_hints.Task.UNKNOWN ) expected_msg = "Registry model must be logged with task in order to be monitored." with self.assertRaisesRegex(ValueError, expected_msg): - self.mm._fetch_task_from_model_version(model_version) + self.mm._validate_task_from_model_version(model_version) def test_validate_function_name(self) -> None: model_version = _build_mock_model_version(self.test_fq_model_name, self.test_version_name) @@ -121,43 +92,7 @@ def test_validate_function_name(self) -> None: f"Function with name {bad_function_name} does not exist in the given model version. Found: {{'predict'}}." ) with self.assertRaisesRegex(ValueError, re.escape(expected_message)): - self.mm._get_and_validate_model_function_from_model_version(bad_function_name, model_version) - - def test_get_monitor_by_model_version(self) -> None: - self.mock_model_monitor_sql_client.validate_existence.return_value = True - self.mock_model_monitor_sql_client.get_model_monitor_by_model_version.return_value = ( - model_monitor_sql_client._ModelMonitorParams( - monitor_name="TEST_MONITOR_NAME", - fully_qualified_model_name=self.test_fq_model_name, - version_name=self.test_model_version, - function_name="PREDICT", - prediction_columns=[], - label_columns=[], - ) - ) - model_monitor = self.mm.get_monitor_by_model_version(self.mv) - - self.mock_model_monitor_sql_client.validate_existence.assert_called_once_with( - self.test_fq_model_name, self.test_model_version, None - ) - self.mock_model_monitor_sql_client.get_model_monitor_by_model_version.assert_called_once_with( - model_db=self.test_db, - model_schema=self.test_schema, - model_name=self.test_model, - version_name=self.test_model_version, - statement_params=None, - ) - self.assertEqual(model_monitor.name, "TEST_MONITOR_NAME") - self.assertEqual(model_monitor._function_name, "PREDICT") - - def test_get_monitor_by_model_version_not_exists(self) -> None: - with self.assertRaisesRegex(ValueError, "ModelMonitor not found for model version"): - with mock.patch.object( - self.mm._model_monitor_client, "validate_existence", return_value=False - ) as mock_validate_existence: - self.mm.get_monitor_by_model_version(self.mv) - - mock_validate_existence.assert_called_once_with(self.test_fq_model_name, self.test_model_version, None) + self.mm._validate_model_function_from_model_version(bad_function_name, model_version) def _init_mm_with_patch(self) -> None: patcher = patch("snowflake.ml.monitoring._client.model_monitor_sql_client.ModelMonitorSQLClient", autospec=True) @@ -178,17 +113,17 @@ def setUp(self) -> None: self.test_model_version = "TEST_VERSION" self.test_model = "TEST_MODEL" - self.test_fq_model_name = f"db1.schema1.{self.test_model}" + self.test_fq_model_name = f"model_db.model_schema.{self.test_model}" self.test_source_table_name = "TEST_TABLE" self.mv = _build_mock_model_version(self.test_fq_model_name, self.test_model_version) - self.test_table_config = model_monitor_config.ModelMonitorTableConfig( - prediction_columns=["PREDICTION"], - label_columns=["LABEL"], + self.test_table_config = model_monitor_config.ModelMonitorSourceConfig( + prediction_score_columns=["PREDICTION"], + actual_score_columns=["LABEL"], id_columns=["ID"], timestamp_column="TS", - source_table=self.test_source_table_name, + source=self.test_source_table_name, ) self.test_monitor_config = model_monitor_config.ModelMonitorConfig( model_version=self.mv, @@ -196,10 +131,6 @@ def setUp(self) -> None: background_compute_warehouse_name=self.test_warehouse, ) session = cast(Session, self.m_session) - self.m_session.add_mock_sql( - query=f"""SHOW TABLES LIKE '_SYSTEM_MONITORING_METADATA' IN {self.test_db}.{self.test_schema}""", - result=mock_data_frame.MockDataFrame([Row(name="_SYSTEM_MONITORING_METADATA")]), - ) self.mm = model_monitor_manager.ModelMonitorManager( session, database_name=self.test_db, schema_name=self.test_schema ) @@ -208,100 +139,74 @@ def setUp(self) -> None: def tearDown(self) -> None: self.m_session.finalize() - def test_manual_init(self) -> None: - self.m_session.add_mock_sql( - query=f"""CREATE TABLE IF NOT EXISTS {self.test_db}.{self.test_schema}._SYSTEM_MONITORING_METADATA - (MONITOR_NAME VARCHAR, SOURCE_TABLE_NAME VARCHAR, FULLY_QUALIFIED_MODEL_NAME VARCHAR, - MODEL_VERSION_NAME VARCHAR, FUNCTION_NAME VARCHAR, TASK VARCHAR, IS_ENABLED BOOLEAN, - TIMESTAMP_COLUMN_NAME VARCHAR, PREDICTION_COLUMN_NAMES ARRAY, - LABEL_COLUMN_NAMES ARRAY, ID_COLUMN_NAMES ARRAY) - """, - result=mock_data_frame.MockDataFrame([Row(status="Table successfully created.")]), - ) - self.m_session.add_mock_sql( - query=f"""SHOW TABLES LIKE '_SYSTEM_MONITORING_METADATA' IN {self.test_db}.{self.test_schema}""", - result=mock_data_frame.MockDataFrame([Row(name="_SYSTEM_MONITORING_METADATA")]), - ) - session = cast(Session, self.m_session) - model_monitor_manager.ModelMonitorManager.setup(session, self.test_db, self.test_schema) - model_monitor_manager.ModelMonitorManager( - session, database_name=self.test_db, schema_name=self.test_schema, create_if_not_exists=False - ) + def test_show_monitors(self) -> None: + with mock.patch.object( + self.mm._model_monitor_client, "show_model_monitors", return_value=[] + ) as mock_show_model_monitors: + self.mm.show_model_monitors() + mock_show_model_monitors.assert_called_once_with(statement_params=None) - def test_init_fails_not_initialized(self) -> None: - self.m_session.add_mock_sql( - query=f"""SHOW TABLES LIKE '_SYSTEM_MONITORING_METADATA' IN {self.test_db}.{self.test_schema}""", - result=mock_data_frame.MockDataFrame([]), - ) - session = cast(Session, self.m_session) - expected_msg = "Monitoring has not been setup. Set create_if_not_exists or call ModelMonitorManager.setup" + def test_get_monitor_by_model_version(self) -> None: + with mock.patch.object( + self.mm._model_monitor_client, "show_model_monitors", return_value=[] + ) as mock_show_model_monitors: + with self.assertRaisesRegex(ValueError, "Unable to find model monitor for the given model version."): + self.mm.get_monitor_by_model_version(self.mv) + mock_show_model_monitors.assert_called_once_with(statement_params=None) - with self.assertRaisesRegex(ValueError, expected_msg): - model_monitor_manager.ModelMonitorManager( - session, database_name=self.test_db, schema_name=self.test_schema, create_if_not_exists=False - ) + mock_return = [Row(name="TEST", model='{"model_name": "TEST_MODEL", "version_name": "TEST_VERSION"}')] + with mock.patch.object( + self.mm._model_monitor_client, "show_model_monitors", return_value=mock_return + ) as mock_show_model_monitors: + m = self.mm.get_monitor_by_model_version(self.mv) + mock_show_model_monitors.assert_called_once_with(statement_params=None) + self.assertEqual(m.name, "TEST") def test_add_monitor(self) -> None: with mock.patch.object( - self.mm._model_monitor_client, "validate_source_table" - ) as mock_validate_source_table, mock.patch.object( + self.mm._model_monitor_client, "validate_source" + ) as mock_validate_source, mock.patch.object( self.mv, "get_model_task", return_value=type_hints.Task.TABULAR_REGRESSION ) as mock_get_model_task, mock.patch.object( - self.mm._model_monitor_client, - "get_score_type", - return_value=output_score_type.OutputScoreType.REGRESSION, - ) as mock_get_score_type, mock.patch.object( - self.mm._model_monitor_client, "create_monitor_on_model_version", return_value=None - ) as mock_create_monitor_on_model_version, mock.patch.object( - self.mm._model_monitor_client, "create_dynamic_tables_for_monitor", return_value=None - ) as mock_create_dynamic_tables_for_monitor, mock.patch.object( - self.mm._model_monitor_client, - "initialize_baseline_table", - return_value=None, - ) as mock_initialize_baseline_table: + self.mm._model_monitor_client, "create_model_monitor", return_value=None + ) as mock_create_model_monitor: self.mm.add_monitor("TEST", self.test_table_config, self.test_monitor_config) - mock_validate_source_table.assert_called_once_with( - source_table_name=self.test_source_table_name, + mock_validate_source.assert_called_once_with( + source_database=None, + source_schema=None, + source=self.test_source_table_name, timestamp_column="TS", - prediction_columns=["PREDICTION"], - label_columns=["LABEL"], + prediction_score_columns=["PREDICTION"], + prediction_class_columns=[], + actual_score_columns=["LABEL"], + actual_class_columns=[], id_columns=["ID"], - model_function=self.mv.show_functions()[0], ) mock_get_model_task.assert_called_once() - mock_get_score_type.assert_called_once() - mock_create_monitor_on_model_version.assert_called_once_with( + mock_create_model_monitor.assert_called_once_with( + monitor_database=None, + monitor_schema=None, monitor_name=sql_identifier.SqlIdentifier("TEST"), - source_table_name=sql_identifier.SqlIdentifier(self.test_source_table_name), - fully_qualified_model_name=self.test_fq_model_name, + source_database=None, + source_schema=None, + source=sql_identifier.SqlIdentifier(self.test_source_table_name), + model_database=sql_identifier.SqlIdentifier("MODEL_DB"), + model_schema=sql_identifier.SqlIdentifier("MODEL_SCHEMA"), + model_name=self.test_model, version_name=sql_identifier.SqlIdentifier(self.test_model_version), function_name="predict", - timestamp_column="TS", - prediction_columns=["PREDICTION"], - label_columns=["LABEL"], - id_columns=["ID"], - task=type_hints.Task.TABULAR_REGRESSION, - statement_params=None, - ) - mock_create_dynamic_tables_for_monitor.assert_called_once_with( - model_name="TEST_MODEL", - model_version_name="TEST_VERSION", - task=type_hints.Task.TABULAR_REGRESSION, - source_table_name=self.test_source_table_name, - refresh_interval=model_monitor_interval.ModelMonitorRefreshInterval("1 days"), - aggregation_window=model_monitor_interval.ModelMonitorAggregationWindow.WINDOW_1_DAY, - warehouse_name="TEST_WAREHOUSE", + warehouse_name=sql_identifier.SqlIdentifier(self.test_warehouse), timestamp_column="TS", id_columns=["ID"], - prediction_columns=["PREDICTION"], - label_columns=["LABEL"], - score_type=output_score_type.OutputScoreType.REGRESSION, - ) - mock_initialize_baseline_table.assert_called_once_with( - model_name="TEST_MODEL", - version_name="TEST_VERSION", - source_table_name=self.test_source_table_name, - columns_to_drop=[self.test_table_config.timestamp_column, *self.test_table_config.id_columns], + prediction_score_columns=["PREDICTION"], + prediction_class_columns=[], + actual_score_columns=["LABEL"], + actual_class_columns=[], + refresh_interval="1 hour", + aggregation_window="1 day", + baseline_database=None, + baseline_schema=None, + baseline=None, statement_params=None, ) @@ -317,63 +222,99 @@ def test_add_monitor_fails_no_task(self) -> None: self.mm.add_monitor("TEST", self.test_table_config, self.test_monitor_config) mock_validate_source_table.assert_called_once() - def test_add_monitor_fails_multiple_predictions(self) -> None: - bad_table_config = model_monitor_config.ModelMonitorTableConfig( - source_table=self.test_source_table_name, - prediction_columns=["PREDICTION1", "PREDICTION2"], - label_columns=["LABEL1", "LABEL2"], - id_columns=["ID"], - timestamp_column="TIMESTAMP", - ) - expected_error = "Multiple Output columns are not supported in monitoring" - with self.assertRaisesRegex(ValueError, expected_error): - self.mm.add_monitor("test", bad_table_config, self.test_monitor_config) - self.m_session.finalize() - - def test_add_monitor_fails_column_lengths_do_not_match(self) -> None: - bad_table_config = model_monitor_config.ModelMonitorTableConfig( - source_table=self.test_source_table_name, - prediction_columns=["PREDICTION"], - label_columns=["LABEL1", "LABEL2"], - id_columns=["ID"], - timestamp_column="TIMESTAMP", - ) - expected_msg = "Prediction and Label column names must be of the same length." - with self.assertRaisesRegex(ValueError, expected_msg): - self.mm.add_monitor( - "test", - bad_table_config, - self.test_monitor_config, + def test_add_monitor_fully_qualified_monitor_name(self) -> None: + with mock.patch.object(self.mm._model_monitor_client, "validate_source_table"), mock.patch.object( + self.mv, "get_model_task", return_value=type_hints.Task.TABULAR_REGRESSION + ), mock.patch.object(self.mm._model_monitor_client, "create_model_monitor") as mock_create_model_monitor: + self.mm.add_monitor("TEST_DB.TEST_SCHEMA.TEST", self.test_table_config, self.test_monitor_config) + mock_create_model_monitor.assert_called_once_with( + monitor_database=sql_identifier.SqlIdentifier("TEST_DB"), + monitor_schema=sql_identifier.SqlIdentifier("TEST_SCHEMA"), + monitor_name=sql_identifier.SqlIdentifier("TEST"), + source_database=None, + source_schema=None, + source=sql_identifier.SqlIdentifier(self.test_source_table_name), + model_database=sql_identifier.SqlIdentifier("MODEL_DB"), + model_schema=sql_identifier.SqlIdentifier("MODEL_SCHEMA"), + model_name=self.test_model, + version_name=sql_identifier.SqlIdentifier(self.test_model_version), + function_name="predict", + warehouse_name=sql_identifier.SqlIdentifier(self.test_warehouse), + timestamp_column="TS", + id_columns=["ID"], + prediction_score_columns=["PREDICTION"], + prediction_class_columns=[], + actual_score_columns=["LABEL"], + actual_class_columns=[], + refresh_interval="1 hour", + aggregation_window="1 day", + baseline_database=None, + baseline_schema=None, + baseline=None, + statement_params=None, ) - self.m_session.finalize() - def test_delete_monitor(self) -> None: monitor = "TEST" - model = "TEST" - version = "V1" - monitor_params = model_monitor_sql_client._ModelMonitorParams( - monitor_name=monitor, - fully_qualified_model_name=f"TEST_DB.TEST_SCHEMA.{model}", - version_name=version, - function_name="predict", - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - ) - with mock.patch.object( - self.mm._model_monitor_client, "get_model_monitor_by_name", return_value=monitor_params - ) as mock_get_model_monitor_by_name, mock.patch.object( - self.mm._model_monitor_client, "delete_monitor_metadata" - ) as mock_delete_monitor_metadata, mock.patch.object( - self.mm._model_monitor_client, "delete_baseline_table" - ) as mock_delete_baseline_table, mock.patch.object( - self.mm._model_monitor_client, "delete_dynamic_tables" - ) as mock_delete_dynamic_tables: + with mock.patch.object(self.mm._model_monitor_client, "drop_model_monitor") as mock_drop_model_monitor: self.mm.delete_monitor(monitor) - mock_get_model_monitor_by_name.assert_called_once_with(monitor) - mock_delete_monitor_metadata.assert_called_once_with(sql_identifier.SqlIdentifier(monitor)) - mock_delete_baseline_table.assert_called_once_with(model, version) - mock_delete_dynamic_tables.assert_called_once_with(model, version) + mock_drop_model_monitor.assert_called_once_with( + database_name=None, schema_name=None, monitor_name="TEST", statement_params=mock.ANY + ) + + monitor = "TEST_DB.TEST_SCHEMA.TEST" + with mock.patch.object(self.mm._model_monitor_client, "drop_model_monitor") as mock_drop_model_monitor: + self.mm.delete_monitor(monitor) + mock_drop_model_monitor.assert_called_once_with( + database_name="TEST_DB", schema_name="TEST_SCHEMA", monitor_name="TEST", statement_params=mock.ANY + ) + + def test_add_monitor_objects_in_different_schemas(self) -> None: + source_config = model_monitor_config.ModelMonitorSourceConfig( + prediction_score_columns=["PREDICTION"], + actual_score_columns=["LABEL"], + id_columns=["ID"], + timestamp_column="TS", + source="SOURCE_DB.SOURCE_SCHEMA.SOURCE", + baseline="BASELINE_DB.BASELINE_SCHEMA.BASELINE", + ) + monitor_config = model_monitor_config.ModelMonitorConfig( + model_version=_build_mock_model_version("MODEL_DB.MODEL_SCHEMA.MODEL", self.test_model_version), + model_function_name="predict", + background_compute_warehouse_name=self.test_warehouse, + ) + with mock.patch.object(self.mm._model_monitor_client, "validate_source_table"), mock.patch.object( + self.mv, "get_model_task", return_value=type_hints.Task.TABULAR_REGRESSION + ), mock.patch.object( + self.mm._model_monitor_client, "create_model_monitor", return_value=None + ) as mock_create_model_monitor: + self.mm.add_monitor("MONITOR_DB.MONITOR_SCHEMA.MONITOR", source_config, monitor_config) + mock_create_model_monitor.assert_called_once_with( + monitor_database=sql_identifier.SqlIdentifier("MONITOR_DB"), + monitor_schema=sql_identifier.SqlIdentifier("MONITOR_SCHEMA"), + monitor_name=sql_identifier.SqlIdentifier("MONITOR"), + source_database=sql_identifier.SqlIdentifier("SOURCE_DB"), + source_schema=sql_identifier.SqlIdentifier("SOURCE_SCHEMA"), + source=sql_identifier.SqlIdentifier("SOURCE"), + model_database=sql_identifier.SqlIdentifier("MODEL_DB"), + model_schema=sql_identifier.SqlIdentifier("MODEL_SCHEMA"), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier(self.test_model_version), + function_name="predict", + warehouse_name=sql_identifier.SqlIdentifier(self.test_warehouse), + timestamp_column="TS", + id_columns=["ID"], + prediction_score_columns=["PREDICTION"], + prediction_class_columns=[], + actual_score_columns=["LABEL"], + actual_class_columns=[], + refresh_interval="1 hour", + aggregation_window="1 day", + baseline_database=sql_identifier.SqlIdentifier("BASELINE_DB"), + baseline_schema=sql_identifier.SqlIdentifier("BASELINE_SCHEMA"), + baseline=sql_identifier.SqlIdentifier("BASELINE"), + statement_params=None, + ) if __name__ == "__main__": diff --git a/snowflake/ml/monitoring/entities/BUILD.bazel b/snowflake/ml/monitoring/entities/BUILD.bazel index 77faa62c..99ee6c9d 100644 --- a/snowflake/ml/monitoring/entities/BUILD.bazel +++ b/snowflake/ml/monitoring/entities/BUILD.bazel @@ -6,7 +6,6 @@ py_library( name = "entities_lib", srcs = [ "model_monitor_config.py", - "model_monitor_interval.py", "output_score_type.py", ], deps = [ @@ -24,14 +23,3 @@ py_test( ":entities_lib", ], ) - -py_test( - name = "model_monitor_interval_test", - srcs = [ - "model_monitor_interval_test.py", - ], - deps = [ - ":entities_lib", - "//snowflake/ml/test_utils:mock_session", - ], -) diff --git a/snowflake/ml/monitoring/entities/model_monitor_config.py b/snowflake/ml/monitoring/entities/model_monitor_config.py index f4083d14..d030fe16 100644 --- a/snowflake/ml/monitoring/entities/model_monitor_config.py +++ b/snowflake/ml/monitoring/entities/model_monitor_config.py @@ -1,17 +1,19 @@ from dataclasses import dataclass -from typing import List +from typing import List, Optional from snowflake.ml.model._client.model import model_version_impl -from snowflake.ml.monitoring.entities import model_monitor_interval @dataclass -class ModelMonitorTableConfig: - source_table: str +class ModelMonitorSourceConfig: + source: str timestamp_column: str - prediction_columns: List[str] - label_columns: List[str] id_columns: List[str] + prediction_score_columns: Optional[List[str]] = None + prediction_class_columns: Optional[List[str]] = None + actual_score_columns: Optional[List[str]] = None + actual_class_columns: Optional[List[str]] = None + baseline: Optional[str] = None @dataclass @@ -22,7 +24,5 @@ class ModelMonitorConfig: model_function_name: str background_compute_warehouse_name: str # TODO: Add support for pythonic notion of time. - refresh_interval: str = model_monitor_interval.ModelMonitorRefreshInterval.DAILY - aggregation_window: model_monitor_interval.ModelMonitorAggregationWindow = ( - model_monitor_interval.ModelMonitorAggregationWindow.WINDOW_1_DAY - ) + refresh_interval: str = "1 hour" + aggregation_window: str = "1 day" diff --git a/snowflake/ml/monitoring/entities/model_monitor_interval.py b/snowflake/ml/monitoring/entities/model_monitor_interval.py deleted file mode 100644 index f9ec1ddd..00000000 --- a/snowflake/ml/monitoring/entities/model_monitor_interval.py +++ /dev/null @@ -1,46 +0,0 @@ -from enum import Enum - - -class ModelMonitorAggregationWindow(Enum): - WINDOW_1_HOUR = 60 - WINDOW_1_DAY = 24 * 60 - - def __init__(self, minutes: int) -> None: - super().__init__() - self.minutes = minutes - - -class ModelMonitorRefreshInterval: - EVERY_30_MINUTES = "30 minutes" - HOURLY = "1 hours" - EVERY_6_HOURS = "6 hours" - EVERY_12_HOURS = "12 hours" - DAILY = "1 days" - WEEKLY = "7 days" - BIWEEKLY = "14 days" - MONTHLY = "30 days" - - _ALLOWED_TIME_UNITS = {"minutes": 1, "hours": 60, "days": 24 * 60} - - def __init__(self, raw_time_str: str) -> None: - try: - num_units_raw, time_units = raw_time_str.strip().split(" ") - num_units = int(num_units_raw) # try to cast - except Exception as e: - raise ValueError( - f"""Failed to parse refresh interval with exception {e}. - Provide ' '. -See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info.""" - ) - if time_units.lower() not in self._ALLOWED_TIME_UNITS: - raise ValueError( - """Invalid time unit in refresh interval. Provide ' '. -See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info.""" - ) - minutes_multiplier = self._ALLOWED_TIME_UNITS[time_units.lower()] - self.minutes = num_units * minutes_multiplier - - def __eq__(self, value: object) -> bool: - if not isinstance(value, ModelMonitorRefreshInterval): - return False - return self.minutes == value.minutes diff --git a/snowflake/ml/monitoring/entities/model_monitor_interval_test.py b/snowflake/ml/monitoring/entities/model_monitor_interval_test.py deleted file mode 100644 index e8a2f913..00000000 --- a/snowflake/ml/monitoring/entities/model_monitor_interval_test.py +++ /dev/null @@ -1,41 +0,0 @@ -from absl.testing import absltest - -from snowflake.ml.monitoring.entities import model_monitor_interval - - -class ModelMonitorIntervalTest(absltest.TestCase): - def setUp(self) -> None: - super().setUp() - - def test_validate_monitor_config(self) -> None: - with self.assertRaisesRegex(ValueError, "Failed to parse refresh interval with exception"): - model_monitor_interval.ModelMonitorRefreshInterval("UNINITIALIZED") - - with self.assertRaisesRegex(ValueError, "Invalid time unit in refresh interval."): - model_monitor_interval.ModelMonitorRefreshInterval("4 years") - - with self.assertRaisesRegex(ValueError, "Failed to parse refresh interval with exception."): - model_monitor_interval.ModelMonitorRefreshInterval("2.5 hours") - ri = model_monitor_interval.ModelMonitorRefreshInterval("1 hours") - self.assertEqual(ri.minutes, 60) - - def test_predefined_refresh_intervals(self) -> None: - min_30 = model_monitor_interval.ModelMonitorRefreshInterval.EVERY_30_MINUTES - hr_1 = model_monitor_interval.ModelMonitorRefreshInterval.HOURLY - hr_6 = model_monitor_interval.ModelMonitorRefreshInterval.EVERY_6_HOURS - day_1 = model_monitor_interval.ModelMonitorRefreshInterval.DAILY - day_7 = model_monitor_interval.ModelMonitorRefreshInterval.WEEKLY - day_14 = model_monitor_interval.ModelMonitorRefreshInterval.BIWEEKLY - day_30 = model_monitor_interval.ModelMonitorRefreshInterval.MONTHLY - - self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(min_30).minutes, 30) - self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(hr_1).minutes, 60) - self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(hr_6).minutes, 6 * 60) - self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(day_1).minutes, 24 * 60) - self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(day_7).minutes, 7 * 24 * 60) - self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(day_14).minutes, 14 * 24 * 60) - self.assertEqual(model_monitor_interval.ModelMonitorRefreshInterval(day_30).minutes, 30 * 24 * 60) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/monitoring/model_monitor.py b/snowflake/ml/monitoring/model_monitor.py index 61287a6c..869280eb 100644 --- a/snowflake/ml/monitoring/model_monitor.py +++ b/snowflake/ml/monitoring/model_monitor.py @@ -1,8 +1,3 @@ -from typing import List, Union - -import pandas as pd - -from snowflake import snowpark from snowflake.ml._internal import telemetry from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.monitoring._client import model_monitor_sql_client @@ -13,11 +8,11 @@ class ModelMonitor: name: sql_identifier.SqlIdentifier _model_monitor_client: model_monitor_sql_client.ModelMonitorSQLClient - _fully_qualified_model_name: str - _version_name: sql_identifier.SqlIdentifier - _function_name: sql_identifier.SqlIdentifier - _prediction_columns: List[sql_identifier.SqlIdentifier] - _label_columns: List[sql_identifier.SqlIdentifier] + + statement_params = telemetry.get_statement_params( + telemetry.TelemetryProject.MLOPS.value, + telemetry.TelemetrySubProject.MONITORING.value, + ) def __init__(self) -> None: raise RuntimeError("ModelMonitor's initializer is not meant to be used.") @@ -27,100 +22,16 @@ def _ref( cls, model_monitor_client: model_monitor_sql_client.ModelMonitorSQLClient, name: sql_identifier.SqlIdentifier, - *, - fully_qualified_model_name: str, - version_name: sql_identifier.SqlIdentifier, - function_name: sql_identifier.SqlIdentifier, - prediction_columns: List[sql_identifier.SqlIdentifier], - label_columns: List[sql_identifier.SqlIdentifier], ) -> "ModelMonitor": self: "ModelMonitor" = object.__new__(cls) self.name = name self._model_monitor_client = model_monitor_client - self._fully_qualified_model_name = fully_qualified_model_name - self._version_name = version_name - self._function_name = function_name - self._prediction_columns = prediction_columns - self._label_columns = label_columns return self - @telemetry.send_api_usage_telemetry( - project=telemetry.TelemetryProject.MLOPS.value, - subproject=telemetry.TelemetrySubProject.MONITORING.value, - ) - def set_baseline(self, baseline_df: Union[pd.DataFrame, snowpark.DataFrame]) -> None: - """ - The baseline dataframe is compared with the monitored data once monitoring is enabled. - The columns of the dataframe should match the columns of the source table that the - ModelMonitor was configured with. Calling this method overwrites any existing baseline split data. - - Args: - baseline_df: Snowpark dataframe containing baseline data. - - Raises: - ValueError: baseline_df does not contain prediction or label columns - """ - statement_params = telemetry.get_statement_params( - project=telemetry.TelemetryProject.MLOPS.value, - subproject=telemetry.TelemetrySubProject.MONITORING.value, - ) - - if isinstance(baseline_df, pd.DataFrame): - baseline_df = self._model_monitor_client._sql_client._session.create_dataframe(baseline_df) - - column_names_identifiers: List[sql_identifier.SqlIdentifier] = [ - sql_identifier.SqlIdentifier(column_name) for column_name in baseline_df.columns - ] - prediction_cols_not_found = any( - [prediction_col not in column_names_identifiers for prediction_col in self._prediction_columns] - ) - label_cols_not_found = any( - [label_col.identifier() not in column_names_identifiers for label_col in self._label_columns] - ) - - if prediction_cols_not_found: - raise ValueError( - "Specified prediction columns were not found in the baseline dataframe. " - f"Columns provided were: {column_names_identifiers}. " - f"Configured prediction columns were: {self._prediction_columns}." - ) - if label_cols_not_found: - raise ValueError( - "Specified label columns were not found in the baseline dataframe." - f"Columns provided in the baseline dataframe were: {column_names_identifiers}." - f"Configured label columns were: {self._label_columns}." - ) - - # Create the table by materializing the df - self._model_monitor_client.materialize_baseline_dataframe( - baseline_df, - self._fully_qualified_model_name, - self._version_name, - statement_params=statement_params, - ) - def suspend(self) -> None: """Suspend pipeline for ModelMonitor""" - statement_params = telemetry.get_statement_params( - telemetry.TelemetryProject.MLOPS.value, - telemetry.TelemetrySubProject.MONITORING.value, - ) - _, _, model_name = sql_identifier.parse_fully_qualified_name(self._fully_qualified_model_name) - self._model_monitor_client.suspend_monitor_dynamic_tables( - model_name=model_name, - version_name=self._version_name, - statement_params=statement_params, - ) + self._model_monitor_client.suspend_monitor(self.name, statement_params=self.statement_params) def resume(self) -> None: """Resume pipeline for ModelMonitor""" - statement_params = telemetry.get_statement_params( - telemetry.TelemetryProject.MLOPS.value, - telemetry.TelemetrySubProject.MONITORING.value, - ) - _, _, model_name = sql_identifier.parse_fully_qualified_name(self._fully_qualified_model_name) - self._model_monitor_client.resume_monitor_dynamic_tables( - model_name=model_name, - version_name=self._version_name, - statement_params=statement_params, - ) + self._model_monitor_client.resume_monitor(self.name, statement_params=self.statement_params) diff --git a/snowflake/ml/monitoring/model_monitor_test.py b/snowflake/ml/monitoring/model_monitor_test.py index eaaaa23b..e1af4e48 100644 --- a/snowflake/ml/monitoring/model_monitor_test.py +++ b/snowflake/ml/monitoring/model_monitor_test.py @@ -1,13 +1,10 @@ -from typing import cast from unittest import mock -import pandas as pd from absl.testing import absltest from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.monitoring import model_monitor -from snowflake.ml.test_utils import mock_data_frame, mock_session -from snowflake.snowpark import DataFrame, Row +from snowflake.ml.test_utils import mock_session class ModelMonitorInstanceTest(absltest.TestCase): @@ -17,140 +14,22 @@ def setUp(self) -> None: self.test_schema_name = sql_identifier.SqlIdentifier("METADATA") self.test_monitor_name = sql_identifier.SqlIdentifier("TEST") - self.test_model_version_name = sql_identifier.SqlIdentifier("TEST_MODEL_VERSION") - self.test_model_name = sql_identifier.SqlIdentifier("TEST_MODEL") - self.test_fq_model_name = f"{self.test_db_name}.{self.test_schema_name}.{self.test_model_name}" - self.test_prediction_column_name = sql_identifier.SqlIdentifier("PREDICTION") - self.test_label_column_name = sql_identifier.SqlIdentifier("LABEL") self.monitor_sql_client = mock.MagicMock(name="sql_client") self.model_monitor = model_monitor.ModelMonitor._ref( model_monitor_client=self.monitor_sql_client, name=self.test_monitor_name, - fully_qualified_model_name=self.test_fq_model_name, - version_name=self.test_model_version_name, - function_name=sql_identifier.SqlIdentifier("predict"), - prediction_columns=[sql_identifier.SqlIdentifier(self.test_prediction_column_name)], - label_columns=[sql_identifier.SqlIdentifier(self.test_label_column_name)], ) - def test_set_baseline(self) -> None: - baseline_df = mock_data_frame.MockDataFrame( - [ - Row( - ID=1, - TIMESTAMP=1, - PREDICTION=0.5, - LABEL=1, - ), - Row( - ID=2, - TIMESTAMP=2, - PREDICTION=0.6, - LABEL=0, - ), - ], - columns=[ - "ID", - "TIMESTAMP", - "PREDICTION", - "LABEL", - ], - ) - with mock.patch.object(self.monitor_sql_client, "materialize_baseline_dataframe") as mock_materialize: - self.model_monitor.set_baseline(cast(DataFrame, baseline_df)) - mock_materialize.assert_called_once_with( - baseline_df, self.test_fq_model_name, self.test_model_version_name, statement_params=mock.ANY - ) - - def test_set_baseline_pandas_df(self) -> None: - # Initialize a test pandas dataframe - pandas_baseline_df = pd.DataFrame( - { - "ID": [1, 2], - "TIMESTAMP": [1, 2], - "PREDICTION": [0.5, 0.6], - "LABEL": [1, 0], - } - ) - snowflake_baseline_df = mock_data_frame.MockDataFrame( - [ - Row( - ID=1, - TIMESTAMP=1, - PREDICTION=0.5, - LABEL=1, - ), - Row( - ID=2, - TIMESTAMP=2, - PREDICTION=0.6, - LABEL=0, - ), - ], - columns=[ - "ID", - "TIMESTAMP", - "PREDICTION", - "LABEL", - ], - ) - - with mock.patch.object( - self.monitor_sql_client, "materialize_baseline_dataframe" - ) as mock_materialize, mock.patch.object(self.monitor_sql_client._sql_client, "_session"), mock.patch.object( - self.monitor_sql_client._sql_client._session, "create_dataframe", return_value=snowflake_baseline_df - ) as mock_create_df: - self.model_monitor.set_baseline(pandas_baseline_df) - mock_materialize.assert_called_once_with( - snowflake_baseline_df, self.test_fq_model_name, self.test_model_version_name, statement_params=mock.ANY - ) - mock_create_df.assert_called_once_with(pandas_baseline_df) - - def test_set_baseline_missing_columns(self) -> None: - baseline_df = mock_data_frame.MockDataFrame( - [ - Row( - ID=1, - TIMESTAMP=1, - PREDICTION=0.5, - LABEL=1, - ), - Row( - ID=2, - TIMESTAMP=2, - PREDICTION=0.6, - LABEL=0, - ), - ], - columns=[ - "ID", - "TIMESTAMP", - "LABEL", - ], - ) - - expected_msg = "Specified prediction columns were not found in the baseline dataframe. Columns provided were: " - with self.assertRaisesRegex(ValueError, expected_msg): - self.model_monitor.set_baseline(cast(DataFrame, baseline_df)) - def test_suspend(self) -> None: - with mock.patch.object( - self.model_monitor._model_monitor_client, "suspend_monitor_dynamic_tables" - ) as mock_suspend: + with mock.patch.object(self.model_monitor._model_monitor_client, "suspend_monitor") as mock_suspend: self.model_monitor.suspend() - mock_suspend.assert_called_once_with( - model_name=self.test_model_name, version_name=self.test_model_version_name, statement_params=mock.ANY - ) + mock_suspend.assert_called_once_with(self.test_monitor_name, statement_params=mock.ANY) def test_resume(self) -> None: - with mock.patch.object( - self.model_monitor._model_monitor_client, "resume_monitor_dynamic_tables" - ) as mock_suspend: + with mock.patch.object(self.model_monitor._model_monitor_client, "resume_monitor") as mock_resume: self.model_monitor.resume() - mock_suspend.assert_called_once_with( - model_name=self.test_model_name, version_name=self.test_model_version_name, statement_params=mock.ANY - ) + mock_resume.assert_called_once_with(self.test_monitor_name, statement_params=mock.ANY) if __name__ == "__main__": diff --git a/snowflake/ml/registry/registry.py b/snowflake/ml/registry/registry.py index f7e4b4fd..f920dc9d 100644 --- a/snowflake/ml/registry/registry.py +++ b/snowflake/ml/registry/registry.py @@ -23,6 +23,11 @@ _TELEMETRY_PROJECT = "MLOps" _MODEL_TELEMETRY_SUBPROJECT = "ModelManagement" +_MODEL_MONITORING_UNIMPLEMENTED_ERROR = "Model Monitoring is not implemented in python yet." +_MODEL_MONITORING_DISABLED_ERROR = ( + """Must enable monitoring to use this method. Please set `options={"enable_monitoring": True}` in the Registry""" +) + class Registry: def __init__( @@ -84,7 +89,6 @@ def __init__( session=session, database_name=self._database_name, schema_name=self._schema_name, - create_if_not_exists=True, # TODO: Support static setup method to configure schema for monitoring. statement_params=monitor_statement_params, ) @@ -381,34 +385,25 @@ def delete_model(self, model_name: str) -> None: def add_monitor( self, name: str, - table_config: model_monitor_config.ModelMonitorTableConfig, + source_config: model_monitor_config.ModelMonitorSourceConfig, model_monitor_config: model_monitor_config.ModelMonitorConfig, - *, - add_dashboard_udtfs: bool = False, ) -> model_monitor.ModelMonitor: """Add a Model Monitor to the Registry Args: name: Name of Model Monitor to create - table_config: Configuration options of table for ModelMonitor. + source_config: Configuration options of table for ModelMonitor. model_monitor_config: Configuration options of ModelMonitor. - add_dashboard_udtfs: Add UDTFs useful for creating a dashboard. Returns: The newly added ModelMonitor object. Raises: - ValueError: If monitoring feature flag is not enabled. + ValueError: If monitoring is not enabled in the Registry. """ if not self.enable_monitoring: - raise ValueError( - "Must enable monitoring in Registry to use this method. Please set the `enable_monitoring=True` option" - ) - - # TODO: Change to fully qualified source table reference to allow table to live in different DB. - return self._model_monitor_manager.add_monitor( - name, table_config, model_monitor_config, add_dashboard_udtfs=add_dashboard_udtfs - ) + raise ValueError(_MODEL_MONITORING_DISABLED_ERROR) + return self._model_monitor_manager.add_monitor(name, source_config, model_monitor_config) @overload def get_monitor(self, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor: @@ -446,17 +441,14 @@ def get_monitor( The fetched ModelMonitor. Raises: - ValueError: If monitoring feature flag is not enabled. - ValueError: If neither name nor model_version specified. + ValueError: If monitoring is not enabled in the Registry. """ if not self.enable_monitoring: - raise ValueError( - "Must enable monitoring in Registry to use this method. Please set the `enable_monitoring=True` option" - ) + raise ValueError(_MODEL_MONITORING_DISABLED_ERROR) if name is not None: return self._model_monitor_manager.get_monitor(name=name) elif model_version is not None: - return self._model_monitor_manager.get_monitor_by_model_version(model_version=model_version) + return self._model_monitor_manager.get_monitor_by_model_version(model_version) else: raise ValueError("Must provide either `name` or `model_version` to get ModelMonitor") @@ -472,12 +464,10 @@ def show_model_monitors(self) -> List[snowpark.Row]: List of snowpark.Row containing metadata for each model monitor. Raises: - ValueError: If monitoring feature flag is not enabled. + ValueError: If monitoring is not enabled in the Registry. """ if not self.enable_monitoring: - raise ValueError( - "Must enable monitoring in Registry to use this method. Please set the `enable_monitoring=True` option" - ) + raise ValueError(_MODEL_MONITORING_DISABLED_ERROR) return self._model_monitor_manager.show_model_monitors() @telemetry.send_api_usage_telemetry( @@ -492,10 +482,8 @@ def delete_monitor(self, name: str) -> None: name: Name of the Model Monitor to delete. Raises: - ValueError: If monitoring feature flag is not enabled. + ValueError: If monitoring is not enabled in the registry. """ if not self.enable_monitoring: - raise ValueError( - "Must enable monitoring in Registry to use this method. Please set the `enable_monitoring=True` option" - ) + raise ValueError(_MODEL_MONITORING_DISABLED_ERROR) self._model_monitor_manager.delete_monitor(name) diff --git a/snowflake/ml/registry/registry_test.py b/snowflake/ml/registry/registry_test.py index 1f231129..3e048b34 100644 --- a/snowflake/ml/registry/registry_test.py +++ b/snowflake/ml/registry/registry_test.py @@ -1,17 +1,16 @@ from typing import cast from unittest import mock -from unittest.mock import patch from absl.testing import absltest -from snowflake.ml.model import model_signature, type_hints +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import type_hints from snowflake.ml.model._client.model import model_version_impl -from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.monitoring import model_monitor from snowflake.ml.monitoring.entities import model_monitor_config from snowflake.ml.registry import registry -from snowflake.ml.test_utils import mock_data_frame, mock_session -from snowflake.snowpark import Row, Session, types +from snowflake.ml.test_utils import mock_session +from snowflake.snowpark import Row, Session class RegistryNameTest(absltest.TestCase): @@ -204,180 +203,111 @@ def test_delete_model(self) -> None: class MonitorRegistryTest(absltest.TestCase): def setUp(self) -> None: self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.test_db_name = "TEST_DB" + self.test_schema_name = "TEST_SCHEMA" self.test_monitor_name = "TEST" - self.test_source_table_name = "MODEL_OUTPUTS" - self.test_db_name = "SNOWML_OBSERVABILITY" - self.test_schema_name = "METADATA" - self.test_model_name = "test_model" - self.test_model_name_sql = "TEST_MODEL" - self.test_model_version_name = "test_model_version" - self.test_model_version_name_sql = "TEST_MODEL_VERSION" - self.test_fq_model_name = f"{self.test_db_name}.{self.test_schema_name}.{self.test_model_name}" + self.test_source = "MODEL_OUTPUTS" + self.test_warehouse = "TEST_WAREHOUSE" self.test_timestamp_column = "TIMESTAMP" - self.test_prediction_column_name = "PREDICTION" - self.test_label_column_name = "LABEL" + self.test_pred_score_column_name = "PREDICTION" + self.test_label_score_column_name = "LABEL" self.test_id_column_name = "ID" - self.test_baseline_table_name_sql = "_SNOWML_OBS_BASELINE_TEST_MODEL_TEST_MODEL_VERSION" - - model_version = mock.MagicMock() - model_version.version_name = self.test_model_version_name - model_version.model_name = self.test_model_name - model_version.fully_qualified_model_name = self.test_fq_model_name - model_version.show_functions.return_value = [ - model_manifest_schema.ModelFunctionInfo( - name="PREDICT", - target_method="predict", - target_method_function_type="FUNCTION", - signature=model_signature.ModelSignature(inputs=[], outputs=[]), - is_partitioned=False, - ) - ] - model_version.get_model_task.return_value = type_hints.Task.TABULAR_REGRESSION - self.m_model_version: model_version_impl.ModelVersion = model_version + self.m_model_version: model_version_impl.ModelVersion = mock.MagicMock() + self.test_monitor_config = model_monitor_config.ModelMonitorConfig( model_version=self.m_model_version, model_function_name="predict", background_compute_warehouse_name=self.test_warehouse, ) - self.test_table_config = model_monitor_config.ModelMonitorTableConfig( - prediction_columns=[self.test_prediction_column_name], - label_columns=[self.test_label_column_name], + self.test_table_config = model_monitor_config.ModelMonitorSourceConfig( + source=self.test_source, id_columns=[self.test_id_column_name], timestamp_column=self.test_timestamp_column, - source_table=self.test_source_table_name, - ) - - mock_struct_fields = [] - for col in ["NUM_0"]: - mock_struct_fields.append(types.StructField(col, types.FloatType(), True)) - for col in ["CAT_0"]: - mock_struct_fields.append(types.StructField(col, types.StringType(), True)) - self.mock_schema = types.StructType._from_attributes(mock_struct_fields) - - mock_struct_fields = [] - for col in ["NUM_0"]: - mock_struct_fields.append(types.StructField(col, types.FloatType(), True)) - for col in ["CAT_0"]: - mock_struct_fields.append(types.StructField(col, types.StringType(), True)) - self.mock_schema = types.StructType._from_attributes(mock_struct_fields) - - def _add_expected_monitoring_init_calls(self, model_monitor_create_if_not_exists: bool = False) -> None: - self.m_session.add_mock_sql( - query="""CREATE TABLE IF NOT EXISTS SNOWML_OBSERVABILITY.METADATA._SYSTEM_MONITORING_METADATA - (MONITOR_NAME VARCHAR, SOURCE_TABLE_NAME VARCHAR, FULLY_QUALIFIED_MODEL_NAME VARCHAR, - MODEL_VERSION_NAME VARCHAR, FUNCTION_NAME VARCHAR, TASK VARCHAR, IS_ENABLED BOOLEAN, - TIMESTAMP_COLUMN_NAME VARCHAR, PREDICTION_COLUMN_NAMES ARRAY, - LABEL_COLUMN_NAMES ARRAY, ID_COLUMN_NAMES ARRAY) - """, - result=mock_data_frame.MockDataFrame([Row(status="Table successfully created.")]), + prediction_score_columns=[self.test_pred_score_column_name], + actual_score_columns=[self.test_label_score_column_name], ) - if not model_monitor_create_if_not_exists: # this code path does validation on whether tables exist. - self.m_session.add_mock_sql( - query="""SHOW TABLES LIKE '_SYSTEM_MONITORING_METADATA' IN SNOWML_OBSERVABILITY.METADATA""", - result=mock_data_frame.MockDataFrame([Row(name="_SYSTEM_MONITORING_METADATA")]), - ) - - def test_init(self) -> None: - self._add_expected_monitoring_init_calls(model_monitor_create_if_not_exists=True) session = cast(Session, self.m_session) - r1 = registry.Registry( + self.m_r = registry.Registry( session, database_name=self.test_db_name, schema_name=self.test_schema_name, options={"enable_monitoring": True}, ) - self.assertEqual(r1.enable_monitoring, True) - - r2 = registry.Registry( - session, - database_name=self.test_db_name, - schema_name=self.test_schema_name, - ) - self.assertEqual(r2.enable_monitoring, False) - self.m_session.finalize() - - def test_add_monitor(self) -> None: - self._add_expected_monitoring_init_calls(model_monitor_create_if_not_exists=True) + def test_registry_monitoring_disabled_by_default(self) -> None: session = cast(Session, self.m_session) m_r = registry.Registry( session, database_name=self.test_db_name, schema_name=self.test_schema_name, - options={"enable_monitoring": True}, ) + + with self.assertRaisesRegex(ValueError, registry._MODEL_MONITORING_DISABLED_ERROR): + m_r.add_monitor( + self.test_monitor_name, + self.test_table_config, + self.test_monitor_config, + ) + + with self.assertRaisesRegex(ValueError, registry._MODEL_MONITORING_DISABLED_ERROR): + m_r.show_model_monitors() + + with self.assertRaisesRegex(ValueError, registry._MODEL_MONITORING_DISABLED_ERROR): + m_r.delete_monitor(self.test_monitor_name) + + with self.assertRaisesRegex(ValueError, registry._MODEL_MONITORING_DISABLED_ERROR): + m_r.get_monitor(name=self.test_monitor_name) + + with self.assertRaisesRegex(ValueError, registry._MODEL_MONITORING_DISABLED_ERROR): + m_r.get_monitor(model_version=self.m_model_version) + + def test_add_monitor(self) -> None: m_monitor = mock.Mock() m_monitor.name = self.test_monitor_name - with mock.patch.object(m_r._model_monitor_manager, "add_monitor", return_value=m_monitor) as mock_add_monitor: - monitor: model_monitor.ModelMonitor = m_r.add_monitor( + with mock.patch.object( + self.m_r._model_monitor_manager, "add_monitor", return_value=m_monitor + ) as mock_add_monitor: + self.m_r.add_monitor( self.test_monitor_name, self.test_table_config, self.test_monitor_config, ) mock_add_monitor.assert_called_once_with( - self.test_monitor_name, self.test_table_config, self.test_monitor_config, add_dashboard_udtfs=False + self.test_monitor_name, + self.test_table_config, + self.test_monitor_config, ) - self.assertEqual(monitor.name, self.test_monitor_name) self.m_session.finalize() def test_get_monitor(self) -> None: - self._add_expected_monitoring_init_calls(model_monitor_create_if_not_exists=True) - - session = cast(Session, self.m_session) - m_r = registry.Registry( - session, - database_name=self.test_db_name, - schema_name=self.test_schema_name, - options={"enable_monitoring": True}, - ) m_model_monitor: model_monitor.ModelMonitor = mock.MagicMock() - with mock.patch.object( - m_r._model_monitor_manager, "get_monitor", return_value=m_model_monitor - ) as mock_get_monitor: - m_r.get_monitor(name=self.test_monitor_name) - mock_get_monitor.assert_called_once_with(name=self.test_monitor_name) + m_model_monitor.name = sql_identifier.SqlIdentifier(self.test_monitor_name) + with mock.patch.object(self.m_r._model_monitor_manager, "get_monitor", return_value=m_model_monitor): + monitor = self.m_r.get_monitor(name=self.test_monitor_name) + self.assertEqual(f"{monitor.name}", self.test_monitor_name) self.m_session.finalize() def test_get_monitor_by_model_version(self) -> None: - self._add_expected_monitoring_init_calls(model_monitor_create_if_not_exists=True) - session = cast(Session, self.m_session) - m_r = registry.Registry( - session, - database_name=self.test_db_name, - schema_name=self.test_schema_name, - options={"enable_monitoring": True}, - ) m_model_monitor: model_monitor.ModelMonitor = mock.MagicMock() + m_model_monitor.name = sql_identifier.SqlIdentifier(self.test_monitor_name) with mock.patch.object( - m_r._model_monitor_manager, "get_monitor_by_model_version", return_value=m_model_monitor - ) as mock_get_monitor: - m_r.get_monitor(model_version=self.m_model_version) - mock_get_monitor.assert_called_once_with(model_version=self.m_model_version) + self.m_r._model_monitor_manager, "get_monitor_by_model_version", return_value=m_model_monitor + ): + monitor = self.m_r.get_monitor(model_version=self.m_model_version) + self.assertEqual(f"{monitor.name}", self.test_monitor_name) + self.m_session.finalize() - @patch("snowflake.ml.monitoring._manager.model_monitor_manager.ModelMonitorManager", autospec=True) - def test_show_model_monitors(self, m_model_monitor_manager_class: mock.MagicMock) -> None: - # Dont need to call self._add_expected_monitoring_init_calls since ModelMonitorManager.__init__ is - # auto mocked. - m_model_monitor_manager = m_model_monitor_manager_class.return_value - sql_result = [ - Row( - col1="val1", - col2="val2", - ) - ] - m_model_monitor_manager.show_model_monitors.return_value = sql_result - session = cast(Session, self.m_session) - m_r = registry.Registry( - session, - database_name=self.test_db_name, - schema_name=self.test_schema_name, - options={"enable_monitoring": True}, - ) - self.assertEqual(m_r.show_model_monitors(), sql_result) + def test_show_model_monitors(self) -> None: + sql_result = [Row(name="monitor")] + with mock.patch.object( + self.m_r._model_monitor_manager, "show_model_monitors", return_value=sql_result + ) as mock_show_model_monitors: + self.assertEqual(self.m_r.show_model_monitors(), sql_result) + mock_show_model_monitors.assert_called_once_with() if __name__ == "__main__": diff --git a/snowflake/ml/version.bzl b/snowflake/ml/version.bzl index 54685a21..20d98634 100644 --- a/snowflake/ml/version.bzl +++ b/snowflake/ml/version.bzl @@ -1,2 +1,2 @@ # This is parsed by regex in conda reciper meta file. Make sure not to break it. -VERSION = "1.7.0" +VERSION = "1.7.1" diff --git a/tests/integ/snowflake/cortex/BUILD.bazel b/tests/integ/snowflake/cortex/BUILD.bazel new file mode 100644 index 00000000..14725565 --- /dev/null +++ b/tests/integ/snowflake/cortex/BUILD.bazel @@ -0,0 +1,30 @@ +load("//bazel:py_rules.bzl", "py_test") + +package(default_visibility = [ + "//bazel:snowml_public_common", +]) + +py_test( + name = "complete_test", + timeout = "long", + srcs = ["complete_test.py"], + deps = [ + "//snowflake/cortex:init", + "//snowflake/ml/_internal/utils:snowflake_env", + "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:common_test_base", + "//tests/integ/snowflake/ml/test_utils:test_env_utils", + ], +) + +py_test( + name = "embed_text_test", + timeout = "long", + srcs = ["embed_text_test.py"], + deps = [ + "//snowflake/cortex:init", + "//snowflake/ml/_internal/utils:snowflake_env", + "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:test_env_utils", + ], +) diff --git a/tests/integ/snowflake/cortex/complete_test.py b/tests/integ/snowflake/cortex/complete_test.py index 7e22f8de..e808cddc 100644 --- a/tests/integ/snowflake/cortex/complete_test.py +++ b/tests/integ/snowflake/cortex/complete_test.py @@ -4,8 +4,10 @@ from snowflake import snowpark from snowflake.cortex import Complete, CompleteOptions +from snowflake.ml._internal.utils import snowflake_env from snowflake.ml.utils import connection_params from snowflake.snowpark import Session, functions +from tests.integ.snowflake.ml.test_utils import test_env_utils _OPTIONS = CompleteOptions( # random params max_tokens=10, @@ -20,6 +22,10 @@ ] +@absltest.skipUnless( + test_env_utils.get_current_snowflake_cloud_type() == snowflake_env.SnowflakeCloudType.AWS, + "Complete SQL only available in AWS", +) class CompleteSQLTest(absltest.TestCase): def setUp(self) -> None: self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() diff --git a/tests/integ/snowflake/cortex/embed_text_test.py b/tests/integ/snowflake/cortex/embed_text_test.py index b4128137..18faea2f 100644 --- a/tests/integ/snowflake/cortex/embed_text_test.py +++ b/tests/integ/snowflake/cortex/embed_text_test.py @@ -4,12 +4,18 @@ from snowflake import snowpark from snowflake.cortex import EmbedText768, EmbedText1024 +from snowflake.ml._internal.utils import snowflake_env from snowflake.ml.utils import connection_params from snowflake.snowpark import Session, functions +from tests.integ.snowflake.ml.test_utils import test_env_utils _TEXT = "Text to embed" +@absltest.skipUnless( + test_env_utils.get_current_snowflake_cloud_type() == snowflake_env.SnowflakeCloudType.AWS, + "Embed text only available in AWS", +) class EmbedTextTest(absltest.TestCase): def setUp(self) -> None: self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() @@ -17,23 +23,25 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - def text_embed_text_768(self) -> None: + def test_embed_text_768(self) -> None: df_in = self._session.create_dataframe([snowpark.Row(model="e5-base-v2", text=_TEXT)]) df_out = df_in.select(EmbedText768(functions.col("model"), functions.col("text"))) res = df_out.collect()[0][0] self.assertIsInstance(res, List) self.assertEqual(len(res), 768) # Check a subset. - self.assertEqual(res[:4], [-0.001, 0.002, -0.003, 0.004]) + for first, second in zip(res[:4], [-0.0174, -0.04528, -0.02869, 0.0189]): + self.assertAlmostEqual(first, second, delta=0.01) - def text_embed_text_1024(self) -> None: + def test_embed_text_1024(self) -> None: df_in = self._session.create_dataframe([snowpark.Row(model="multilingual-e5-large", text=_TEXT)]) df_out = df_in.select(EmbedText1024(functions.col("model"), functions.col("text"))) res = df_out.collect()[0][0] self.assertIsInstance(res, List) self.assertEqual(len(res), 1024) # Check a subset. - self.assertEqual(res[:4], [-0.001, 0.002, -0.003, 0.004]) + for first, second in zip(res[:4], [0.0253, 0.0085, 0.0143, -0.0387]): + self.assertAlmostEqual(first, second, delta=0.01) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py b/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py index b3b9909b..c0cb6aae 100644 --- a/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py +++ b/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py @@ -47,6 +47,10 @@ "PDAYS", "PREVIOUS", ] +# Columns in imputer must be of same type, so that we can infer output types. +COLUMNS_TO_IMPUTE = [ + "DURATION", +] label_column = ["LABEL"] IN_ML_RUNTIME_ENV_VAR = "IN_SPCS_ML_RUNTIME" feature_cols = categorical_columns + numerical_columns @@ -70,7 +74,7 @@ def _get_preprocessor(self, categorical_columns, numerical_columns, use_knn_impu ] if use_knn_imputer: - transformers.append(("knn_imputer", KNNImputer(), numerical_columns)) + transformers.append(("knn_imputer", KNNImputer(), COLUMNS_TO_IMPUTE)) return ColumnTransformer( transformers=transformers, @@ -111,7 +115,7 @@ def _get_pipeline(self, categorical_columns, numerical_columns, label_column, us ] if use_knn_imputer: - steps.insert(2, ("KNNImputer", KNNImputer(input_cols=numerical_columns, output_cols=numerical_columns))) + steps.insert(2, ("KNNImputer", KNNImputer(input_cols=COLUMNS_TO_IMPUTE, output_cols=COLUMNS_TO_IMPUTE))) return Pipeline(steps=steps) @@ -132,7 +136,7 @@ def test_fit_and_compare_results(self) -> None: [ ("cat_transformer", SkOneHotEncoder(), categorical_columns), ("num_transforms", SkMinMaxScaler(), numerical_columns), - ("num_imputer", SkKNNImputer(), numerical_columns), + ("num_imputer", SkKNNImputer(), COLUMNS_TO_IMPUTE), ] ), ), @@ -169,9 +173,10 @@ def test_fit_predict_proba_and_compare_results(self) -> None: SkColumnTransformer( [ ("cat_transformer", SkOneHotEncoder(), categorical_columns), - ("num_transforms", SkMinMaxScaler(), numerical_columns), - ("num_imputer", SkKNNImputer(), numerical_columns), - ] + ("num_transforms", SkMinMaxScaler(clip=True), numerical_columns), + ("num_imputer", SkKNNImputer(), COLUMNS_TO_IMPUTE), + ], + remainder="passthrough", ), ), ("Training", XGB_XGBClassifier()), diff --git a/tests/integ/snowflake/ml/model/_client/model/input_validation_integ_test.py b/tests/integ/snowflake/ml/model/_client/model/input_validation_integ_test.py index 8bd08070..4524d880 100644 --- a/tests/integ/snowflake/ml/model/_client/model/input_validation_integ_test.py +++ b/tests/integ/snowflake/ml/model/_client/model/input_validation_integ_test.py @@ -7,7 +7,7 @@ from snowflake.ml.model import custom_model, model_signature from snowflake.ml.registry import registry from snowflake.ml.utils import connection_params -from snowflake.snowpark import Session +from snowflake.snowpark import Session, exceptions from tests.integ.snowflake.ml.test_utils import dataframe_utils, db_manager MODEL_NAME = "TEST_MODEL" @@ -110,10 +110,14 @@ def test_default_non_strict(self) -> None: y_df_expected = pd.DataFrame([[1, 2, 3, 1], [4, 2, 5, 4]], columns=["c1", "c2", "c3", "output"]) dataframe_utils.check_sp_df_res(self._mv.run(sp_df), y_df_expected, check_dtype=False) - sp_df = self._session.create_dataframe([[1, 2, 3], [257, 2, 5]], schema=['"c1"', '"c2"', '"c3"']) - y_df_expected = pd.DataFrame([[1, 2, 3, 1], [257, 2, 5, 1]], columns=["c1", "c2", "c3", "output"]) + sp_df = self._session.create_dataframe([[None, 2, 3], [257, 2, 5]], schema=['"c1"', '"c2"', '"c3"']) + y_df_expected = pd.DataFrame([[None, 2, 3, None], [257, 2, 5, 1]], columns=["c1", "c2", "c3", "output"]) dataframe_utils.check_sp_df_res(self._mv.run(sp_df), y_df_expected, check_dtype=False) + sp_df = self._session.create_dataframe([[1, 2, 3], [257, 2, 5]], schema=['"c1"', '"c2"', '"c3"']) + with self.assertRaisesRegex(exceptions.SnowparkSQLException, "Python Interpreter Error"): + self._mv.run(sp_df).collect() + def test_strict(self) -> None: pd.testing.assert_frame_equal( self._mv.run(pd.DataFrame([[1, 2, 3], [4, 2, 5]]), strict_input_validation=True), diff --git a/tests/integ/snowflake/ml/monitoring/model_monitor_integ_test.py b/tests/integ/snowflake/ml/monitoring/model_monitor_integ_test.py index 8e0c3f54..4fbdce6b 100644 --- a/tests/integ/snowflake/ml/monitoring/model_monitor_integ_test.py +++ b/tests/integ/snowflake/ml/monitoring/model_monitor_integ_test.py @@ -1,21 +1,21 @@ +import unittest import uuid -import pandas as pd from absl.testing import absltest, parameterized -from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.model._client.model import model_version_impl from snowflake.ml.monitoring import model_monitor -from snowflake.ml.monitoring._client import model_monitor_sql_client from snowflake.ml.monitoring.entities import model_monitor_config from snowflake.ml.registry import registry from snowflake.ml.utils import connection_params from snowflake.snowpark import Session +from snowflake.snowpark.exceptions import SnowparkSQLException from tests.integ.snowflake.ml.test_utils import db_manager, model_factory INPUT_FEATURE_COLUMNS_NAMES = [f"input_feature_{i}" for i in range(64)] +# NOTE: In order to run the tests, must remove the DATA_RETENTION_TIME_IN_DAYS=0 parameter in test_utils/db_manager.py class ModelMonitorRegistryIntegrationTest(parameterized.TestCase): def _create_test_table(self, fully_qualified_table_name: str, id_column_type: str = "STRING") -> None: s = ", ".join([f"{i} FLOAT" for i in INPUT_FEATURE_COLUMNS_NAMES]) @@ -25,6 +25,15 @@ def _create_test_table(self, fully_qualified_table_name: str, id_column_type: st {s}, id {id_column_type}, timestamp TIMESTAMP)""" ).collect() + # Needed to create DT against this table + self._session.sql(f"ALTER TABLE {fully_qualified_table_name} SET CHANGE_TRACKING=TRUE").collect() + + self._session.sql( + f"""INSERT INTO {fully_qualified_table_name} + (label, prediction, {", ".join(INPUT_FEATURE_COLUMNS_NAMES)}, id, timestamp) + VALUES (1, 1, {", ".join(["1"] * 64)}, '1', CURRENT_TIMESTAMP())""" + ).collect() + @classmethod def setUpClass(cls) -> None: """Creates Snowpark and Snowflake environments for testing.""" @@ -35,22 +44,23 @@ def setUp(self) -> None: self.run_id = uuid.uuid4().hex self._db_manager = db_manager.DBManager(self._session) self._schema_name = "PUBLIC" - # TODO(jfishbein): Investigate whether conversion to sql identifier requires uppercase. self._db_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - self.run_id, "monitor_registry" + self.run_id, "TEST_MODEL_MONITORING" ).upper() - self._session.sql(f"CREATE DATABASE IF NOT EXISTS {self._db_name}").collect() - self.perm_stage = "@" + self._db_manager.create_stage( - stage_name="model_registry_test_stage", - schema_name=self._schema_name, - db_name=self._db_name, - sse_encrypted=True, - ) + # Time-travel is required for model monitoring. + self._db_manager.create_database(self._db_name, data_retention_time_in_days=1) self._warehouse_name = "REGTEST_ML_SMALL" self._db_manager.set_warehouse(self._warehouse_name) self._db_manager.cleanup_databases(expire_hours=6) - self.registry = registry.Registry(self._session, database_name=self._db_name, schema_name=self._schema_name) + self.registry = registry.Registry( + self._session, + database_name=self._db_name, + schema_name=self._schema_name, + options={"enable_monitoring": True}, + ) + + self._session.sql("ALTER SESSION SET FEATURE_ML_OBSERVABILITY=enabled").collect() def tearDown(self) -> None: self._db_manager.drop_database(self._db_name) @@ -60,28 +70,24 @@ def tearDown(self) -> None: def tearDownClass(cls) -> None: cls._session.close() - def _add_sample_model_version_and_monitor( - self, - monitor_registry: registry.Registry, - source_table: str, - model_name: str, - version_name: str, - monitor_name: str, - ) -> model_monitor.ModelMonitor: + def _add_sample_model_version(self, model_name: str, version_name: str) -> model_version_impl.ModelVersion: model, features, _ = model_factory.ModelFactory.prepare_sklearn_model() - model_version: model_version_impl.ModelVersion = self.registry.log_model( + return self.registry.log_model( model=model, model_name=model_name, version_name=version_name, sample_input_data=features, ) - return monitor_registry.add_monitor( + def _add_sample_monitor( + self, monitor_name: str, source: str, model_version: model_version_impl.ModelVersion + ) -> model_monitor.ModelMonitor: + return self.registry.add_monitor( name=monitor_name, - table_config=model_monitor_config.ModelMonitorTableConfig( - source_table=source_table, - prediction_columns=["prediction"], - label_columns=["label"], + source_config=model_monitor_config.ModelMonitorSourceConfig( + source=source, + prediction_score_columns=["prediction"], + actual_class_columns=["label"], id_columns=["id"], timestamp_column="timestamp", ), @@ -92,209 +98,130 @@ def _add_sample_model_version_and_monitor( ), ) - def test_add_model_monitor(self) -> None: - # Create an instance of the Registry class with Monitoring enabled. - _monitor_registry = registry.Registry( - session=self._session, - database_name=self._db_name, - schema_name=self._schema_name, - options={"enable_monitoring": True}, - ) - - source_table_name = "TEST_TABLE" + def _create_sample_table_model_and_monitor( + self, monitor_name: str, table_name: str, model_name: str, version_name: str = "V1" + ): + self._create_test_table(fully_qualified_table_name=f"{self._db_name}.{self._schema_name}.{table_name}") + mv = self._add_sample_model_version(model_name=model_name, version_name=version_name) + self._add_sample_monitor(monitor_name=monitor_name, source=table_name, model_version=mv) + + @unittest.skip("Not implemented in the backend yet") + def test_show_model_monitors(self): + res = self.registry.show_model_monitors() + self.assertEqual(len(res), 0) + self._create_sample_table_model_and_monitor( + monitor_name="monitor", table_name="source_table", model_name="model_name" + ) + res = self.registry.show_model_monitors() + self.assertEqual(len(res), 1) + self.assertEqual(res[0]["name"], "MONITOR") + + @unittest.skip("Not implemented in the backend yet") + def test_add_monitor_duplicate_fails(self): + source_table_name = "source_table" + model_name_original = "model_name" self._create_test_table(f"{self._db_name}.{self._schema_name}.{source_table_name}") - - model_name = "TEST_MODEL" - version_name = "TEST_VERSION" - monitor_name = f"TEST_MONITOR_{model_name}_{version_name}_{self.run_id}" - monitor = self._add_sample_model_version_and_monitor( - _monitor_registry, source_table_name, model_name, version_name, monitor_name - ) - - self.assertEqual( - self._session.sql( - f"""SELECT * - FROM - {self._db_name}.{self._schema_name}.{model_monitor_sql_client.SNOWML_MONITORING_METADATA_TABLE_NAME} - WHERE FULLY_QUALIFIED_MODEL_NAME = '{self._db_name}.{self._schema_name}.{model_name}' AND - MODEL_VERSION_NAME = '{version_name}'""" - ).count(), - 1, + mv = self._add_sample_model_version( + model_name=model_name_original, + version_name="V1", ) - - self.assertEqual( - self._session.sql( - f"""SELECT * - FROM {self._db_name}.{self._schema_name}. - _SNOWML_OBS_BASELINE_{model_name}_{version_name}""" - ).count(), - 0, + self._add_sample_monitor( + monitor_name="test_monitor_name", + source="source_table", + model_version=mv, ) - table_columns = self._session.sql( - f"""DESCRIBE TABLE - {self._db_name}.{self._schema_name}._SNOWML_OBS_BASELINE_{model_name}_{version_name}""" - ).collect() - - for col in table_columns: - self.assertTrue( - col["name"].upper() - in ["PREDICTION", "LABEL", "ID", "TIMESTAMP", *[i.upper() for i in INPUT_FEATURE_COLUMNS_NAMES]] + with self.assertRaisesRegex(SnowparkSQLException, ".*already exists.*"): + self._add_sample_monitor( + monitor_name="test_monitor_name", + source="source_table", + model_version=mv, ) - df = self._session.create_dataframe( - [ - (1.0, 1.0, *[1.0] * 64), - (1.0, 1.0, *[1.0] * 64), - ], - ["LABEL", "PREDICTION", *[i.upper() for i in INPUT_FEATURE_COLUMNS_NAMES]], - ) - monitor.set_baseline(df) - self.assertEqual( - self._session.sql( - f"""SELECT * - FROM {self._db_name}.{self._schema_name}. - _SNOWML_OBS_BASELINE_{model_name}_{version_name}""" - ).count(), - 2, - ) - - pandas_cols = { - "LABEL": [1.0, 2.0, 3.0], - "PREDICTION": [1.0, 2.0, 3.0], - } - for i in range(64): - pandas_cols[f"INPUT_FEATURE_{i}"] = [1.0, 2.0, 3.0] - - pandas_df = pd.DataFrame(pandas_cols) - monitor.set_baseline(pandas_df) - self.assertEqual( - self._session.sql( - f"""SELECT * - FROM {self._db_name}.{self._schema_name}. - _SNOWML_OBS_BASELINE_{model_name}_{version_name}""" - ).count(), - 3, - ) - - # create a snowpark dataframe that does not conform to the existing schema - df = self._session.create_dataframe( - [ - (1.0, "bad", *[1.0] * 64), - (2.0, "very_bad", *[2.0] * 64), - ], - ["LABEL", "PREDICTION", *[i.upper() for i in INPUT_FEATURE_COLUMNS_NAMES]], - ) - with self.assertRaises(ValueError) as e: - monitor.set_baseline(df) - - expected_msg = "Ensure that the baseline dataframe columns match those provided in your monitored table" - self.assertTrue(expected_msg in str(e.exception)) - expected_msg = "Numeric value 'bad' is not recognized" - self.assertTrue(expected_msg in str(e.exception)) - - # Delete monitor - _monitor_registry.delete_monitor(monitor_name) - - # Validate that metadata entry is deleted - self.assertEqual( - self._session.sql( - f"""SELECT * - FROM - {self._db_name}.{self._schema_name}.{model_monitor_sql_client.SNOWML_MONITORING_METADATA_TABLE_NAME} - WHERE MONITOR_NAME = '{monitor.name}'""" - ).count(), - 0, - ) - - # Validate that baseline table is deleted - self.assertEqual( - self._session.sql( - f"""SHOW TABLES LIKE '%{self._db_name}.{self._schema_name}. - _SNOWML_OBS_BASELINE_{model_name}_{version_name}%'""" - ).count(), - 0, - ) + with self.assertRaisesRegex(SnowparkSQLException, ".*already exists.*"): + self._add_sample_monitor( + monitor_name="test_monitor_name2", + source="source_table", + model_version=mv, + ) - # Validate that dynamic tables are deleted - self.assertEqual( - self._session.sql( - f"""SHOW TABLES LIKE '%{self._db_name}.{self._schema_name}. - _SNOWML_OBS_MONITORING_{model_name}_{version_name}%'""" - ).count(), - 0, - ) - self.assertEqual( - self._session.sql( - f"""SHOW TABLES LIKE '%{self._db_name}.{self._schema_name}. - _SNOWML_OBS_ACCURACY_{model_name}_{version_name}%'""" - ).count(), - 0, - ) + @unittest.skip("Not implemented in the backend yet") + def test_suspend_resume_monitor(self): + self._create_sample_table_model_and_monitor( + monitor_name="monitor", table_name="source_table", model_name="model_name" + ) + monitor = self.registry.get_monitor(name="monitor") + self.assertEqual(self.registry.show_model_monitors()[0]["monitor_state"], "RUNNING") + + monitor.resume() # resume while already running + self.assertEqual(self.registry.show_model_monitors()[0]["monitor_state"], "RUNNING") + monitor.suspend() # suspend after running + self.assertEqual(self.registry.show_model_monitors()[0]["monitor_state"], "SUSPENDED") + monitor.suspend() # suspend while already suspended + self.assertEqual(self.registry.show_model_monitors()[0]["monitor_state"], "SUSPENDED") + monitor.resume() # resume after suspending + self.assertEqual(self.registry.show_model_monitors()[0]["monitor_state"], "RUNNING") + + @unittest.skip("Not implemented in the backend yet") + def test_get_monitor(self): + self._create_sample_table_model_and_monitor( + monitor_name="monitor", table_name="source_table", model_name="model_name", version_name="V1" + ) + # Test get by name. + monitor = self.registry.get_monitor(name="monitor") + self.assertEqual(monitor.name, "MONITOR") + + # Test get by model_version. + model_version = self.registry.get_model("model_name").version("V1") + monitor2 = self.registry.get_monitor(model_version=model_version) + self.assertEqual(monitor2.name, "MONITOR") + + # Test get by name, doesn't exist + with self.assertRaisesRegex(ValueError, "Unable to find model monitor 'non_existent_monitor'"): + self.registry.get_monitor(name="non_existent_monitor") + + # Test get by model_version, doesn't exist + model_version_not_monitored = self._add_sample_model_version(model_name="fake_model_name", version_name="V2") + with self.assertRaisesRegex(ValueError, "Unable to find model monitor for the given model version."): + self.registry.get_monitor(model_version=model_version_not_monitored) + + @unittest.skip("Not implemented in the backend yet") + def test_delete_monitor(self) -> None: + self._create_sample_table_model_and_monitor( + monitor_name="monitor", table_name="source_table", model_name="model_name" + ) + self.assertEqual(len(self.registry.show_model_monitors()), 1) + self.registry.delete_monitor(name="monitor") + with self.assertRaisesRegex(ValueError, "Unable to find model monitor 'monitor'"): + self.registry.get_monitor(name="monitor") + self.assertEqual(len(self.registry.show_model_monitors()), 0) + + @unittest.skip("Not implemented in the backend yet") + def test_create_model_monitor_from_view(self): + res = self.registry.show_model_monitors() + self.assertEqual(len(res), 0) + + source_table_name = "source_table" + model_name_original = "model_name" + self._create_test_table(f"{self._db_name}.{self._schema_name}.{source_table_name}") + self._session.sql( + f"""CREATE OR REPLACE VIEW {self._db_name}.{self._schema_name}.{source_table_name}_view + AS SELECT * FROM {self._db_name}.{self._schema_name}.{source_table_name}""" + ).collect() - def test_add_model_monitor_varchar(self) -> None: - _monitor_registry = registry.Registry( - session=self._session, - database_name=self._db_name, - schema_name=self._schema_name, - options={"enable_monitoring": True}, + mv = self._add_sample_model_version( + model_name=model_name_original, + version_name="V1", ) - source_table = "TEST_TABLE" - self._create_test_table(f"{self._db_name}.{self._schema_name}.{source_table}", id_column_type="VARCHAR(64)") - - model_name = "TEST_MODEL" - version_name = "TEST_VERSION" - monitor_name = f"TEST_MONITOR_{model_name}_{version_name}_{self.run_id}" - self._add_sample_model_version_and_monitor( - _monitor_registry, source_table, model_name, version_name, monitor_name + self._add_sample_monitor( + monitor_name="monitor", + source="source_table_view", + model_version=mv, ) - self.assertEqual( - self._session.sql( - f"""SELECT * - FROM - {self._db_name}.{self._schema_name}.{model_monitor_sql_client.SNOWML_MONITORING_METADATA_TABLE_NAME} - WHERE FULLY_QUALIFIED_MODEL_NAME = '{self._db_name}.{self._schema_name}.{model_name}' AND - MODEL_VERSION_NAME = '{version_name}'""" - ).count(), - 1, - ) - - def test_show_model_monitors(self) -> None: - _monitor_registry = registry.Registry( - session=self._session, - database_name=self._db_name, - schema_name=self._schema_name, - options={"enable_monitoring": True}, - ) - source_table_1 = "TEST_TABLE_1" - self._create_test_table(f"{self._db_name}.{self._schema_name}.{source_table_1}") - - source_table_2 = "TEST_TABLE_2" - self._create_test_table(f"{self._db_name}.{self._schema_name}.{source_table_2}") - - model_1 = "TEST_MODEL_1" - version_1 = "TEST_VERSION_1" - monitor_1 = f"TEST_MONITOR_{model_1}_{version_1}_{self.run_id}" - self._add_sample_model_version_and_monitor(_monitor_registry, source_table_1, model_1, version_1, monitor_1) - - model_2 = "TEST_MODEL_2" - version_2 = "TEST_VERSION_2" - monitor_2 = f"TEST_MONITOR_{model_2}_{version_2}_{self.run_id}" - self._add_sample_model_version_and_monitor(_monitor_registry, source_table_2, model_2, version_2, monitor_2) - - stored_monitors = sorted(_monitor_registry.show_model_monitors(), key=lambda x: x["MONITOR_NAME"]) - self.assertEqual(len(stored_monitors), 2) - row_1 = stored_monitors[0] - self.assertEqual(row_1["MONITOR_NAME"], sql_identifier.SqlIdentifier(monitor_1)) - self.assertEqual(row_1["SOURCE_TABLE_NAME"], source_table_1) - self.assertEqual(row_1["MODEL_VERSION_NAME"], version_1) - self.assertEqual(row_1["IS_ENABLED"], True) - row_2 = stored_monitors[1] - self.assertEqual(row_2["MONITOR_NAME"], sql_identifier.SqlIdentifier(monitor_2)) - self.assertEqual(row_2["SOURCE_TABLE_NAME"], source_table_2) - self.assertEqual(row_2["MODEL_VERSION_NAME"], version_2) - self.assertEqual(row_2["IS_ENABLED"], True) + res = self.registry.show_model_monitors() + self.assertEqual(len(res), 1) + self.assertEqual(res[0]["name"], "MONITOR") if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/registry/model/random_version_name_test.py b/tests/integ/snowflake/ml/registry/model/random_version_name_test.py index e69aafb9..e135ec6a 100644 --- a/tests/integ/snowflake/ml/registry/model/random_version_name_test.py +++ b/tests/integ/snowflake/ml/registry/model/random_version_name_test.py @@ -1,4 +1,4 @@ -import numpy as np +import pandas as pd from absl.testing import absltest from sklearn import datasets, linear_model @@ -13,8 +13,10 @@ def test_random_version_name(self) -> None: regr.fit(iris_X, iris_y) name = f"model_{self._run_id}" mv = self.registry.log_model(regr, model_name=name, sample_input_data=iris_X) - np.testing.assert_allclose( - mv.run(iris_X, function_name="predict")["output_feature_0"].values, regr.predict(iris_X) + pd.testing.assert_series_equal( + mv.run(iris_X, function_name="predict")["output_feature_0"], + pd.Series(regr.predict(iris_X), name="output_feature_0"), + check_dtype=False, ) self.registry._model_manager._hrid_generator.hrid_to_id(mv.version_name.lower()) diff --git a/tests/integ/snowflake/ml/registry/model/registry_catboost_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_catboost_model_test.py index 264240c6..6550a891 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_catboost_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_catboost_model_test.py @@ -1,10 +1,15 @@ import catboost import inflection -import numpy as np import pandas as pd import shap from absl.testing import absltest, parameterized -from sklearn import datasets, model_selection +from sklearn import ( + compose, + datasets, + model_selection, + pipeline as SK_pipeline, + preprocessing, +) from snowflake.ml.model import model_signature from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema @@ -35,13 +40,72 @@ def test_catboost_classifier_no_explain( prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(classifier.predict(cal_X_test), axis=1) + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict(cal_X_test), columns=res.columns), + check_dtype=False, ), ), "predict_proba": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, classifier.predict_proba(cal_X_test)), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict_proba(cal_X_test), columns=res.columns), + check_dtype=False, + ), + ), + }, + options={"enable_explainability": False}, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_catboost_classifier_pipeline_no_explain( + self, + registry_test_fn: str, + ) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True) + cal_X = cal_data.data + cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + + classifier = SK_pipeline.Pipeline( + steps=[ + ("regressor", catboost.CatBoostClassifier()), + ] + ) + classifier.fit(cal_X_train, cal_y_train) + + y_df_expected = pd.concat( + [ + cal_X_test.reset_index(drop=True), + pd.DataFrame(classifier.predict(cal_X_test), columns=["output_feature_0"]), + ], + axis=1, + ) + y_df_expected_proba = pd.concat( + [ + cal_X_test.reset_index(drop=True), + pd.DataFrame(classifier.predict_proba(cal_X_test), columns=["output_feature_0", "output_feature_1"]), + ], + axis=1, + ) + + cal_data_sp_df_train = self.session.create_dataframe(cal_X_train) + cal_data_sp_df_test = self.session.create_dataframe(cal_X_test) + getattr(self, registry_test_fn)( + model=classifier, + sample_input_data=cal_data_sp_df_train, + prediction_assert_fns={ + "predict": ( + cal_data_sp_df_test, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + ), + "predict_proba": ( + cal_data_sp_df_test, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected_proba, check_dtype=False), ), }, options={"enable_explainability": False}, @@ -70,17 +134,27 @@ def test_catboost_classifier_explain( prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(classifier.predict(cal_X_test), axis=1) + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict(cal_X_test), columns=res.columns), + check_dtype=False, ), ), "predict_proba": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, classifier.predict_proba(cal_X_test)), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict_proba(cal_X_test), columns=res.columns), + check_dtype=False, + ), ), "explain": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, expected_explanations), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(expected_explanations, columns=res.columns), + check_dtype=False, + ), ), }, function_type_assert={ @@ -238,21 +312,35 @@ def test_catboost_with_signature_and_sample_data( prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(classifier.predict(cal_X_test), axis=1) + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict(cal_X_test), columns=res.columns), + check_dtype=False, ), ), "predict_proba": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, classifier.predict_proba(cal_X_test)), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict_proba(cal_X_test), columns=res.columns), + check_dtype=False, + ), ), "predict_log_proba": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, classifier.predict_log_proba(cal_X_test)), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict_log_proba(cal_X_test), columns=res.columns), + check_dtype=False, + ), ), "explain": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, expected_explanations), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(expected_explanations, columns=res.columns), + check_dtype=False, + ), ), }, options={"enable_explainability": True}, @@ -265,6 +353,64 @@ def test_catboost_with_signature_and_sample_data( }, ) + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_catboost_model_with_categorical_dtype_columns( + self, + registry_test_fn: str, + ) -> None: + data = { + "color": ["red", "blue", "green", "red"], + "size": [1, 2, 2, 4], + "price": [10, 15, 20, 25], + "target": [0, 1, 1, 0], + } + input_features = ["color", "size", "price"] + + df = pd.DataFrame(data) + df["color"] = df["color"].astype("category") + df["size"] = df["size"].astype("category") + + # Define categorical columns + categorical_columns = ["color", "size"] + + # Create a column transformer + preprocessor = compose.ColumnTransformer( + transformers=[ + ("cat", preprocessing.OneHotEncoder(), categorical_columns), + ], + remainder="passthrough", + ) + + pipeline = SK_pipeline.Pipeline( + [ + ("preprocessor", preprocessor), + ("classifier", catboost.CatBoostClassifier()), + ] + ) + pipeline.fit(df[input_features], df["target"]) + + def _check_predict_fn(res) -> None: + pd.testing.assert_frame_equal( + res["output_feature_0"].to_frame(), + pd.DataFrame(pipeline.predict(df[input_features]), columns=["output_feature_0"]), + check_dtype=False, + ) + + getattr(self, registry_test_fn)( + model=pipeline, + sample_input_data=df[input_features], + prediction_assert_fns={ + "predict": ( + df[input_features], + _check_predict_fn, + ), + }, + # TODO(SNOW-1677301): Add support for explainability for categorical columns + options={"enable_explainability": False}, + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_custom_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_custom_model_test.py index 44e77dfe..e3df0fa9 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_custom_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_custom_model_test.py @@ -62,6 +62,15 @@ def predict(self, input: pd.DataFrame) -> pd.DataFrame: return pd.DataFrame({"output": (input["c1"] + self.bias) > 12}) +class DemoModelWithPdSeriesOutPut(custom_model.CustomModel): + def __init__(self, context: custom_model.ModelContext) -> None: + super().__init__(context) + + @custom_model.inference_api + def predict(self, input: pd.DataFrame) -> pd.Series: + return input["c1"] + + class TestRegistryCustomModelInteg(registry_model_test_base.RegistryModelTestBase): @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, @@ -89,7 +98,7 @@ async def _test(self: "TestRegistryCustomModelInteg") -> None: pd_df, lambda res: pd.testing.assert_frame_equal( res, - pd.DataFrame(arr[:, 0], columns=["output"], dtype=float), + pd.DataFrame(arr[:, 0], columns=["output"], dtype=pd.Float64Dtype()), ), ), }, @@ -116,7 +125,7 @@ def test_large_input( pd_df, lambda res: pd.testing.assert_frame_equal( res, - pd.DataFrame(arr[:, 0], columns=["output"]), + pd.DataFrame(arr[:, 0], columns=["output"], dtype=pd.Int64Dtype()), ), ), }, @@ -187,6 +196,118 @@ def test_custom_demo_model_sp_one_query( }, ) + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_custom_demo_model_none( + self, + registry_test_fn: str, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + arr = [[1, 2, 3], [None, 2, 5]] + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + getattr(self, registry_test_fn)( + model=lm, + sample_input_data=pd_df, + prediction_assert_fns={ + "predict": ( + pd_df, + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame([1, None], columns=["output"], dtype=pd.Int64Dtype()), + ), + ) + }, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_custom_demo_model_none_multi_block( + self, + registry_test_fn: str, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + arr = [[1, 2]] * 5000 + [[None, 2]] * 5000 + pd_df = pd.DataFrame(arr, columns=["c1", "c2"]) + getattr(self, registry_test_fn)( + model=lm, + sample_input_data=pd_df, + prediction_assert_fns={ + "predict": ( + pd_df, + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame([1] * 5000 + [None] * 5000, columns=["output"], dtype=pd.Float64Dtype()), + ), + ) + }, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_custom_demo_model_none_sp( + self, + registry_test_fn: str, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + arr = [[1, 2, 3], [None, 2, 5]] + sp_df = self.session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) + y_df_expected = pd.DataFrame([[1, 2, 3, 1], [None, 2, 5, None]], columns=["c1", "c2", "c3", "output"]) + getattr(self, registry_test_fn)( + model=lm, + sample_input_data=sp_df, + prediction_assert_fns={ + "predict": (sp_df, lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False)) + }, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_custom_demo_model_none_sp_mix1( + self, + registry_test_fn: str, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + arr = [[1, 2, 3], [None, 2, 5]] + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + sp_df = self.session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) + y_df_expected = pd.DataFrame([[1, 2, 3, 1], [None, 2, 5, None]], columns=["c1", "c2", "c3", "output"]) + getattr(self, registry_test_fn)( + model=lm, + sample_input_data=pd_df, + prediction_assert_fns={ + "predict": (sp_df, lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False)) + }, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_custom_demo_model_none_sp_mix2( + self, + registry_test_fn: str, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + arr = [[1, 2, 3], [None, 2, 5]] + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + sp_df = self.session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) + getattr(self, registry_test_fn)( + model=lm, + sample_input_data=sp_df, + prediction_assert_fns={ + "predict": ( + pd_df, + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame([1, None], columns=["output"], dtype=pd.Int64Dtype()), + ), + ) + }, + ) + @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) @@ -206,7 +327,7 @@ def test_custom_demo_model_sp_quote( pd_df, lambda res: pd.testing.assert_frame_equal( res, - pd.DataFrame([1, 4], columns=['"output"'], dtype=np.int8), + pd.DataFrame([1, 4], columns=['"output"'], dtype=pd.Int8Dtype()), ), ) }, @@ -254,7 +375,7 @@ def test_custom_demo_model_sp_mix_2( pd_df, lambda res: pd.testing.assert_frame_equal( res, - pd.DataFrame([1, 4], columns=["output"], dtype=np.int8), + pd.DataFrame([1, 4], columns=["output"], dtype=pd.Int8Dtype()), ), ) }, @@ -284,6 +405,31 @@ def test_custom_demo_model_array( }, ) + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + @absltest.skip("Skip until we support pd.Series as output") + def test_custom_demo_model_pd_series( + self, + registry_test_fn: str, + ) -> None: + lm = DemoModelWithPdSeriesOutPut(custom_model.ModelContext()) + arr = np.array([[1, 2, 3], [4, 2, 5]]) + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + getattr(self, registry_test_fn)( + model=lm, + sample_input_data=pd_df, + prediction_assert_fns={ + "predict": ( + pd_df, + lambda res: pd.testing.assert_series_equal( + res["output"], + pd.Series([1, 4], name="output"), + ), + ) + }, + ) + @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) @@ -301,12 +447,37 @@ def test_custom_demo_model_str( pd_df, lambda res: pd.testing.assert_frame_equal( res, - pd.DataFrame(data={"output": ["Yogiri", "Artia"]}), + pd.DataFrame(data={"output": ["Yogiri", "Artia"]}, dtype=pd.StringDtype()), ), ) }, ) + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_custom_demo_model_str_sp_none( + self, + registry_test_fn: str, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + pd_df = pd.DataFrame([[None, "Civia", "Echo"], ["Artia", "Doris", "Rosalyn"]], columns=["c1", "c2", "c3"]) + sp_df = self.session.create_dataframe(pd_df) + y_df_expected = pd.DataFrame( + [[None, "Civia", "Echo", None], ["Artia", "Doris", "Rosalyn", "Artia"]], + columns=["c1", "c2", "c3", "output"], + ) + getattr(self, registry_test_fn)( + model=lm, + sample_input_data=sp_df, + prediction_assert_fns={ + "predict": ( + sp_df, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + ) + }, + ) + @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) @@ -398,7 +569,7 @@ def test_custom_model_with_artifacts( pd_df, lambda res: pd.testing.assert_frame_equal( res, - pd.DataFrame([False, True], columns=["output"]), + pd.DataFrame([False, True], columns=["output"], dtype=pd.BooleanDtype()), ), ) }, @@ -453,6 +624,7 @@ def predict(self, input: pd.DataFrame) -> pd.DataFrame: lambda res: pd.testing.assert_frame_equal( res, output, + check_dtype=False, ), ) }, @@ -490,6 +662,7 @@ def predict(self, input: pd.DataFrame) -> pd.DataFrame: lambda res: pd.testing.assert_frame_equal( res, pd.DataFrame([False, True], columns=["output"]), + check_dtype=False, ), ) }, diff --git a/tests/integ/snowflake/ml/registry/model/registry_huggingface_pipeline_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_huggingface_pipeline_model_test.py index 62892226..ad51bd39 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_huggingface_pipeline_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_huggingface_pipeline_model_test.py @@ -189,7 +189,7 @@ def check_res(res: pd.DataFrame) -> None: self.assertEqual(res["score"].dtype.type, np.float64) self.assertEqual(res["start"].dtype.type, np.int64) self.assertEqual(res["end"].dtype.type, np.int64) - self.assertEqual(res["answer"].dtype.type, np.object_) + self.assertEqual(res["answer"].dtype.type, str) getattr(self, registry_test_fn)( model=model, @@ -281,7 +281,7 @@ def test_summarization_pipeline( def check_res(res: pd.DataFrame) -> None: pd.testing.assert_index_equal(res.columns, pd.Index(["summary_text"])) - self.assertEqual(res["summary_text"].dtype.type, np.object_) + self.assertEqual(res["summary_text"].dtype.type, str) getattr(self, registry_test_fn)( model=model, @@ -341,12 +341,12 @@ def test_table_question_answering_pipeline( def check_res(res: pd.DataFrame) -> None: pd.testing.assert_index_equal(res.columns, pd.Index(["answer", "coordinates", "cells", "aggregator"])) - self.assertEqual(res["answer"].dtype.type, np.object_) + self.assertEqual(res["answer"].dtype.type, str) self.assertEqual(res["coordinates"].dtype.type, np.object_) self.assertIsInstance(res["coordinates"][0], list) self.assertEqual(res["cells"].dtype.type, np.object_) self.assertIsInstance(res["cells"][0], list) - self.assertEqual(res["aggregator"].dtype.type, np.object_) + self.assertEqual(res["aggregator"].dtype.type, str) getattr(self, registry_test_fn)( model=model, @@ -376,7 +376,7 @@ def test_text_classification_pair_pipeline( def check_res(res: pd.DataFrame) -> None: pd.testing.assert_index_equal(res.columns, pd.Index(["label", "score"])) - self.assertEqual(res["label"].dtype.type, np.object_) + self.assertEqual(res["label"].dtype.type, str) self.assertEqual(res["score"].dtype.type, np.float64) getattr(self, registry_test_fn)( @@ -490,7 +490,7 @@ def test_text2text_generation_pipeline( def check_res(res: pd.DataFrame) -> None: pd.testing.assert_index_equal(res.columns, pd.Index(["generated_text"])) - self.assertEqual(res["generated_text"].dtype.type, np.object_) + self.assertEqual(res["generated_text"].dtype.type, str) getattr(self, registry_test_fn)( model=model, @@ -532,7 +532,7 @@ def test_translation_pipeline( def check_res(res: pd.DataFrame) -> None: pd.testing.assert_index_equal(res.columns, pd.Index(["translation_text"])) - self.assertEqual(res["translation_text"].dtype.type, np.object_) + self.assertEqual(res["translation_text"].dtype.type, str) getattr(self, registry_test_fn)( model=model, @@ -573,7 +573,7 @@ def test_zero_shot_classification_pipeline( def check_res(res: pd.DataFrame) -> None: pd.testing.assert_index_equal(res.columns, pd.Index(["sequence", "labels", "scores"])) - self.assertEqual(res["sequence"].dtype.type, np.object_) + self.assertEqual(res["sequence"].dtype.type, str) self.assertEqual( res["sequence"][0], "I have a problem with Snowflake that needs to be resolved asap!!", @@ -586,8 +586,8 @@ def check_res(res: pd.DataFrame) -> None: self.assertListEqual(sorted(res["labels"][0]), sorted(["urgent", "not urgent"])) self.assertListEqual(sorted(res["labels"][1]), sorted(["English", "Japanese"])) self.assertEqual(res["scores"].dtype.type, np.object_) - self.assertIsInstance(res["labels"][0], list) - self.assertIsInstance(res["labels"][1], list) + self.assertIsInstance(res["scores"][0], list) + self.assertIsInstance(res["scores"][1], list) getattr(self, registry_test_fn)( model=model, diff --git a/tests/integ/snowflake/ml/registry/model/registry_lightgbm_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_lightgbm_model_test.py index e74addc6..b54dd386 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_lightgbm_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_lightgbm_model_test.py @@ -7,7 +7,13 @@ import pandas as pd import shap from absl.testing import absltest, parameterized -from sklearn import datasets, model_selection +from sklearn import ( + compose, + datasets, + model_selection, + pipeline as SK_pipeline, + preprocessing, +) from snowflake.ml.model import model_signature from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema @@ -39,13 +45,69 @@ def test_lightgbm_classifier_no_explain( prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(classifier.predict(cal_X_test), axis=1) + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict(cal_X_test), columns=res.columns), + check_dtype=False, ), ), "predict_proba": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, classifier.predict_proba(cal_X_test)), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict_proba(cal_X_test), columns=res.columns), + check_dtype=False, + ), + ), + }, + options={"enable_explainability": False}, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_lightgbm_classifier_pipeline_no_explain( + self, + registry_test_fn: str, + ) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True) + cal_X = cal_data.data + cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + + pipeline = SK_pipeline.Pipeline( + [ + ("classifier", lightgbm.LGBMClassifier()), + ] + ) + pipeline.fit(cal_X_train, cal_y_train) + + def _check_predict_fn(res: pd.DataFrame) -> None: + pd.testing.assert_frame_equal( + res, + pd.DataFrame(pipeline.predict(cal_X_test), columns=res.columns), + check_dtype=False, + ) + + def _check_predict_proba_fn(res: pd.DataFrame) -> None: + pd.testing.assert_frame_equal( + res, + pd.DataFrame(pipeline.predict_proba(cal_X_test), columns=res.columns), + check_dtype=False, + ) + + getattr(self, registry_test_fn)( + model=pipeline, + sample_input_data=cal_X_test, + prediction_assert_fns={ + "predict": ( + cal_X_test, + _check_predict_fn, + ), + "predict_proba": ( + cal_X_test, + _check_predict_proba_fn, ), }, options={"enable_explainability": False}, @@ -70,23 +132,36 @@ def test_lightgbm_classifier_explain( if expected_explanations.ndim == 3 and expected_explanations.shape[2] == 2: expected_explanations = np.apply_along_axis(lambda arr: arr[1], -1, expected_explanations) + def check_explain_fn(res) -> None: + pd.testing.assert_frame_equal( + res, + pd.DataFrame(expected_explanations, columns=res.columns), + check_dtype=False, + ) + getattr(self, registry_test_fn)( model=classifier, sample_input_data=cal_X_test, prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(classifier.predict(cal_X_test), axis=1) + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict(cal_X_test), columns=res.columns), + check_dtype=False, ), ), "predict_proba": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, classifier.predict_proba(cal_X_test)), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict_proba(cal_X_test), columns=res.columns), + check_dtype=False, + ), ), "explain": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, expected_explanations, rtol=1e-5), + check_explain_fn, ), }, function_type_assert={ @@ -236,7 +311,11 @@ def test_lightgbm_booster_no_explain( prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, np.expand_dims(y_pred, axis=1), rtol=1e-6), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(y_pred, columns=res.columns), + check_dtype=False, + ), ), }, options={"enable_explainability": False}, @@ -265,11 +344,19 @@ def test_lightgbm_booster_explain( prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, np.expand_dims(y_pred, axis=1), rtol=1e-6), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(y_pred, columns=res.columns), + check_dtype=False, + ), ), "explain": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, expected_explanations, rtol=1e-5), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(expected_explanations, columns=res.columns), + check_dtype=False, + ), ), }, function_type_assert={ @@ -398,13 +485,19 @@ def test_lightgbm_with_signature_and_sample_data( prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(classifier.predict(cal_X_test), axis=1) + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict(cal_X_test), columns=res.columns), + check_dtype=False, ), ), "explain": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, expected_explanations, rtol=1e-5), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(expected_explanations, columns=res.columns), + check_dtype=False, + ), ), }, options={"enable_explainability": True}, @@ -415,6 +508,62 @@ def test_lightgbm_with_signature_and_sample_data( }, ) + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_lightgbm_model_with_categorical_dtype_columns( + self, + registry_test_fn: str, + ) -> None: + data = { + "color": ["red", "blue", "green", "red"], + "size": [1, 2, 2, 4], + "price": [10, 15, 20, 25], + "target": [0, 1, 1, 0], + } + input_features = ["color", "size", "price"] + + df = pd.DataFrame(data) + df["color"] = df["color"].astype("category") + df["size"] = df["size"].astype("category") + + # Define categorical columns + categorical_columns = ["color", "size"] + + # Create a column transformer + preprocessor = compose.ColumnTransformer( + transformers=[("cat", preprocessing.OneHotEncoder(), categorical_columns)], + remainder="passthrough", + ) + + pipeline = SK_pipeline.Pipeline( + [ + ("preprocessor", preprocessor), + ("classifier", lightgbm.LGBMClassifier()), + ] + ) + pipeline.fit(df.drop("target", axis=1), df["target"]) + + def _check_predict_fn(res) -> None: + pd.testing.assert_frame_equal( + res["output_feature_0"].to_frame(), + pd.DataFrame(pipeline.predict(df[input_features]), columns=["output_feature_0"]), + check_dtype=False, + ) + + getattr(self, registry_test_fn)( + model=pipeline, + sample_input_data=df[input_features], + prediction_assert_fns={ + "predict": ( + df[input_features], + _check_predict_fn, + ), + }, + # TODO(SNOW-1677301): Add support for explainability for categorical columns + options={"enable_explainability": False}, + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_mlflow_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_mlflow_model_test.py index a87e51ce..8103fba9 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_mlflow_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_mlflow_model_test.py @@ -1,7 +1,7 @@ from importlib import metadata as importlib_metadata import mlflow -import numpy as np +import pandas as pd from absl.testing import absltest, parameterized from sklearn import datasets, ensemble, model_selection @@ -58,7 +58,11 @@ def test_mlflow_model_deploy_sklearn_df( prediction_assert_fns={ "predict": ( X_test, - lambda res: np.testing.assert_allclose(np.expand_dims(predictions, axis=1), res.to_numpy()), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(predictions, columns=res.columns), + check_dtype=False, + ), ), }, options={"relax_version": False}, @@ -113,7 +117,11 @@ def test_mlflow_model_deploy_sklearn( prediction_assert_fns={ "predict": ( X_test_df, - lambda res: np.testing.assert_allclose(np.expand_dims(predictions, axis=1), res.to_numpy()), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(predictions, columns=res.columns), + check_dtype=False, + ), ), }, options={"relax_version": False}, diff --git a/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py index 61c0b10e..54b010ef 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py @@ -1,7 +1,7 @@ import os import posixpath -import numpy as np +import pandas as pd import shap import yaml from absl.testing import absltest, parameterized @@ -12,13 +12,14 @@ from snowflake.ml.model import model_signature from snowflake.ml.model._model_composer import model_composer from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema +from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils from snowflake.ml.modeling.lightgbm import LGBMRegressor from snowflake.ml.modeling.linear_model import LogisticRegression from snowflake.ml.modeling.pipeline import Pipeline from snowflake.ml.modeling.xgboost import XGBRegressor from snowflake.snowpark import types as T from tests.integ.snowflake.ml.registry.model import registry_model_test_base -from tests.integ.snowflake.ml.test_utils import dataframe_utils, test_env_utils +from tests.integ.snowflake.ml.test_utils import test_env_utils class TestRegistryModelingModelInteg(registry_model_test_base.RegistryModelTestBase): @@ -44,8 +45,10 @@ def test_snowml_model_deploy_snowml_sklearn_explain_disabled( prediction_assert_fns={ "predict": ( test_features, - lambda res: np.testing.assert_allclose( - res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values + lambda res: pd.testing.assert_series_equal( + res[OUTPUT_COLUMNS], + regr.predict(test_features)[OUTPUT_COLUMNS], + check_dtype=False, ), ), }, @@ -65,7 +68,7 @@ def test_snowml_model_deploy_snowml_sklearn_explain_default( INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] LABEL_COLUMNS = "TARGET" OUTPUT_COLUMNS = "PREDICTED_TARGET" - EXPLAIN_OUTPUT_COLUMNS = [identifier.concat_names([feature, "_explanation"]) for feature in INPUT_COLUMNS] + # EXPLAIN_OUTPUT_COLUMNS = [identifier.concat_names([feature, "_explanation"]) for feature in INPUT_COLUMNS] regr = LogisticRegression(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) test_features = iris_X regr.fit(test_features) @@ -73,22 +76,35 @@ def test_snowml_model_deploy_snowml_sklearn_explain_default( test_data = test_features[INPUT_COLUMNS] expected_explanations = shap.Explainer(regr.to_sklearn(), masker=test_data)(test_data).values + def _check_explain(res: pd.DataFrame) -> None: + actual_explain_df = handlers_utils.convert_explanations_to_2D_df(regr, expected_explanations) + rename_columns = { + old_col_name: new_col_name for old_col_name, new_col_name in zip(actual_explain_df.columns, res.columns) + } + actual_explain_df.rename(columns=rename_columns, inplace=True) + pd.testing.assert_frame_equal( + res, + actual_explain_df, + check_dtype=False, + ) + + def _check_predict(res) -> None: + pd.testing.assert_series_equal( + res[OUTPUT_COLUMNS], + regr.predict(test_features)[OUTPUT_COLUMNS], + check_dtype=False, + ) + getattr(self, registry_test_fn)( model=regr, prediction_assert_fns={ "predict": ( test_features, - lambda res: np.testing.assert_allclose( - res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values - ), + _check_predict, ), "explain": ( test_features, - lambda res: np.testing.assert_allclose( - dataframe_utils.convert2D_json_to_3D(res[EXPLAIN_OUTPUT_COLUMNS].values), - expected_explanations, - rtol=1e-4, - ), + _check_explain, ), }, sample_input_data=test_data, @@ -107,11 +123,30 @@ def test_snowml_model_deploy_snowml_sklearn_explain_enabled( INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] LABEL_COLUMNS = "TARGET" OUTPUT_COLUMNS = "PREDICTED_TARGET" - EXPLAIN_OUTPUT_COLUMNS = [identifier.concat_names([feature, "_explanation"]) for feature in INPUT_COLUMNS] + # EXPLAIN_OUTPUT_COLUMNS = [identifier.concat_names([feature, "_explanation"]) for feature in INPUT_COLUMNS] regr = LogisticRegression(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) test_features = iris_X regr.fit(test_features) + def _check_explain(res: pd.DataFrame) -> None: + actual_explain_df = handlers_utils.convert_explanations_to_2D_df(regr, expected_explanations) + rename_columns = { + old_col_name: new_col_name for old_col_name, new_col_name in zip(actual_explain_df.columns, res.columns) + } + actual_explain_df.rename(columns=rename_columns, inplace=True) + pd.testing.assert_frame_equal( + res, + actual_explain_df, + check_dtype=False, + ) + + def _check_predict(res) -> None: + pd.testing.assert_series_equal( + res[OUTPUT_COLUMNS], + regr.predict(test_features)[OUTPUT_COLUMNS], + check_dtype=False, + ) + test_data = test_features[INPUT_COLUMNS] expected_explanations = shap.Explainer(regr.to_sklearn(), masker=test_data)(test_data).values getattr(self, registry_test_fn)( @@ -119,17 +154,11 @@ def test_snowml_model_deploy_snowml_sklearn_explain_enabled( prediction_assert_fns={ "predict": ( test_features, - lambda res: np.testing.assert_allclose( - res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values - ), + _check_predict, ), "explain": ( test_features, - lambda res: np.testing.assert_allclose( - dataframe_utils.convert2D_json_to_3D(res[EXPLAIN_OUTPUT_COLUMNS].values), - expected_explanations, - rtol=1e-4, - ), + _check_explain, ), }, sample_input_data=test_data, @@ -158,8 +187,10 @@ def test_snowml_model_deploy_xgboost_explain_disabled( prediction_assert_fns={ "predict": ( test_features, - lambda res: np.testing.assert_allclose( - res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values + lambda res: pd.testing.assert_series_equal( + res[OUTPUT_COLUMNS], + regr.predict(test_features)[OUTPUT_COLUMNS], + check_dtype=False, ), ), }, @@ -192,14 +223,18 @@ def test_snowml_model_deploy_xgboost_explain_default( prediction_assert_fns={ "predict": ( test_features, - lambda res: np.testing.assert_allclose( - res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values + lambda res: pd.testing.assert_series_equal( + res[PRED_OUTPUT_COLUMNS], + regr.predict(test_features)[PRED_OUTPUT_COLUMNS], + check_dtype=False, ), ), "explain": ( test_features, - lambda res: np.testing.assert_allclose( - res[EXPLAIN_OUTPUT_COLUMNS].values, expected_explanations, rtol=1e-4 + lambda res: pd.testing.assert_frame_equal( + res[EXPLAIN_OUTPUT_COLUMNS], + pd.DataFrame(expected_explanations, columns=EXPLAIN_OUTPUT_COLUMNS), + check_dtype=False, ), ), }, @@ -231,14 +266,18 @@ def test_snowml_model_deploy_xgboost_explain_enabled( prediction_assert_fns={ "predict": ( test_features, - lambda res: np.testing.assert_allclose( - res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values + lambda res: pd.testing.assert_series_equal( + res[PRED_OUTPUT_COLUMNS], + regr.predict(test_features)[PRED_OUTPUT_COLUMNS], + check_dtype=False, ), ), "explain": ( test_features, - lambda res: np.testing.assert_allclose( - res[EXPLAIN_OUTPUT_COLUMNS].values, expected_explanations, rtol=1e-4 + lambda res: pd.testing.assert_frame_equal( + res[EXPLAIN_OUTPUT_COLUMNS], + pd.DataFrame(expected_explanations, columns=EXPLAIN_OUTPUT_COLUMNS), + check_dtype=False, ), ), }, @@ -249,6 +288,57 @@ def test_snowml_model_deploy_xgboost_explain_enabled( }, ) + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_snowml_model_deploy_xgboost_explain( + self, + registry_test_fn: str, + ) -> None: + iris_X = datasets.load_iris(as_frame=True).frame + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" + EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] + + regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X + regr.fit(test_features) + + expected_explanations = shap.Explainer(regr.to_xgboost())(test_features[INPUT_COLUMNS]).values + + def _check_explain(res: pd.DataFrame) -> None: + expected_explanations_df = pd.DataFrame( + expected_explanations, + columns=EXPLAIN_OUTPUT_COLUMNS, + ) + res.columns = EXPLAIN_OUTPUT_COLUMNS + pd.testing.assert_frame_equal( + res, + expected_explanations_df, + check_dtype=False, + ) + + getattr(self, registry_test_fn)( + model=regr, + prediction_assert_fns={ + "predict": ( + test_features, + lambda res: pd.testing.assert_series_equal( + res[PRED_OUTPUT_COLUMNS], + regr.predict(test_features)[PRED_OUTPUT_COLUMNS], + check_dtype=False, + ), + ), + "explain": ( + test_features, + _check_explain, + ), + }, + ) + @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, ) @@ -271,8 +361,10 @@ def test_snowml_model_deploy_lightgbm_explain_disabled( prediction_assert_fns={ "predict": ( test_features, - lambda res: np.testing.assert_allclose( - res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values + lambda res: pd.testing.assert_series_equal( + res[OUTPUT_COLUMNS], + regr.predict(test_features)[OUTPUT_COLUMNS], + check_dtype=False, ), ), }, @@ -299,22 +391,32 @@ def test_snowml_model_deploy_lightgbm_explain_default( expected_explanations = shap.Explainer(regr.to_lightgbm())(test_features[INPUT_COLUMNS]).values + def _check_explain(res: pd.DataFrame) -> None: + expected_explanations_df = pd.DataFrame( + expected_explanations, + columns=EXPLAIN_OUTPUT_COLUMNS, + ) + res.columns = EXPLAIN_OUTPUT_COLUMNS + pd.testing.assert_frame_equal( + res, + expected_explanations_df, + check_dtype=False, + ) + getattr(self, registry_test_fn)( model=regr, prediction_assert_fns={ "predict": ( test_features, - lambda res: np.testing.assert_allclose( - res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values + lambda res: pd.testing.assert_series_equal( + res[PRED_OUTPUT_COLUMNS], + regr.predict(test_features)[PRED_OUTPUT_COLUMNS], + check_dtype=False, ), ), "explain": ( test_features, - lambda res: np.testing.assert_allclose( - res[EXPLAIN_OUTPUT_COLUMNS].values, - expected_explanations, - rtol=1e-5, - ), + _check_explain, ), }, function_type_assert={ @@ -343,22 +445,32 @@ def test_snowml_model_deploy_lightgbm_explain_enabled( expected_explanations = shap.Explainer(regr.to_lightgbm())(test_features[INPUT_COLUMNS]).values + def _check_explain(res: pd.DataFrame) -> None: + expected_explanations_df = pd.DataFrame( + expected_explanations, + columns=EXPLAIN_OUTPUT_COLUMNS, + ) + res.columns = EXPLAIN_OUTPUT_COLUMNS + pd.testing.assert_frame_equal( + res[EXPLAIN_OUTPUT_COLUMNS], + expected_explanations_df, + check_dtype=False, + ) + getattr(self, registry_test_fn)( model=regr, prediction_assert_fns={ "predict": ( test_features, - lambda res: np.testing.assert_allclose( - res[PRED_OUTPUT_COLUMNS].values, regr.predict(test_features)[PRED_OUTPUT_COLUMNS].values + lambda res: pd.testing.assert_series_equal( + res[PRED_OUTPUT_COLUMNS], + regr.predict(test_features)[PRED_OUTPUT_COLUMNS], + check_dtype=False, ), ), "explain": ( test_features, - lambda res: np.testing.assert_allclose( - res[EXPLAIN_OUTPUT_COLUMNS].values, - expected_explanations, - rtol=1e-5, - ), + _check_explain, ), }, options={"enable_explainability": True}, @@ -368,6 +480,56 @@ def test_snowml_model_deploy_lightgbm_explain_enabled( }, ) + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_snowml_model_deploy_lightgbm_explain( + self, + registry_test_fn: str, + ) -> None: + iris_X = datasets.load_iris(as_frame=True).frame + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + PRED_OUTPUT_COLUMNS = "PREDICTED_TARGET" + EXPLAIN_OUTPUT_COLUMNS = [feature + "_explanation" for feature in INPUT_COLUMNS] + regr = LGBMRegressor(input_cols=INPUT_COLUMNS, output_cols=PRED_OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X + regr.fit(test_features) + + expected_explanations = shap.Explainer(regr.to_lightgbm())(test_features[INPUT_COLUMNS]).values + + def check_explain(res: pd.DataFrame) -> None: + expected_explanations_df = pd.DataFrame( + expected_explanations, + columns=EXPLAIN_OUTPUT_COLUMNS, + ) + res.columns = EXPLAIN_OUTPUT_COLUMNS + pd.testing.assert_frame_equal( + res, + expected_explanations_df, + check_dtype=False, + ) + + getattr(self, registry_test_fn)( + model=regr, + prediction_assert_fns={ + "predict": ( + test_features, + lambda res: pd.testing.assert_series_equal( + res[PRED_OUTPUT_COLUMNS], + regr.predict(test_features)[PRED_OUTPUT_COLUMNS], + check_dtype=False, + ), + ), + "explain": ( + test_features, + check_explain, + ), + }, + ) + @parameterized.product( # type: ignore[misc] registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, use_pipeline=[False, True], @@ -419,12 +581,13 @@ def test_dataset_to_model_lineage( prediction_assert_fns={ "predict": ( iris_X, - lambda res: np.testing.assert_allclose( - res[OUTPUT_COLUMNS].values, regr.predict(iris_X)[OUTPUT_COLUMNS].values + lambda res: pd.testing.assert_series_equal( + res[OUTPUT_COLUMNS], + regr.predict(iris_X)[OUTPUT_COLUMNS], + check_dtype=False, ), ), }, - additional_dependencies=["fsspec", "aiohttp", "cryptography"], ) # Case 3 : Capture Lineage via sample_input of log_model of MANIFEST.yml file diff --git a/tests/integ/snowflake/ml/registry/model/registry_pytorch_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_pytorch_model_test.py index 41776df4..53f71ae5 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_pytorch_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_pytorch_model_test.py @@ -92,7 +92,7 @@ def test_torchscript_tensor_as_sample( ) -> None: model, data_x, data_y = model_factory.ModelFactory.prepare_jittable_torch_model(torch.float32) x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) - model_script = torch.jit.script(model) # type:ignore[attr-defined] + model_script = torch.jit.script(model) y_pred = model_script.forward(data_x).detach() getattr(self, registry_test_fn)( @@ -117,7 +117,7 @@ def test_torchscript_df_as_sample( ) -> None: model, data_x, data_y = model_factory.ModelFactory.prepare_jittable_torch_model(torch.float64) x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) - model_script = torch.jit.script(model) # type:ignore[attr-defined] + model_script = torch.jit.script(model) y_pred = model_script.forward(data_x).detach() getattr(self, registry_test_fn)( @@ -143,7 +143,7 @@ def test_torchscript_sp( model, data_x, data_y = model_factory.ModelFactory.prepare_jittable_torch_model(torch.float64) x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) x_df.columns = ["col_0"] - model_script = torch.jit.script(model) # type:ignore[attr-defined] + model_script = torch.jit.script(model) y_pred = model_script.forward(data_x) x_df_sp = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(self.session, x_df) y_pred_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([y_pred]) diff --git a/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py index 0dde2940..caaeeab8 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py @@ -40,11 +40,19 @@ def test_skl_model( prediction_assert_fns={ "predict": ( iris_X, - lambda res: np.testing.assert_allclose(res["output_feature_0"].values, classifier.predict(iris_X)), + lambda res: pd.testing.assert_frame_equal( + res["output_feature_0"].to_frame("output_feature_0"), + pd.DataFrame(classifier.predict(iris_X), columns=["output_feature_0"]), + check_dtype=False, + ), ), "predict_proba": ( iris_X[:10], - lambda res: np.testing.assert_allclose(res.values, classifier.predict_proba(iris_X[:10])), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict_proba(iris_X[:10]), columns=res.columns), + check_dtype=False, + ), ), }, function_type_assert={ @@ -68,23 +76,41 @@ def test_skl_model_explain( classifier.fit(iris_X_df, iris_y) expected_explanations = shap.Explainer(classifier, iris_X_df)(iris_X_df).values + def _check_explain(res: pd.DataFrame) -> None: + actual_explain_df = handlers_utils.convert_explanations_to_2D_df(classifier, expected_explanations) + rename_columns = { + old_col_name: new_col_name for old_col_name, new_col_name in zip(actual_explain_df.columns, res.columns) + } + actual_explain_df.rename(columns=rename_columns, inplace=True) + pd.testing.assert_frame_equal( + res, + actual_explain_df, + check_dtype=False, + ) + getattr(self, registry_test_fn)( model=classifier, sample_input_data=iris_X_df, prediction_assert_fns={ "predict": ( iris_X_df, - lambda res: np.testing.assert_allclose(res["output_feature_0"].values, classifier.predict(iris_X)), + lambda res: pd.testing.assert_frame_equal( + res["output_feature_0"].to_frame("output_feature_0"), + pd.DataFrame(classifier.predict(iris_X), columns=["output_feature_0"]), + check_dtype=False, + ), ), "predict_proba": ( iris_X_df.iloc[:10], - lambda res: np.testing.assert_allclose(res.values, classifier.predict_proba(iris_X[:10])), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(classifier.predict_proba(iris_X_df[:10]), columns=res.columns), + check_dtype=False, + ), ), "explain": ( iris_X_df, - lambda res: np.testing.assert_allclose( - dataframe_utils.convert2D_json_to_3D(res.values), expected_explanations - ), + _check_explain, ), }, options={"enable_explainability": True}, @@ -153,11 +179,19 @@ def test_skl_model_case_sensitive( prediction_assert_fns={ '"predict"': ( iris_X, - lambda res: np.testing.assert_allclose(res["output_feature_0"].values, regr.predict(iris_X)), + lambda res: pd.testing.assert_frame_equal( + res["output_feature_0"].to_frame("output_feature_0"), + pd.DataFrame(regr.predict(iris_X), columns=["output_feature_0"]), + check_dtype=False, + ), ), '"predict_proba"': ( iris_X[:10], - lambda res: np.testing.assert_allclose(res.values, regr.predict_proba(iris_X[:10])), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(regr.predict_proba(iris_X[:10]), columns=res.columns), + check_dtype=False, + ), ), }, ) @@ -180,7 +214,11 @@ def test_skl_multiple_output_model( prediction_assert_fns={ "predict": ( iris_X[-10:], - lambda res: np.testing.assert_allclose(res.to_numpy(), model.predict(iris_X[-10:])), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(model.predict(iris_X[-10:]), columns=res.columns), + check_dtype=False, + ), ), "predict_proba": ( iris_X[-10:], @@ -218,7 +256,11 @@ def test_skl_unsupported_explain( ) res = mv.run(iris_X[-10:], function_name="predict") - np.testing.assert_allclose(res.to_numpy(), model.predict(iris_X[-10:])) + pd.testing.assert_frame_equal( + res, + pd.DataFrame(model.predict(iris_X[-10:]), columns=res.columns), + check_dtype=False, + ) res = mv.run(iris_X[-10:], function_name="predict_proba") np.testing.assert_allclose( @@ -252,19 +294,33 @@ def test_skl_model_with_signature_and_sample_data( "predict": model_signature.infer_signature(iris_X_df, y_pred), } + def _check_explain(res: pd.DataFrame) -> None: + actual_explain_df = handlers_utils.convert_explanations_to_2D_df(classifier, expected_explanations) + rename_columns = { + old_col_name: new_col_name for old_col_name, new_col_name in zip(actual_explain_df.columns, res.columns) + } + actual_explain_df.rename(columns=rename_columns, inplace=True) + pd.testing.assert_frame_equal( + res, + actual_explain_df, + check_dtype=False, + ) + getattr(self, registry_test_fn)( model=classifier, sample_input_data=iris_X_df, prediction_assert_fns={ "predict": ( iris_X_df, - lambda res: np.testing.assert_allclose(res["output_feature_0"].values, classifier.predict(iris_X)), + lambda res: pd.testing.assert_frame_equal( + res["output_feature_0"].to_frame("output_feature_0"), + pd.DataFrame(classifier.predict(iris_X), columns=["output_feature_0"]), + check_dtype=False, + ), ), "explain": ( iris_X_df, - lambda res: np.testing.assert_allclose( - dataframe_utils.convert2D_json_to_3D(res.values), expected_explanations - ), + _check_explain, ), }, options={"enable_explainability": True}, @@ -306,6 +362,12 @@ def test_skl_model_with_categorical_dtype_columns( ] ) pipeline.fit(df.drop("target", axis=1), df["target"]) + expected_signatures = { + "predict": model_signature.infer_signature( + df[input_features], + df["target"].rename("output_feature_0"), + ), + } getattr(self, registry_test_fn)( model=pipeline, @@ -313,13 +375,16 @@ def test_skl_model_with_categorical_dtype_columns( prediction_assert_fns={ "predict": ( df[input_features], - lambda res: np.testing.assert_allclose( - res["output_feature_0"].values, pipeline.predict(df[input_features]) + lambda res: pd.testing.assert_series_equal( + res["output_feature_0"], + pd.Series(pipeline.predict(df[input_features]), name="output_feature_0"), + check_dtype=False, ), ), }, # TODO(SNOW-1677301): Add support for explainability for categorical columns options={"enable_explainability": False}, + signatures=expected_signatures, ) diff --git a/tests/integ/snowflake/ml/registry/model/registry_tensorflow_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_tensorflow_model_test.py index 8263100f..bd63f71d 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_tensorflow_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_tensorflow_model_test.py @@ -58,16 +58,25 @@ def test_tf_tensor_as_sample( x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) y_pred = model(data_x) + def assert_fn(res): + y_pred_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df( + tf.transpose(tf.expand_dims(y_pred, axis=0)), + ensure_serializable=False, + ) + y_pred_df.columns = res.columns + pd.testing.assert_frame_equal( + res, + y_pred_df, + check_dtype=False, + ) + getattr(self, registry_test_fn)( model=model, sample_input_data=[data_x], prediction_assert_fns={ "": ( x_df, - lambda res: np.testing.assert_allclose( - tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(res)[0].numpy(), - y_pred.numpy(), - ), + assert_fn, ), }, ) @@ -83,16 +92,25 @@ def test_tf_df_as_sample( x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) y_pred = model(data_x) + def assert_fn(res): + y_pred_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df( + tf.transpose(tf.expand_dims(y_pred, axis=0)), + ensure_serializable=False, + ) + y_pred_df.columns = res.columns + pd.testing.assert_frame_equal( + res, + y_pred_df, + check_dtype=False, + ) + getattr(self, registry_test_fn)( model=model, sample_input_data=x_df, prediction_assert_fns={ "": ( x_df, - lambda res: np.testing.assert_allclose( - tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(res)[0].numpy(), - y_pred.numpy(), - ), + assert_fn, ), }, ) @@ -137,17 +155,28 @@ def test_keras_tensor_as_sample( model, data_x, data_y = prepare_keras_model() x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) y_pred = model.predict(data_x) + + def assert_fn(res): + y_pred_df = pd.DataFrame(y_pred) + y_pred_df.columns = res.columns + # res's shape: (num_rows, 1, 1) + # y_pred_df's shape: (num_rows, 1) + # convert list to scalar value before comparing + for col in res.columns: + res[col] = res[col].apply(lambda x: x[0]) + pd.testing.assert_frame_equal( + res, + y_pred_df, + check_dtype=False, + ) + getattr(self, registry_test_fn)( model=model, sample_input_data=[data_x], prediction_assert_fns={ "": ( x_df, - lambda res: np.testing.assert_allclose( - tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(res)[0].numpy(), - y_pred, - atol=1e-6, - ), + assert_fn, ), }, ) @@ -162,17 +191,28 @@ def test_keras_df_as_sample( model, data_x, data_y = prepare_keras_model() x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) y_pred = model.predict(data_x) + + def assert_fn(res): + y_pred_df = pd.DataFrame(y_pred) + y_pred_df.columns = res.columns + # res's shape: (num_rows, 1, 1) + # y_pred_df's shape: (num_rows, 1) + # convert list to scalar value before comparing + for col in res.columns: + res[col] = res[col].apply(lambda x: x[0]) + pd.testing.assert_frame_equal( + res, + y_pred_df, + check_dtype=False, + ) + getattr(self, registry_test_fn)( model=model, sample_input_data=x_df, prediction_assert_fns={ "": ( x_df, - lambda res: np.testing.assert_allclose( - tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(res)[0].numpy(), - y_pred, - atol=1e-6, - ), + assert_fn, ), }, ) diff --git a/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py index 1085c1a7..376f3b0a 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py @@ -4,7 +4,13 @@ import shap import xgboost from absl.testing import absltest, parameterized -from sklearn import datasets, model_selection +from sklearn import ( + compose, + datasets, + model_selection, + pipeline as SK_pipeline, + preprocessing, +) from snowflake.ml.model import model_signature from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema @@ -31,7 +37,11 @@ def test_xgb_manual_shap_override(self, registry_test_fn: str) -> None: prediction_assert_fns={ "explain": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, expected_explanations, rtol=1e-3), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(expected_explanations, columns=res.columns), + check_dtype=False, + ), ), }, # pin version of shap for tests @@ -53,14 +63,58 @@ def test_xgb_no_explain( cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) regressor.fit(cal_X_train, cal_y_train) + + def _check_predict_fn(res: pd.DataFrame) -> None: + pd.testing.assert_frame_equal( + res, + pd.DataFrame(regressor.predict(cal_X_test), columns=res.columns), + check_dtype=False, + ) + getattr(self, registry_test_fn)( model=regressor, sample_input_data=cal_X_test, prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(regressor.predict(cal_X_test), axis=1) + _check_predict_fn, + ), + }, + options={"enable_explainability": False}, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_xgb_pipeline_no_explain( + self, + registry_test_fn: str, + ) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True) + cal_X = cal_data.data + cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + regressor = SK_pipeline.Pipeline( + steps=[ + ("regressor", xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3)), + ] + ) + + regressor.fit(cal_X_train, cal_y_train) + getattr(self, registry_test_fn)( + model=regressor, + sample_input_data=cal_X_test, + prediction_assert_fns={ + "predict": ( + cal_X_test, + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame( + regressor.predict(cal_X_test), + columns=res.columns, + ), + check_dtype=False, ), ), }, @@ -88,7 +142,11 @@ def test_xgb_explain_by_default( prediction_assert_fns={ "explain": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, expected_explanations, rtol=1e-4), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(expected_explanations, columns=res.columns), + check_dtype=False, + ), ), }, function_type_assert={"explain": model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION}, @@ -115,7 +173,11 @@ def test_xgb_explain_explicitly_enabled( prediction_assert_fns={ "explain": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, expected_explanations, rtol=1e-4), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(expected_explanations, columns=res.columns), + check_dtype=False, + ), ), }, options={"enable_explainability": True}, @@ -212,13 +274,21 @@ def test_xgb_booster_no_explain( params = dict(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3, objective="binary:logistic") regressor = xgboost.train(params, xgboost.DMatrix(data=cal_X_train, label=cal_y_train)) y_pred = regressor.predict(xgboost.DMatrix(data=cal_X_test)) + + def _check_predict_fn(res: pd.DataFrame) -> None: + pd.testing.assert_frame_equal( + res, + pd.DataFrame(y_pred, columns=res.columns), + check_dtype=False, + ) + getattr(self, registry_test_fn)( model=regressor, sample_input_data=cal_X_test, prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, np.expand_dims(y_pred, axis=1), rtol=1e-6), + _check_predict_fn, ), }, options={"enable_explainability": False}, @@ -245,7 +315,11 @@ def test_xgb_booster_explain( prediction_assert_fns={ "explain": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, expected_explanations, rtol=1e-4), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(expected_explanations, columns=res.columns), + check_dtype=False, + ), ), }, function_type_assert={"explain": model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION}, @@ -360,13 +434,123 @@ def test_xgb_booster_with_signature_and_sample_data( prediction_assert_fns={ "explain": ( cal_X_test, - lambda res: np.testing.assert_allclose(res.values, expected_explanations, rtol=1e-4), + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(expected_explanations, columns=res.columns), + check_dtype=False, + ), ), }, options={"enable_explainability": True}, signatures=sig, ) + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + def test_xgb_model_with_categorical_dtype_columns( + self, + registry_test_fn: str, + ) -> None: + data = { + "color": ["red", "blue", "green", "red"], + "size": [1, 2, 2, 4], + "price": [10, 15, 20, 25], + "target": [0, 1, 1, 0], + } + input_features = ["color", "size", "price"] + + df = pd.DataFrame(data) + df["color"] = df["color"].astype("category") + df["size"] = df["size"].astype("category") + + # Define categorical columns + categorical_columns = ["color", "size"] + + # Create a column transformer + preprocessor = compose.ColumnTransformer( + transformers=[ + ("cat", preprocessing.OneHotEncoder(), categorical_columns), + ], + remainder="passthrough", + ) + + pipeline = SK_pipeline.Pipeline( + [ + ("preprocessor", preprocessor), + ("classifier", xgboost.XGBClassifier(tree_method="hist")), + ] + ) + pipeline.fit(df[input_features], df["target"]) + + def _check_predict_fn(res) -> None: + pd.testing.assert_frame_equal( + res["output_feature_0"].to_frame(), + pd.DataFrame(pipeline.predict(df[input_features]), columns=["output_feature_0"]), + check_dtype=False, + ) + + getattr(self, registry_test_fn)( + model=pipeline, + sample_input_data=df[input_features], + prediction_assert_fns={ + "predict": ( + df[input_features], + _check_predict_fn, + ), + }, + # TODO(SNOW-1677301): Add support for explainability for categorical columns + options={"enable_explainability": False}, + ) + + @parameterized.product( # type: ignore[misc] + registry_test_fn=registry_model_test_base.RegistryModelTestBase.REGISTRY_TEST_FN_LIST, + ) + @absltest.skip("SNOW-1752904") + def test_xgb_model_with_native_categorical_dtype_columns( + self, + registry_test_fn: str, + ) -> None: + data = { + "color": ["red", "blue", "green", "red"], + "size": [1, 2, 2, 4], + "price": [10, 15, 20, 25], + "target": [0, 1, 1, 0], + } + input_features = ["color", "size", "price"] + + df = pd.DataFrame(data) + df["color"] = df["color"].astype("category") + df["size"] = df["size"].astype("category") + + # Define categorical columns + # categorical_columns = ["color", "size"] + + classifier = xgboost.XGBClassifier(tree_method="hist", enable_categorical=True) + classifier.fit(df[input_features], df["target"]) + + getattr(self, registry_test_fn)( + model=classifier, + sample_input_data=df[input_features], + prediction_assert_fns={ + "predict": ( + df[input_features], + lambda res: np.testing.assert_allclose( + res["output_feature_0"].values, classifier.predict(df[input_features]) + ), + ), + }, + # TODO(SNOW-1677301): Add support for explainability for categorical columns + options={"enable_explainability": False}, + ) + + # TODO(SNOW-1752904): + # The inference fails with message + # ValueError: DataFrame.dtypes for data must be int, float, bool or category. + # When categorical type is supplied, The experimental DMatrix parameter`enable_categorical` + # must be set to `True`. Invalid columns:color: object + # in function PREDICT with handler predict.infer + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/registry/services/BUILD.bazel b/tests/integ/snowflake/ml/registry/services/BUILD.bazel index 4c3547c2..663dbde4 100644 --- a/tests/integ/snowflake/ml/registry/services/BUILD.bazel +++ b/tests/integ/snowflake/ml/registry/services/BUILD.bazel @@ -25,10 +25,10 @@ py_library( ) py_test( - name = "registry_xgboost_model_deployment_test", + name = "registry_xgboost_model_deployment_pip_test", timeout = "eternal", - srcs = ["registry_xgboost_model_deployment_test.py"], - shard_count = 4, + srcs = ["registry_xgboost_model_deployment_pip_test.py"], + shard_count = 2, deps = [ ":registry_model_deployment_test_base", ], @@ -38,7 +38,17 @@ py_test( name = "registry_sentence_transformers_model_deployment_test", timeout = "eternal", srcs = ["registry_sentence_transformers_model_deployment_test.py"], - shard_count = 4, + shard_count = 2, + deps = [ + ":registry_model_deployment_test_base", + ], +) + +py_test( + name = "registry_sentence_transformers_model_deployment_gpu_test", + timeout = "eternal", + srcs = ["registry_sentence_transformers_model_deployment_gpu_test.py"], + shard_count = 2, deps = [ ":registry_model_deployment_test_base", ], @@ -48,7 +58,17 @@ py_test( name = "registry_huggingface_pipeline_model_deployment_test", timeout = "eternal", srcs = ["registry_huggingface_pipeline_model_deployment_test.py"], - shard_count = 4, + shard_count = 2, + deps = [ + ":registry_model_deployment_test_base", + ], +) + +py_test( + name = "registry_huggingface_pipeline_model_deployment_gpu_test", + timeout = "eternal", + srcs = ["registry_huggingface_pipeline_model_deployment_gpu_test.py"], + shard_count = 2, deps = [ ":registry_model_deployment_test_base", ], @@ -56,7 +76,7 @@ py_test( py_test( name = "registry_sklearn_model_deployment_test", - timeout = "long", + timeout = "eternal", srcs = ["registry_sklearn_model_deployment_test.py"], shard_count = 2, deps = [ @@ -66,9 +86,9 @@ py_test( py_test( name = "registry_custom_model_deployment_test", - timeout = "long", + timeout = "eternal", srcs = ["registry_custom_model_deployment_test.py"], - shard_count = 2, + shard_count = 1, deps = [ ":registry_model_deployment_test_base", ], @@ -76,7 +96,7 @@ py_test( py_test( name = "registry_model_deployment_test", - timeout = "long", + timeout = "eternal", srcs = ["registry_model_deployment_test.py"], shard_count = 2, deps = [ diff --git a/tests/integ/snowflake/ml/registry/services/registry_custom_model_deployment_test.py b/tests/integ/snowflake/ml/registry/services/registry_custom_model_deployment_test.py index 60cdaea6..7beb4401 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_custom_model_deployment_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_custom_model_deployment_test.py @@ -2,7 +2,6 @@ import tempfile import inflection -import numpy as np import pandas as pd import xgboost from absl.testing import absltest @@ -57,11 +56,15 @@ def test_custom_model( prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(my_custom_model.predict(cal_X_test), axis=1), rtol=1e-3 + lambda res: pd.testing.assert_frame_equal( + res, + my_custom_model.predict(cal_X_test), + rtol=1e-3, + check_dtype=False, ), ), }, + options={"enable_explainability": False}, ) diff --git a/tests/integ/snowflake/ml/registry/services/registry_huggingface_pipeline_model_deployment_gpu_test.py b/tests/integ/snowflake/ml/registry/services/registry_huggingface_pipeline_model_deployment_gpu_test.py new file mode 100644 index 00000000..ec9d6290 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/services/registry_huggingface_pipeline_model_deployment_gpu_test.py @@ -0,0 +1,72 @@ +import json +import os +import tempfile +from typing import List, Optional + +import pandas as pd +from absl.testing import absltest, parameterized + +from tests.integ.snowflake.ml.registry.services import ( + registry_model_deployment_test_base, +) + + +class TestRegistryHuggingFacePipelineDeploymentModelInteg( + registry_model_deployment_test_base.RegistryModelDeploymentTestBase +): + @classmethod + def setUpClass(self) -> None: + self.cache_dir = tempfile.TemporaryDirectory() + self._original_cache_dir = os.getenv("TRANSFORMERS_CACHE", None) + os.environ["TRANSFORMERS_CACHE"] = self.cache_dir.name + + @classmethod + def tearDownClass(self) -> None: + if self._original_cache_dir: + os.environ["TRANSFORMERS_CACHE"] = self._original_cache_dir + self.cache_dir.cleanup() + + @parameterized.product( # type: ignore[misc] + pip_requirements=[None, ["transformers"]], + ) + def test_text_generation( + self, + pip_requirements: Optional[List[str]], + ) -> None: + import transformers + + model = transformers.pipeline( + task="text-generation", + model="openai-community/gpt2", + ) + + x_df = pd.DataFrame( + [['A descendant of the Lost City of Atlantis, who swam to Earth while saying, "']], + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["outputs"])) + + for row in res["outputs"]: + self.assertIsInstance(row, str) + resp = json.loads(row) + self.assertIsInstance(resp, list) + self.assertIn("generated_text", resp[0]) + + self._test_registry_model_deployment( + model=model, + prediction_assert_fns={ + "__call__": ( + x_df, + check_res, + ), + }, + options={"cuda_version": "11.8"}, + additional_dependencies=["pytorch==2.1.0"], + gpu_requests="1", + pip_requirements=pip_requirements, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/services/registry_huggingface_pipeline_model_deployment_test.py b/tests/integ/snowflake/ml/registry/services/registry_huggingface_pipeline_model_deployment_test.py index 34e7450b..f41f939b 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_huggingface_pipeline_model_deployment_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_huggingface_pipeline_model_deployment_test.py @@ -27,12 +27,10 @@ def tearDownClass(self) -> None: self.cache_dir.cleanup() @parameterized.product( # type: ignore[misc] - gpu_requests=[None, "1"], pip_requirements=[None, ["transformers"]], ) def test_text_generation( self, - gpu_requests: str, pip_requirements: Optional[List[str]], ) -> None: import transformers @@ -63,9 +61,8 @@ def check_res(res: pd.DataFrame) -> None: check_res, ), }, - options={"cuda_version": "11.8"} if gpu_requests else {}, + options={}, additional_dependencies=["pytorch==2.1.0"], - gpu_requests=gpu_requests, pip_requirements=pip_requirements, ) diff --git a/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test.py b/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test.py index 47a33bec..46b2510c 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test.py @@ -1,7 +1,7 @@ import inflection -import numpy as np +import pandas as pd import xgboost -from absl.testing import absltest +from absl.testing import absltest, parameterized from sklearn import datasets, model_selection from tests.integ.snowflake.ml.registry.services import ( @@ -10,8 +10,12 @@ class TestRegistryModelDeploymentInteg(registry_model_deployment_test_base.RegistryModelDeploymentTestBase): + @parameterized.product( # type: ignore[misc] + gpu_requests=[None, "1"], + ) def test_end_to_end_pipeline( self, + gpu_requests: str, ) -> None: cal_data = datasets.load_breast_cancer(as_frame=True) cal_X = cal_data.data @@ -26,23 +30,31 @@ def test_end_to_end_pipeline( prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(regressor.predict(cal_X_test), axis=1), rtol=1e-3 + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(regressor.predict(cal_X_test), columns=res.columns), + rtol=1e-3, + check_dtype=False, ), ), }, + options=( + {"cuda_version": "11.8", "enable_explainability": False} + if gpu_requests + else {"enable_explainability": False} + ), + gpu_requests=gpu_requests, ) services_df = mv.list_services() - services = services_df["service_name"] + services = services_df["name"] self.assertLen(services, 1) for service in services: mv.delete_service(service) services_df = mv.list_services() - services = services_df["service_name"] - self.assertEmpty(services) + self.assertLen(services_df, 0) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py b/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py index 7ba0d6e0..af1cd966 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py +++ b/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py @@ -116,7 +116,7 @@ def _deploy_model_with_image_override( memory=None, gpu=gpu_requests, force_rebuild=force_rebuild, - external_access_integration=sql_identifier.SqlIdentifier(self._SPCS_EAI), + external_access_integrations=[sql_identifier.SqlIdentifier(self._SPCS_EAI)], ) with (mv._service_ops.workspace_path / deploy_spec_file_rel_path).open("r", encoding="utf-8") as f: @@ -241,7 +241,7 @@ def _test_registry_model_deployment( num_workers=num_workers, max_instances=max_instances, max_batch_rows=max_batch_rows, - build_external_access_integration=self._SPCS_EAI, + build_external_access_integrations=[self._SPCS_EAI], ) for target_method, (test_input, check_func) in prediction_assert_fns.items(): diff --git a/tests/integ/snowflake/ml/registry/services/registry_sentence_transformers_model_deployment_gpu_test.py b/tests/integ/snowflake/ml/registry/services/registry_sentence_transformers_model_deployment_gpu_test.py new file mode 100644 index 00000000..83f9bca0 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/services/registry_sentence_transformers_model_deployment_gpu_test.py @@ -0,0 +1,79 @@ +import os +import random +import tempfile +from typing import List, Optional + +import pandas as pd +from absl.testing import absltest, parameterized + +from tests.integ.snowflake.ml.registry.services import ( + registry_model_deployment_test_base, +) + +MODEL_NAMES = ["intfloat/e5-base-v2"] # cant load models in parallel +SENTENCE_TRANSFORMERS_CACHE_DIR = "SENTENCE_TRANSFORMERS_HOME" + + +class TestRegistrySentenceTransformerDeploymentModelInteg( + registry_model_deployment_test_base.RegistryModelDeploymentTestBase +): + @classmethod + def setUpClass(self) -> None: + self.cache_dir = tempfile.TemporaryDirectory() + self._original_cache_dir = os.getenv(SENTENCE_TRANSFORMERS_CACHE_DIR, None) + os.environ[SENTENCE_TRANSFORMERS_CACHE_DIR] = self.cache_dir.name + + @classmethod + def tearDownClass(self) -> None: + if self._original_cache_dir: + os.environ[SENTENCE_TRANSFORMERS_CACHE_DIR] = self._original_cache_dir + self.cache_dir.cleanup() + + @parameterized.product( # type: ignore[misc] + pip_requirements=[None, ["sentence-transformers"]], + ) + def test_sentence_transformers( + self, + pip_requirements: Optional[List[str]], + ) -> None: + import sentence_transformers + + # Sample Data + sentences = pd.DataFrame( + { + "SENTENCES": [ + "Why don’t scientists trust atoms? Because they make up everything.", + "I told my wife she should embrace her mistakes. She gave me a hug.", + "Im reading a book on anti-gravity. Its impossible to put down!", + "Did you hear about the mathematician who’s afraid of negative numbers?", + "Parallel lines have so much in common. It’s a shame they’ll never meet.", + ] + } + ) + model = sentence_transformers.SentenceTransformer(random.choice(MODEL_NAMES)) + embeddings = pd.DataFrame(model.encode(sentences["SENTENCES"].to_list(), batch_size=sentences.shape[0])) + + self._test_registry_model_deployment( + model=model, + sample_input_data=sentences, + prediction_assert_fns={ + "encode": ( + sentences, + lambda res: pd.testing.assert_frame_equal( + pd.DataFrame(res["output_feature_0"].to_list()), + embeddings, + rtol=1e-2, + atol=1e-3, + check_dtype=False, + ), + ), + }, + options={"cuda_version": "11.8"}, + gpu_requests="1", + additional_dependencies=["pytorch==2.1.0", "huggingface_hub<0.26"], + pip_requirements=pip_requirements, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/services/registry_sentence_transformers_model_deployment_test.py b/tests/integ/snowflake/ml/registry/services/registry_sentence_transformers_model_deployment_test.py index 10ee0479..5d9e9bf7 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_sentence_transformers_model_deployment_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_sentence_transformers_model_deployment_test.py @@ -30,12 +30,10 @@ def tearDownClass(self) -> None: self.cache_dir.cleanup() @parameterized.product( # type: ignore[misc] - gpu_requests=[None, "1"], pip_requirements=[None, ["sentence-transformers"]], ) def test_sentence_transformers( self, - gpu_requests: str, pip_requirements: Optional[List[str]], ) -> None: import sentence_transformers @@ -70,9 +68,8 @@ def test_sentence_transformers( ), ), }, - options={"cuda_version": "11.8"} if gpu_requests else {}, - gpu_requests=gpu_requests, - additional_dependencies=["pytorch==2.1.0"], + options={"cuda_version": "11.8"}, + additional_dependencies=["pytorch==2.1.0", "huggingface_hub<0.26"], pip_requirements=pip_requirements, ) diff --git a/tests/integ/snowflake/ml/registry/services/registry_sklearn_model_deployment_test.py b/tests/integ/snowflake/ml/registry/services/registry_sklearn_model_deployment_test.py index 9a258a5e..d7789b95 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_sklearn_model_deployment_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_sklearn_model_deployment_test.py @@ -1,6 +1,6 @@ from typing import List, Optional -import numpy as np +import pandas as pd from absl.testing import absltest, parameterized from sklearn import datasets, svm @@ -22,12 +22,16 @@ def test_sklearn(self, pip_requirements: Optional[List[str]]) -> None: prediction_assert_fns={ "predict": ( iris_X, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(svc.predict(iris_X), axis=1), rtol=1e-3 + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(svc.predict(iris_X), columns=res.columns), + rtol=1e-3, + check_dtype=False, ), ), }, pip_requirements=pip_requirements, + options={"enable_explainability": False}, ) diff --git a/tests/integ/snowflake/ml/registry/services/registry_xgboost_model_deployment_test.py b/tests/integ/snowflake/ml/registry/services/registry_xgboost_model_deployment_pip_test.py similarity index 71% rename from tests/integ/snowflake/ml/registry/services/registry_xgboost_model_deployment_test.py rename to tests/integ/snowflake/ml/registry/services/registry_xgboost_model_deployment_pip_test.py index 9519cb38..bc2c248c 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_xgboost_model_deployment_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_xgboost_model_deployment_pip_test.py @@ -1,7 +1,5 @@ -from typing import List - import inflection -import numpy as np +import pandas as pd import xgboost from absl.testing import absltest, parameterized from sklearn import datasets, model_selection @@ -14,12 +12,10 @@ class TestRegistryXGBoostModelDeploymentInteg(registry_model_deployment_test_base.RegistryModelDeploymentTestBase): @parameterized.product( # type: ignore[misc] gpu_requests=[None, "1"], - pip_requirements=[None, ["xgboost"]], ) def test_xgb( self, gpu_requests: str, - pip_requirements: List[str], ) -> None: cal_data = datasets.load_breast_cancer(as_frame=True) cal_X = cal_data.data @@ -34,14 +30,21 @@ def test_xgb( prediction_assert_fns={ "predict": ( cal_X_test, - lambda res: np.testing.assert_allclose( - res.values, np.expand_dims(regressor.predict(cal_X_test), axis=1), rtol=1e-3 + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(regressor.predict(cal_X_test), columns=res.columns), + rtol=1e-3, + check_dtype=False, ), ), }, - options={"cuda_version": "11.8"} if gpu_requests else {}, + options=( + {"cuda_version": "11.8", "enable_explainability": False} + if gpu_requests + else {"enable_explainability": False} + ), gpu_requests=gpu_requests, - pip_requirements=pip_requirements, + pip_requirements=[f"xgboost=={xgboost.__version__}"], ) diff --git a/tests/integ/snowflake/ml/test_utils/BUILD.bazel b/tests/integ/snowflake/ml/test_utils/BUILD.bazel index c8116a0a..561268eb 100644 --- a/tests/integ/snowflake/ml/test_utils/BUILD.bazel +++ b/tests/integ/snowflake/ml/test_utils/BUILD.bazel @@ -1,6 +1,7 @@ load("//bazel:py_rules.bzl", "py_genrule", "py_library", "py_test") package(default_visibility = [ + "//tests/integ/snowflake/cortex:__subpackages__", "//tests/integ/snowflake/ml:__subpackages__", "//tests/perf:__subpackages__", ]) diff --git a/tests/integ/snowflake/ml/test_utils/common_test_base.py b/tests/integ/snowflake/ml/test_utils/common_test_base.py index 3744a125..4156fe3d 100644 --- a/tests/integ/snowflake/ml/test_utils/common_test_base.py +++ b/tests/integ/snowflake/ml/test_utils/common_test_base.py @@ -137,7 +137,7 @@ def _in_sproc_test(execute_as: Literal["owner", "caller"] = "owner") -> None: imports = [snowml_zip_module_filename, tests_zip_module_filename] packages = additional_packages or [] - for req_str in _snowml_requirements.REQUIREMENTS: + for req_str in _snowml_requirements.ALL_REQUIREMENTS: req = requirements.Requirement(req_str) # Remove "_" not in req once Snowpark 1.11.0 available, it is a workaround for their bug. if any(offending in req.name for offending in ["snowflake-connector-python", "pyarrow"]): diff --git a/tests/integ/snowflake/ml/test_utils/db_manager.py b/tests/integ/snowflake/ml/test_utils/db_manager.py index 6b8c58d7..6297075b 100644 --- a/tests/integ/snowflake/ml/test_utils/db_manager.py +++ b/tests/integ/snowflake/ml/test_utils/db_manager.py @@ -25,13 +25,14 @@ def create_database( self, db_name: str, creation_mode: sql_client.CreationMode = _default_creation_mode, + data_retention_time_in_days: int = 0, ) -> str: actual_db_name = identifier.get_inferred_name(db_name) ddl_phrases = creation_mode.get_ddl_phrases() self._session.sql( f"CREATE{ddl_phrases[sql_client.CreationOption.OR_REPLACE]} DATABASE" f"{ddl_phrases[sql_client.CreationOption.CREATE_IF_NOT_EXIST]} " - f"{actual_db_name} DATA_RETENTION_TIME_IN_DAYS = 0" + f"{actual_db_name} DATA_RETENTION_TIME_IN_DAYS = {data_retention_time_in_days}" ).collect() return actual_db_name