From 27431b2dfc46f249be28f166ec8670d8ffc9e976 Mon Sep 17 00:00:00 2001 From: Sumit Das Date: Tue, 12 Mar 2024 13:59:45 -0700 Subject: [PATCH] Project import generated by Copybara. (#92) GitOrigin-RevId: cd1cf14167a03d4d572a86fb6162ba2d9d9e8457 Co-authored-by: Snowflake Authors --- CHANGELOG.md | 55 +- CONTRIBUTING.md | 4 +- ci/conda_recipe/meta.yaml | 4 +- codegen/codegen_rules.bzl | 2 + codegen/sklearn_wrapper_generator.py | 54 +- codegen/sklearn_wrapper_template.py_template | 581 ++++++++++-------- ...nsformer_autogen_test_template.py_template | 8 +- requirements.yml | 2 +- snowflake/ml/_internal/env_utils.py | 17 + snowflake/ml/_internal/env_utils_test.py | 46 ++ .../ml/_internal/exceptions/error_codes.py | 2 + snowflake/ml/feature_store/feature_store.py | 29 +- snowflake/ml/feature_store/feature_view.py | 3 +- snowflake/ml/fileset/BUILD.bazel | 1 + snowflake/ml/fileset/sfcfs.py | 73 ++- snowflake/ml/fileset/sfcfs_test.py | 40 ++ snowflake/ml/model/_api.py | 4 +- .../model/_client/model/model_version_impl.py | 4 + .../_client/model/model_version_impl_test.py | 30 + snowflake/ml/model/_client/ops/model_ops.py | 11 +- .../ml/model/_client/ops/model_ops_test.py | 46 +- .../ml/model/_client/sql/model_version.py | 16 +- .../templates/dockerfile_template | 41 +- .../test_fixtures/dockerfile_test_fixture | 23 +- .../dockerfile_test_fixture_with_CUDA | 23 +- .../dockerfile_test_fixture_with_model | 27 +- .../_deploy_client/warehouse/deploy_test.py | 5 +- .../model_manifest/BUILD.bazel | 1 + .../model_manifest/fixtures/MANIFEST_4.yml | 73 +++ .../model_manifest/model_manifest_schema.py | 6 +- .../model_manifest/model_manifest_test.py | 47 ++ .../_model_composer/model_method/BUILD.bazel | 3 + .../model_method/fixtures/function_3.py | 76 +++ .../model_method/function_generator.py | 14 +- .../model_method/function_generator_test.py | 31 +- .../infer_table_function.py_template | 76 +++ .../model_method/model_method.py | 41 +- .../model_method/model_method_test.py | 40 ++ .../model_runtime/model_runtime.py | 5 +- .../model_runtime/model_runtime_test.py | 41 +- .../model_handlers_test/custom_test.py | 1 + .../model/_packager/model_meta/model_meta.py | 40 +- .../_packager/model_meta/model_meta_test.py | 82 ++- .../ml/model/_packager/model_packager_test.py | 39 -- .../ml/model/_signatures/snowpark_handler.py | 8 +- .../ml/model/_signatures/snowpark_test.py | 27 +- snowflake/ml/model/model_signature.py | 71 ++- snowflake/ml/model/model_signature_test.py | 20 +- snowflake/ml/model/type_hints.py | 1 + snowflake/ml/modeling/_internal/BUILD.bazel | 44 +- snowflake/ml/modeling/_internal/constants.py | 1 + .../modeling/_internal/estimator_protocols.py | 45 -- .../_internal/estimator_protocols_test.py | 17 - .../local_implementations/BUILD.bazel | 9 + .../local_implementations/pandas_handlers.py | 226 +++++++ .../local_implementations/pandas_trainer.py | 30 +- .../ml_runtime_implementations/BUILD.bazel | 36 ++ .../ml_runtime_handlers.py | 131 ++++ .../ml_runtime_handlers_test.py | 41 ++ .../ml_runtime_trainer.py | 66 ++ .../ml_runtime_trainer_test.py | 46 ++ .../ml/modeling/_internal/model_trainer.py | 13 +- .../_internal/model_trainer_builder.py | 57 +- .../_internal/model_trainer_builder_test.py | 18 + .../_internal/model_transformer_builder.py | 85 +++ .../model_transformer_builder_test.py | 73 +++ .../snowpark_handlers.py | 144 ++--- .../snowpark_trainer.py | 188 +++++- .../_internal/transformer_protocols.py | 191 ++++++ snowflake/ml/modeling/framework/BUILD.bazel | 3 +- .../ml/modeling/model_selection/BUILD.bazel | 2 + .../model_selection/grid_search_cv.py | 445 +++++++------- .../model_selection/randomized_search_cv.py | 442 ++++++------- snowflake/ml/registry/model_registry.py | 12 + snowflake/ml/registry/registry.py | 2 +- snowflake/ml/version.bzl | 2 +- .../ml/_internal/snowpark_handlers_test.py | 48 +- .../snowflake/ml/extra_tests/BUILD.bazel | 36 +- .../ml/extra_tests/fit_predict_test.py | 64 -- .../snowflake/ml/feature_store/BUILD.bazel | 14 + .../ml/feature_store/access_utils.py | 110 ++++ .../ml/feature_store/common_utils.py | 32 +- .../feature_store_access_test.py | 434 +++++++++++++ .../ml/feature_store/feature_store_test.py | 119 +++- .../snowflake/ml/fileset/sfcfs_integ_test.py | 11 + .../ml/model/_client/model/BUILD.bazel | 14 + .../model/input_validation_integ_test.py | 119 ++++ .../ml/model/model_badcase_integ_test.py | 4 +- .../model/warehouse_model_integ_test_utils.py | 2 +- .../snowflake/ml/modeling/framework/utils.py | 3 +- .../snowflake/ml/modeling/metrics/BUILD.bazel | 8 + .../metrics/d2_absolute_error_score_test.py | 170 ++--- .../modeling/metrics/d2_pinball_score_test.py | 241 ++++---- .../ml/modeling/metrics/f1_score_test.py | 314 +++++----- .../ml/modeling/metrics/fbeta_score_test.py | 413 +++++++------ .../ml/modeling/metrics/generator.py | 3 +- .../ml/modeling/metrics/log_loss_test.py | 346 ++++++----- .../precision_recall_fscore_support_test.py | 414 ++++++++----- .../modeling/metrics/precision_score_test.py | 314 +++++----- .../ml/modeling/metrics/recall_score_test.py | 314 +++++----- .../check_sklearn_inference_test.py | 41 +- .../snowflake/ml/registry/model/BUILD.bazel | 19 + .../registry/model/additional_import_test.py | 129 ++++ .../ml/registry/model/my_module/__init__.py | 3 + .../ml/registry/model/my_module/utils.py | 2 + .../model/registry_model_test_base.py | 2 +- .../ml/registry/model_registry_integ_test.py | 20 +- .../integ/snowflake/ml/test_utils/BUILD.bazel | 1 + 108 files changed, 5671 insertions(+), 2356 deletions(-) create mode 100644 snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_4.yml create mode 100644 snowflake/ml/model/_model_composer/model_method/fixtures/function_3.py create mode 100644 snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template create mode 100644 snowflake/ml/modeling/_internal/constants.py delete mode 100644 snowflake/ml/modeling/_internal/estimator_protocols.py delete mode 100644 snowflake/ml/modeling/_internal/estimator_protocols_test.py create mode 100644 snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py create mode 100644 snowflake/ml/modeling/_internal/ml_runtime_implementations/BUILD.bazel create mode 100644 snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py create mode 100644 snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers_test.py create mode 100644 snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py create mode 100644 snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer_test.py create mode 100644 snowflake/ml/modeling/_internal/model_transformer_builder.py create mode 100644 snowflake/ml/modeling/_internal/model_transformer_builder_test.py create mode 100644 snowflake/ml/modeling/_internal/transformer_protocols.py delete mode 100644 tests/integ/snowflake/ml/extra_tests/fit_predict_test.py create mode 100644 tests/integ/snowflake/ml/feature_store/access_utils.py create mode 100644 tests/integ/snowflake/ml/feature_store/feature_store_access_test.py create mode 100644 tests/integ/snowflake/ml/model/_client/model/input_validation_integ_test.py create mode 100644 tests/integ/snowflake/ml/registry/model/additional_import_test.py create mode 100644 tests/integ/snowflake/ml/registry/model/my_module/__init__.py create mode 100644 tests/integ/snowflake/ml/registry/model/my_module/utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e4da97a..4b16e381 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,33 @@ # Release History -## 1.2.3 +## 1.3.0 + +### Bug Fixes + +- Registry: Fix a bug that leads to module in `code_paths` when `log_model` cannot be correctly imported. +- Registry: Fix incorrect error message when validating input Snowpark DataFrame with array feature. +- Model Registry: Fix an issue when deploying a model to SPCS that some files do not have proper permission. +- Model Development: Relax package versions for all inference methods if the installed version + is not available in the Snowflake conda channel + +### Behavior Changes + +- Registry: When running the method of a model, the value range based input validation to avoid input from overflowing + is now optional rather than enforced, this should improve the performance and should not lead to problem for most + kinds of model. If you want to enable this check as previous, specify `strict_input_validation=True` when + calling `run`. +- Registry: By default `relax_version=True` when logging a model instead of using the specific local dependency versions. + This improves dependency versioning by using versions available in Snowflake. To switch back to the previous behavior + and use specific local dependency versions, specify `relax_version=False` when calling `log_model`. +- Model Development: The behavior of `fit_predict` for all estimators is changed. + Firstly, it will cover all the estimator that contains this function, + secondly, the output would be the union of pandas DataFrame and snowpark DataFrame. + +### New Features + +- FileSet: `snowflake.ml.fileset.sfcfs.SFFileSystem` can now be serialized with `pickle`. + +## 1.2.3 (2024-02-26) ### Bug Fixes @@ -23,11 +50,7 @@ GridSearchCV, RandomizedSearchCV, PCA, IsolationForest, ... - Registry: Support deleting a version of a model. -## 1.2.2 - -### Bug Fixes - -### Behavior Changes +## 1.2.2 (2024-02-13) ### New Features @@ -38,14 +61,14 @@ `snowflake.ml.model.models.huggingface_pipeline.HuggingFacePipelineModel` object, the following endpoints are required to be allowed: huggingface.com:80, huggingface.com:443, huggingface.co:80, huggingface.co:443. -## 1.2.1 +## 1.2.1 (2024-01-25) ### New Features - Model Development: Infers output column data type for transformers when possible. - Registry: `relax_version` option is available in the `options` argument when logging the model. -## 1.2.0 +## 1.2.0 (2024-01-11) ### Bug Fixes @@ -53,8 +76,6 @@ XGBoost models deployed to SPCS. - Model Registry: Fix model deployment to SPCS on Windows machines. -### Behavior Changes - ### New Features - Model Development: Introduced XGBoost external memory training feature. This feature enables training XGBoost models @@ -72,7 +93,7 @@ `snowflake.ml.registry.Registry`, except when specifically required. The old model registry will be removed once all its primary functionalities are fully integrated into the new registry. -## 1.1.2 +## 1.1.2 (2023-12-18) ### Bug Fixes @@ -90,7 +111,7 @@ its primary functionalities are fully integrated into the new registry. - Model Development: SQL implementation of binary `precision_score` metric. -## 1.1.1 +## 1.1.1 (2023-12-05) ### Bug Fixes @@ -103,7 +124,7 @@ its primary functionalities are fully integrated into the new registry. requiring automatic input_cols inference, but need to avoid using specific columns, like index columns, during training or inference. -## 1.1.0 +## 1.1.0 (2023-12-01) ### Bug Fixes @@ -111,8 +132,6 @@ its primary functionalities are fully integrated into the new registry. - Model Development: OrdinalEncoder and LabelEncoder output_columns do not need to be valid snowflake identifiers. They would previously be excluded if the normalized name did not match the name specified in output_columns. -### Behavior Changes - ### New Features - Model Registry: Add support for invoking public endpoint on SPCS service, by providing a "enable_ingress" SPCS @@ -120,7 +139,7 @@ its primary functionalities are fully integrated into the new registry. - Model Development: Add support for distributed HPO - GridSearchCV and RandomizedSearchCV execution will be distributed on multi-node warehouses. -## 1.0.12 +## 1.0.12 (2023-11-13) ### Bug Fixes @@ -145,7 +164,7 @@ its primary functionalities are fully integrated into the new registry. - Model Registry: Enable best-effort SPCS job/service log streaming when logging level is set to INFO. -## 1.0.11 +## 1.0.11 (2023-10-27) ### New Features @@ -164,7 +183,7 @@ its primary functionalities are fully integrated into the new registry. - Model Development: Fix metrics compatibility with Snowpark Dataframes that use Snowflake identifiers - Model Registry: Resolve 'delete_deployment' not deleting the SPCS service in certain cases. -## 1.0.10 +## 1.0.10 (2023-10-13) ### Behavior Changes diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7df6ff5a..71a2d770 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -46,13 +46,13 @@ Note: You may need to configure your editor to run this on save. To build the package, run: ```shell -> bazel build //snowflake/ml:wheel +> bazel build //:wheel ``` `bazel` can be run from anywhere under the monorepo and it can accept absolute path or a relative path. For example, ```sh -snowflake/ml> bazel build :wheel +snowml> bazel build :wheel ``` You can build an entire sub-tree as: diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index 486baec4..7bec5f62 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -17,7 +17,7 @@ build: noarch: python package: name: snowflake-ml-python - version: 1.2.3 + version: 1.3.0 requirements: build: - python @@ -42,7 +42,7 @@ requirements: - scikit-learn>=1.2.1,<1.4 - scipy>=1.9,<2 - snowflake-connector-python>=3.0.4,<4 - - snowflake-snowpark-python>=1.8.0,<2 + - snowflake-snowpark-python>=1.8.0,<2,!=1.12.0 - sqlparse>=0.4,<1 - typing-extensions>=4.1.0,<5 - xgboost>=1.7.3,<2 diff --git a/codegen/codegen_rules.bzl b/codegen/codegen_rules.bzl index 01518ba6..bcfcf533 100644 --- a/codegen/codegen_rules.bzl +++ b/codegen/codegen_rules.bzl @@ -94,6 +94,8 @@ def autogen_estimators(module, estimator_info_list): "//snowflake/ml/modeling/_internal:estimator_utils", "//snowflake/ml/modeling/_internal:model_trainer", "//snowflake/ml/modeling/_internal:model_trainer_builder", + "//snowflake/ml/modeling/_internal:transformer_protocols", + "//snowflake/ml/modeling/_internal:model_transformer_builder", ], ) diff --git a/codegen/sklearn_wrapper_generator.py b/codegen/sklearn_wrapper_generator.py index 793aaf37..a5931b47 100644 --- a/codegen/sklearn_wrapper_generator.py +++ b/codegen/sklearn_wrapper_generator.py @@ -154,18 +154,6 @@ def _is_classifier_obj(class_object: Tuple[str, type]) -> bool: """ return WrapperGeneratorFactory._is_class_of_type(class_object[1], "ClassifierMixin") - @staticmethod - def _is_cluster_obj(class_object: Tuple[str, type]) -> bool: - """Check if the given estimator object can cluster features and conduct fit_predict methods. - - Args: - class_object: Meta class object which needs to be checked. - - Returns: - True if the class inherits from ClusterMixin, otherwise False. - """ - return WrapperGeneratorFactory._is_class_of_type(class_object[1], "ClusterMixin") - @staticmethod def _is_meta_estimator_obj(class_object: Tuple[str, type]) -> bool: """Check if the given estimator object requires an `estimator` parameter. @@ -277,6 +265,33 @@ def _is_xgboost(module_name: str) -> bool: """ return module_name.split(".")[0] == "xgboost" + @staticmethod + def _is_deterministic(class_object: Tuple[str, type]) -> bool: + """Checks if the given module is deterministic or not + + Args: + class_object: Meta class object which needs to be checked. + + Returns: + True if the class is deterministic, otherwise False. + """ + return not ( + WrapperGeneratorFactory._is_class_of_type(class_object[1], "LinearDiscriminantAnalysis") + or WrapperGeneratorFactory._is_class_of_type(class_object[1], "BernoulliRBM") + ) + + @staticmethod + def _is_deterministic_cross_platform(class_object: Tuple[str, type]) -> bool: + """Checks if the given module is deterministic or not across different platforms + + Args: + class_object: Meta class object which needs to be checked. + + Returns: + True if the class is deterministic across different platforms, otherwise False. + """ + return not (WrapperGeneratorFactory._is_class_of_type(class_object[1], "Isomap")) + @staticmethod def _is_lightgbm(module_name: str) -> bool: """Checks if the given module belongs to LightGBM package. @@ -604,7 +619,6 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None: self.test_estimator_imports_list: List[str] = [] # Optional function support - self.fit_predict_cluster_function_support = False self.fit_transform_manifold_function_support = False # Dependencies @@ -654,7 +668,6 @@ def _populate_flags(self) -> None: self._is_multioutput_estimator = WrapperGeneratorFactory._is_multioutput_estimator_obj(self.class_object) self._is_k_neighbors = WrapperGeneratorFactory._is_k_neighbors_obj(self.class_object) self._is_heterogeneous_ensemble = WrapperGeneratorFactory._is_heterogeneous_ensemble_obj(self.class_object) - self._is_cluster = WrapperGeneratorFactory._is_cluster_obj(self.class_object) self._is_stacking_ensemble = WrapperGeneratorFactory._is_stacking_ensemble_obj(self.class_object) self._is_voting_ensemble = WrapperGeneratorFactory._is_voting_ensemble_obj(self.class_object) self._is_chain_multioutput = WrapperGeneratorFactory._is_chain_multioutput_obj(self.class_object) @@ -668,6 +681,10 @@ def _populate_flags(self) -> None: self._is_randomized_search_cv = WrapperGeneratorFactory._is_randomized_search_cv(self.class_object) self._is_iterative_imputer = WrapperGeneratorFactory._is_iterative_imputer(self.class_object) self._is_xgboost = WrapperGeneratorFactory._is_xgboost(self.module_name) + self._is_deterministic = WrapperGeneratorFactory._is_deterministic(self.class_object) + self._is_deterministic_cross_platform = WrapperGeneratorFactory._is_deterministic_cross_platform( + self.class_object + ) def _populate_import_statements(self) -> None: self.estimator_imports_list.append("import numpy") @@ -984,11 +1001,6 @@ def generate(self) -> "SklearnWrapperGenerator": ] self.test_estimator_input_args_list.append(f"dictionary={dictionary}") - if self._is_cluster: - self.fit_predict_cluster_function_support = True - if self._is_manifold: - self.fit_transform_manifold_function_support = True - if self._is_manifold: self.fit_transform_manifold_function_support = True @@ -998,12 +1010,10 @@ def generate(self) -> "SklearnWrapperGenerator": if "n_components" in self.original_init_signature.parameters.keys(): if WrapperGeneratorFactory._is_class_of_type(self.class_object[1], "SpectralBiclustering"): - # For spectral bi clustering, set number of sigular vertors to consider to number of input cols and + # For spectral bi clustering, set number of singular vectors to consider to number of input cols and # num best vector to select to half the number of input cols. self.test_estimator_input_args_list.append("n_components=len(cols)") self.test_estimator_input_args_list.append("n_best=int(len(cols)/2)") - else: - self.test_estimator_input_args_list.append("n_components=1") if self._is_heterogeneous_ensemble: if self._is_regressor: diff --git a/codegen/sklearn_wrapper_template.py_template b/codegen/sklearn_wrapper_template.py_template index 068a7646..c5d8835b 100644 --- a/codegen/sklearn_wrapper_template.py_template +++ b/codegen/sklearn_wrapper_template.py_template @@ -10,6 +10,7 @@ import pandas as pd import numpy as np from numpy import typing as npt + {transform.estimator_imports} from sklearn.utils.metaestimators import available_if @@ -20,16 +21,21 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV from snowflake.ml._internal.utils import pkg_version_utils, identifier from snowflake.snowpark import DataFrame, Session from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type -from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder -from snowflake.ml.modeling._internal.model_trainer import ModelTrainer +from snowflake.ml.modeling._internal.transformer_protocols import ( + ModelTransformHandlers, + BatchInferenceKwargsTypedDict, + ScoreKwargsTypedDict +) + +from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder + from snowflake.ml.modeling._internal.estimator_utils import ( gather_dependencies, original_estimator_has_callable, transform_snowml_obj_to_sklearn_obj, validate_sklearn_args, ) -from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers from snowflake.ml.model.model_signature import ( DataType, @@ -47,17 +53,12 @@ _PROJECT = "ModelDevelopment" # e.g. sklearn.linear_model -> LinearModel. _SUBPROJECT = "".join([s.capitalize() for s in "{transform.root_module_name}".replace("sklearn.", "").split("_")]) +DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame] -def _is_fit_predict_method_enabled() -> Callable[[Any], bool]: - def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]: - return {transform.fit_predict_cluster_function_support} and callable(getattr(self._sklearn_object, "fit_predict", None)) - return check - - -def _is_fit_transform_method_enabled() -> Callable[[Any], bool]: - def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]: - return {transform.fit_transform_manifold_function_support} and callable(getattr(self._sklearn_object, "fit_transform", None)) - return check +def _is_fit_transform_method_enabled() -> Callable[[Any], bool]: + def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]: + return {transform.fit_transform_manifold_function_support} and callable(getattr(self._sklearn_object, "fit_transform", None)) + return check class {transform.original_class_name}(BaseTransformer): @@ -85,8 +86,10 @@ class {transform.original_class_name}(BaseTransformer): self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols self._snowpark_cols: Optional[List[str]] = self.input_cols - self._handlers: TransformerHandlers = HandlersImpl(class_name={transform.original_class_name}.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True) self._autogenerated = True + self._class_name={transform.original_class_name}.__class__.__name__ + self._subproject = _SUBPROJECT + def _get_rand_id(self) -> str: """ @@ -177,16 +180,12 @@ class {transform.original_class_name}(BaseTransformer): else: return list(set(dataset.columns) - set(self.output_cols)) - def _batch_inference( + def _batch_inference_validate_snowpark( self, dataset: DataFrame, inference_method: str, - expected_output_cols_list: List[str], - expected_output_cols_type: str = "", - *args: Any, - **kwargs: Any, - ) -> DataFrame: - """Util method to create UDF and run batch inference. + ) -> List[str]: + """Util method to run validate that batch inference can be run on a snowpark dataframe. """ if not self._is_fitted: raise exceptions.SnowflakeMLException( @@ -205,153 +204,9 @@ class {transform.original_class_name}(BaseTransformer): ), ) # Validate that key package version in user workspace are supported in snowflake conda channel - pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( + return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT) - return self._handlers.batch_inference( - dataset, - session, - self._sklearn_object, - self._get_dependencies(), - inference_method, - self.input_cols, - self._get_pass_through_columns(dataset), - expected_output_cols_list, - expected_output_cols_type, - *args, - **kwargs, - ) - - - def _sklearn_inference( - self, - dataset: pd.DataFrame, - inference_method: str, - expected_output_cols_list: List[str], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: - output_cols = expected_output_cols_list.copy() - - # Model expects exact same columns names in the input df for predict call. - # Given the scenario that user use snowpark DataFrame in fit call, but pandas DataFrame in predict call - # input cols need to match unquoted / quoted - input_cols = self.input_cols - assert self._snowpark_cols is not None # Keep mypy happy - _snowpark_input_cols: List[str] = self._snowpark_cols - - estimator = self._sklearn_object - - if hasattr(estimator, "feature_names_in_"): - features_required_by_estimator = getattr(estimator, "feature_names_in_") - else: - features_required_by_estimator = _snowpark_input_cols - missing_features = [] - features_in_dataset = set(dataset.columns) - - columns_to_select = [] - for i, f in enumerate(features_required_by_estimator): - if ( - i >= len(input_cols) - or (input_cols[i] != f and _snowpark_input_cols[i] != f) - or (input_cols[i] not in features_in_dataset and _snowpark_input_cols[i] not in features_in_dataset) - ): - missing_features.append(f) - elif input_cols[i] in features_in_dataset: - columns_to_select.append(input_cols[i]) - elif _snowpark_input_cols[i] in features_in_dataset: - columns_to_select.append(_snowpark_input_cols[i]) - - if len(missing_features) > 0: - raise exceptions.SnowflakeMLException( - error_code=error_codes.NOT_FOUND, - original_exception=ValueError( - "The feature names should match with those that were passed during fit.\n" - f"Features seen during fit call but not present in the input: {{missing_features}}\n" - f"Features in the input dataframe : {{input_cols}}\n" - ), - ) - input_df = dataset[columns_to_select] - input_df.columns = features_required_by_estimator - - inference_res = getattr(estimator, inference_method)(input_df, *args, **kwargs) - - if ( - isinstance(inference_res, list) - and len(inference_res) > 0 - and isinstance(inference_res[0], np.ndarray) - ): - # In case of multioutput estimators, predict_proba, decision_function etc., functions return a list of - # ndarrays. We need to concatenate them. - - # First compute output column names - if len(output_cols) == len(inference_res): - actual_output_cols = [] - for idx, np_arr in enumerate(inference_res): - for i in range(1 if len(np_arr.shape) <= 1 else np_arr.shape[1]): - actual_output_cols.append(f"{{output_cols[idx]}}_{{i}}") - output_cols = actual_output_cols - - # Concatenate np arrays - transformed_numpy_array = np.concatenate(inference_res, axis=1) - elif ( - isinstance(inference_res, tuple) - and len(inference_res) > 0 - and isinstance(inference_res[0], np.ndarray) - ): - # In case of kneighbors, functions return a tuple of ndarrays. - transformed_numpy_array = np.stack(inference_res, axis=1) - else: - transformed_numpy_array = inference_res - - if (len(transformed_numpy_array.shape) == 3) and inference_method != "kneighbors": - # VotingClassifier will return results of shape (n_classifiers, n_samples, n_classes) - # when voting = "soft" and flatten_transform = False. We can't handle unflatten transforms, - # so we ignore flatten_transform flag and flatten the results. - transformed_numpy_array = np.hstack(transformed_numpy_array) # type: ignore[call-overload] - - if len(transformed_numpy_array.shape) == 1: - transformed_numpy_array = np.reshape(transformed_numpy_array, (-1, 1)) - - shape = transformed_numpy_array.shape - if shape[1] != len(output_cols): - if len(output_cols) != 1: - raise exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=TypeError( - "expected_output_cols_list must be same length as transformed array or " - "should be of length 1" - ), - ) - actual_output_cols = [] - for i in range(shape[1]): - actual_output_cols.append(f"{{output_cols[0]}}_{{i}}") - output_cols = actual_output_cols - - if inference_method == "kneighbors": - if (len(transformed_numpy_array.shape) == 3): # return_distance=True - shape = transformed_numpy_array.shape - data = [transformed_numpy_array[:, i, :].tolist() for i in range(shape[1])] - kneighbors_df = pd.DataFrame({{output_cols[i]: data[i] for i in range(shape[1])}}) - else: # return_distance=False - kneighbors_df = pd.DataFrame( - {{output_cols[0]: [ - transformed_numpy_array[i, :].tolist() for i in range(transformed_numpy_array.shape[0]) - ]}} - ) - - if self._drop_input_cols: - dataset = kneighbors_df - else: - dataset = pd.concat([dataset, kneighbors_df], axis=1) - else: - if self._drop_input_cols: - dataset = pd.DataFrame(data=transformed_numpy_array, columns=output_cols) - else: - dataset = dataset.copy() - dataset[output_cols] = transformed_numpy_array - return dataset - @available_if(original_estimator_has_callable("predict")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( project=_PROJECT, @@ -365,6 +220,12 @@ class {transform.original_class_name}(BaseTransformer): Transformed dataset. """ super()._check_dataset_type(dataset) + inference_method = "predict" + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + if isinstance(dataset, DataFrame): expected_type_inferred = "{transform.udf_datatype}" # when it is classifier, infer the datatype from label columns @@ -378,21 +239,41 @@ class {transform.original_class_name}(BaseTransformer): error_code=error_codes.INVALID_ATTRIBUTE, original_exception=ValueError(error_str), ) + expected_type_inferred = convert_sp_to_sf_type( label_cols_signatures[0].as_snowpark_type() ) - output_df = self._batch_inference( - dataset=dataset, - inference_method="predict", - expected_output_cols_list=self.output_cols, - expected_output_cols_type=expected_type_inferred, + self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark() + + transform_kwargs = dict( + session = dataset._session, + dependencies = self._deps, + pass_through_cols = self._get_pass_through_columns(dataset), + expected_output_cols_type = expected_type_inferred, ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="predict", - expected_output_cols_list=self.output_cols,) + transform_kwargs = dict( + snowpark_input_cols = self._snowpark_cols, + drop_input_cols = self._drop_input_cols + ) + + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols= self.output_cols, + **transform_kwargs + ) return output_df @@ -409,6 +290,11 @@ class {transform.original_class_name}(BaseTransformer): Transformed dataset. """ super()._check_dataset_type(dataset) + inference_method="transform" + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() if isinstance(dataset, DataFrame): expected_dtype = "{transform.udf_datatype}" if {transform._is_heterogeneous_ensemble}: # is child of _BaseHeterogeneousEnsemble @@ -434,30 +320,64 @@ class {transform.original_class_name}(BaseTransformer): if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols): expected_dtype = convert_sp_to_sf_type(output_types[0]) - output_df = self._batch_inference( + self._deps = self._batch_inference_validate_snowpark( dataset=dataset, - inference_method="transform", - expected_output_cols_list=self.output_cols, - expected_output_cols_type=expected_dtype, + inference_method=inference_method, ) - elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="transform", - expected_output_cols_list=self.output_cols, + assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark() + + transform_kwargs = dict( + session = dataset._session, + dependencies = self._deps, + pass_through_cols = self._get_pass_through_columns(dataset), + expected_output_cols_type = expected_dtype, ) + elif isinstance(dataset, pd.DataFrame): + transform_kwargs = dict( + snowpark_input_cols = self._snowpark_cols, + drop_input_cols = self._drop_input_cols + ) + + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self.output_cols, + **transform_kwargs + ) return output_df - @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc] - def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]: + @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc] + def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]: """ {transform.fit_predict_docstring} + output_cols_prefix: Prefix for the response columns Returns: Predicted dataset. """ - self.fit(dataset) - assert self._sklearn_object is not None - return self._sklearn_object.labels_ + self._infer_input_output_cols(dataset) + super()._check_dataset_type(dataset) + model_trainer = ModelTrainerBuilder.build_fit_predict( + estimator=self._sklearn_object, + dataset=dataset, + input_cols=self.input_cols, + autogenerated=self._autogenerated, + subproject=_SUBPROJECT, + ) + output_result, fitted_estimator = model_trainer.train_fit_predict( + pass_through_columns=self._get_pass_through_columns(dataset), + expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix), + ) + self._sklearn_object = fitted_estimator + self._is_fitted = True + return output_result @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc] @@ -524,20 +444,44 @@ class {transform.original_class_name}(BaseTransformer): Output dataset with probability of the sample for each class in the model. """ super()._check_dataset_type(dataset) + inference_method = "predict_proba" + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( + self._deps = self._batch_inference_validate_snowpark( dataset=dataset, - inference_method="predict_proba", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), - expected_output_cols_type="float" + inference_method=inference_method, ) - elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="predict_proba", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._deps, + pass_through_cols=self._get_pass_through_columns(dataset), + expected_output_cols_type="float", ) - + + elif isinstance(dataset, pd.DataFrame): + transform_kwargs = dict( + snowpark_input_cols = self._snowpark_cols, + drop_input_cols = self._drop_input_cols + ) + + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated + ) + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix), + **transform_kwargs + ) return output_df @available_if(original_estimator_has_callable("predict_log_proba")) # type: ignore[misc] @@ -557,22 +501,47 @@ class {transform.original_class_name}(BaseTransformer): Output dataset with log probability of the sample for each class in the model. """ super()._check_dataset_type(dataset) + inference_method="predict_log_proba" + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( + self._deps = self._batch_inference_validate_snowpark( dataset=dataset, - inference_method="predict_log_proba", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), - expected_output_cols_type="float" + inference_method=inference_method, ) - elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="predict_log_proba", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._deps, + pass_through_cols=self._get_pass_through_columns(dataset), + expected_output_cols_type="float", ) + elif isinstance(dataset, pd.DataFrame): + transform_kwargs = dict( + snowpark_input_cols = self._snowpark_cols, + drop_input_cols = self._drop_input_cols + ) + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix), + **transform_kwargs + ) return output_df + @available_if(original_estimator_has_callable("decision_function")) # type: ignore[misc] def decision_function( self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "decision_function_" @@ -585,20 +554,45 @@ class {transform.original_class_name}(BaseTransformer): Output dataset with results of the decision function for the samples in input dataset. """ super()._check_dataset_type(dataset) + inference_method="decision_function" + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( + self._deps = self._batch_inference_validate_snowpark( dataset=dataset, - inference_method="decision_function", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), - expected_output_cols_type="float" + inference_method=inference_method, ) - elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="decision_function", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._deps, + pass_through_cols=self._get_pass_through_columns(dataset), + expected_output_cols_type="float", ) + elif isinstance(dataset, pd.DataFrame): + transform_kwargs = dict( + snowpark_input_cols = self._snowpark_cols, + drop_input_cols = self._drop_input_cols + ) + + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix), + **transform_kwargs + ) return output_df @available_if(original_estimator_has_callable("score_samples")) # type: ignore[misc] @@ -617,20 +611,45 @@ class {transform.original_class_name}(BaseTransformer): Output dataset with probability of the sample for each class in the model. """ super()._check_dataset_type(dataset) + inference_method="score_samples" + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( + self._deps = self._batch_inference_validate_snowpark( dataset=dataset, - inference_method="score_samples", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), - expected_output_cols_type="float" + inference_method=inference_method, ) - elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="score_samples", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._deps, + pass_through_cols=self._get_pass_through_columns(dataset), + expected_output_cols_type="float", ) - + + elif isinstance(dataset, pd.DataFrame): + transform_kwargs = dict( + snowpark_input_cols = self._snowpark_cols, + drop_input_cols = self._drop_input_cols + ) + + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix), + **transform_kwargs + ) return output_df @available_if(original_estimator_has_callable("score")) # type: ignore[misc] @@ -647,40 +666,42 @@ class {transform.original_class_name}(BaseTransformer): """ self._infer_input_output_cols(dataset) super()._check_dataset_type(dataset) - if isinstance(dataset, pd.DataFrame): - output_score = self._handlers.score_pandas( - dataset, - self._sklearn_object, - self.input_cols, - self.label_cols, - self.sample_weight_col + # This dictionary contains optional kwargs for scoring. These kwargs + # are specific to the type of dataset used. + transform_kwargs: ScoreKwargsTypedDict = dict() + + if isinstance(dataset, DataFrame): + selected_cols = self._get_active_columns() + if len(selected_cols) > 0: + dataset = dataset.select(selected_cols) + assert isinstance(dataset._session, Session) # keep mypy happy + transform_kwargs = dict( + session=dataset._session, + dependencies=["snowflake-snowpark-python"] + self._get_dependencies(), + score_sproc_imports={transform.score_sproc_imports}, ) - elif isinstance(dataset, DataFrame): - output_score = self._score_snowpark(dataset) - return output_score + elif isinstance(dataset, pd.DataFrame): + # pandas_handler.score() does not require any extra kwargs. + transform_kwargs = dict() - def _score_snowpark(self, dataset: DataFrame) -> float: - # Specify input columns so column pruing will be enforced - selected_cols = self._get_active_columns() - if len(selected_cols) > 0: - dataset = dataset.select(selected_cols) + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated + ) - session = dataset._session - assert session is not None # keep mypy happy - - score = self._handlers.score_snowpark( - dataset, - session, - self._sklearn_object, - ["snowflake-snowpark-python"] + self._get_dependencies(), - {transform.score_sproc_imports}, - self.input_cols, - self.label_cols, - self.sample_weight_col, + output_score = transform_handlers.score( + input_cols=self.input_cols, + label_cols=self.label_cols, + sample_weight_col=self.sample_weight_col, + **transform_kwargs ) - return score + return output_score + @available_if(original_estimator_has_callable("kneighbors")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( @@ -703,30 +724,56 @@ class {transform.original_class_name}(BaseTransformer): Output dataset with results of the K-neighbors for the samples in input dataset. """ super()._check_dataset_type(dataset) + inference_method="kneighbors" + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() output_cols = ["neigh_ind"] if return_distance: output_cols.insert(0, "neigh_dist") + if isinstance(dataset, DataFrame): # TODO: Solve inconsistent neigh_ind with sklearn due to different precisions in case of close distances. - output_df = self._batch_inference( + + self._deps = self._batch_inference_validate_snowpark( dataset=dataset, - inference_method="kneighbors", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix, output_cols), - expected_output_cols_type="array", - n_neighbors=n_neighbors, - return_distance=return_distance, + inference_method=inference_method, + + ) + assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session = dataset._session, + dependencies = self._deps, + pass_through_cols = self._get_pass_through_columns(dataset), + expected_output_cols_type = "array", + n_neighbors = n_neighbors, + return_distance = return_distance ) elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="kneighbors", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix, output_cols), - n_neighbors=n_neighbors, - return_distance=return_distance, + transform_kwargs = dict( + n_neighbors = n_neighbors, + return_distance = return_distance, + snowpark_input_cols = self._snowpark_cols ) + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix, output_cols), + **transform_kwargs + ) return output_df + def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None: self._model_signature_dict = dict() diff --git a/codegen/transformer_autogen_test_template.py_template b/codegen/transformer_autogen_test_template.py_template index 94f8d3a0..65174ee9 100644 --- a/codegen/transformer_autogen_test_template.py_template +++ b/codegen/transformer_autogen_test_template.py_template @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import json import random +import platform from typing import Optional, Any, Tuple, List from absl.testing.absltest import TestCase, main @@ -124,7 +125,7 @@ class {transform.test_class_name}(TestCase): sklearn_reg.fit(**args) - inference_methods = ["transform", "predict"] + inference_methods = ["transform", "predict", "fit_predict"] for m in inference_methods: if callable(getattr(sklearn_reg, m, None)): if m == 'predict': @@ -180,9 +181,12 @@ class {transform.test_class_name}(TestCase): num_diffs = (~np.isclose(actual_arr, sklearn_numpy_arr)).sum() num_example = sklearn_numpy_arr.shape[0] assert num_diffs < 0.1 * num_example + elif (not {transform._is_deterministic}) or (not {transform._is_deterministic_cross_platform} and platform.system() == 'Windows'): + assert actual_arr.shape == sklearn_numpy_arr.shape else: np.testing.assert_allclose(actual_arr, sklearn_numpy_arr, rtol=1.e-1, atol=1.e-2) + expected_methods = ["predict_proba", "predict_log_proba", "decision_function", "kneighbors", "score_samples"] for m in expected_methods: assert not ( @@ -236,7 +240,7 @@ class {transform.test_class_name}(TestCase): ) elif ( m == "score_samples" - and reg.__class__.__name__ == 'BernoulliRBM' + and not {transform._is_deterministic} ): # score_samples is not deterministic for BernoulliRBM: # it computes a quantity called the free energy on X, diff --git a/requirements.yml b/requirements.yml index 80eafd96..0ed2881d 100644 --- a/requirements.yml +++ b/requirements.yml @@ -236,7 +236,7 @@ version_requirements: '>=3.0.4,<4' - name: snowflake-snowpark-python dev_version: 1.8.0 - version_requirements: '>=1.8.0,<2' + version_requirements: '>=1.8.0,<2,!=1.12.0' tags: - deployment_core - udf_inference diff --git a/snowflake/ml/_internal/env_utils.py b/snowflake/ml/_internal/env_utils.py index 8b14094a..00242fb5 100644 --- a/snowflake/ml/_internal/env_utils.py +++ b/snowflake/ml/_internal/env_utils.py @@ -35,6 +35,7 @@ class CONDA_OS(Enum): _NODEFAULTS = "nodefaults" _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {} _SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {} +_SUPPORTED_PACKAGE_SPEC_OPS = ["==", ">=", "<=", ">", "<"] DEFAULT_CHANNEL_NAME = "" SNOWML_SPROC_ENV = "IN_SNOWML_SPROC" @@ -236,6 +237,22 @@ def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement return new_pip_req +def get_package_spec_with_supported_ops_only(req: requirements.Requirement) -> requirements.Requirement: + """Get the package spec with supported ops only including ==, >=, <=, > and < + + Args: + req: A requirements.Requirement object showing the requirement. + + Returns: + A requirements.Requirement object with supported ops only + """ + new_req = copy.deepcopy(req) + new_req.specifier = specifiers.SpecifierSet( + specifiers=",".join([str(spec) for spec in req.specifier if spec.operator in _SUPPORTED_PACKAGE_SPEC_OPS]) + ) + return new_req + + def relax_requirement_version(req: requirements.Requirement) -> requirements.Requirement: """Relax version specifier from a requirement. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1) diff --git a/snowflake/ml/_internal/env_utils_test.py b/snowflake/ml/_internal/env_utils_test.py index 07068605..10095ea2 100644 --- a/snowflake/ml/_internal/env_utils_test.py +++ b/snowflake/ml/_internal/env_utils_test.py @@ -264,6 +264,52 @@ def test_get_local_installed_version_of_pip_package(self) -> None: requirements.Requirement(f"pip!={importlib_metadata.version('pip')}") ) + def test_get_package_spec_with_supported_ops_only(self) -> None: + r = requirements.Requirement("python-package==1.0.1") + self.assertEqual(env_utils.get_package_spec_with_supported_ops_only(r), r) + + r = requirements.Requirement("python-package==1.0.*") + self.assertEqual(env_utils.get_package_spec_with_supported_ops_only(r), r) + + r = requirements.Requirement("python-package>=1.0") + self.assertEqual(env_utils.get_package_spec_with_supported_ops_only(r), r) + + r = requirements.Requirement("python-package<=1.0") + self.assertEqual(env_utils.get_package_spec_with_supported_ops_only(r), r) + + r = requirements.Requirement("python-package>1.0") + self.assertEqual(env_utils.get_package_spec_with_supported_ops_only(r), r) + + r = requirements.Requirement("python-package<1.0") + self.assertEqual(env_utils.get_package_spec_with_supported_ops_only(r), r) + + r = requirements.Requirement("python-package>=1.0,<2,!=1.1") + self.assertEqual( + env_utils.get_package_spec_with_supported_ops_only(r), requirements.Requirement("python-package>=1.0,<2") + ) + + r = requirements.Requirement("python-package==1.0.1, !=1.0.2") + self.assertEqual( + env_utils.get_package_spec_with_supported_ops_only(r), requirements.Requirement("python-package==1.0.1") + ) + + r = requirements.Requirement("python-package[extra]>=1.0,<2,!=1.1") + self.assertEqual( + env_utils.get_package_spec_with_supported_ops_only(r), + requirements.Requirement("python-package[extra]>=1.0,<2"), + ) + + r = requirements.Requirement("python-package!=1.0.1") + self.assertEqual( + env_utils.get_package_spec_with_supported_ops_only(r), requirements.Requirement("python-package") + ) + + r = requirements.Requirement("python-package") + self.assertEqual( + env_utils.get_package_spec_with_supported_ops_only(r), requirements.Requirement("python-package") + ) + self.assertIsNot(env_utils.get_package_spec_with_supported_ops_only(r), r) + def test_relax_requirement_version(self) -> None: r = requirements.Requirement("python-package==1.0.1") self.assertEqual(env_utils.relax_requirement_version(r), requirements.Requirement("python-package>=1.0,<2")) diff --git a/snowflake/ml/_internal/exceptions/error_codes.py b/snowflake/ml/_internal/exceptions/error_codes.py index d2cc7407..02c276a7 100644 --- a/snowflake/ml/_internal/exceptions/error_codes.py +++ b/snowflake/ml/_internal/exceptions/error_codes.py @@ -90,6 +90,8 @@ SNOWML_INVALID_STAGE = "2210" # Invalid query caused by syntax error, invalid source, etc. SNOWML_INVALID_QUERY = "2211" +# Indicates that an error was encountered while attempting to deserialize a resource (SFFIleSystem). +SNOWML_DESERIALIZATION_FAILED = "2212" # Invalid Snowpark Session (Missing information) in Snowpark Session that is required. INVALID_SNOWPARK_SESSION = "2301" diff --git a/snowflake/ml/feature_store/feature_store.py b/snowflake/ml/feature_store/feature_store.py index 2517dce3..5745be16 100644 --- a/snowflake/ml/feature_store/feature_store.py +++ b/snowflake/ml/feature_store/feature_store.py @@ -633,45 +633,20 @@ def resume_feature_view(self, feature_view: FeatureView) -> FeatureView: Returns: A new feature view with updated status. - - Raises: - SnowflakeMLException: [ValueError] FeatureView is not in suspended status. - SnowflakeMLException: [RuntimeError] Failed to update feature view status. """ - if feature_view.status != FeatureViewStatus.SUSPENDED: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.SNOWML_UPDATE_FAILED, - original_exception=ValueError( - f"FeatureView {feature_view.name}/{feature_view.version} is not in suspended status. " - f"Actual status: {feature_view.status}" - ), - ) - return self._update_feature_view_status(feature_view, "RESUME") @dispatch_decorator(prpr_version="1.0.8") def suspend_feature_view(self, feature_view: FeatureView) -> FeatureView: """ - Suspend a running FeatureView. + Suspend an active FeatureView. Args: feature_view: FeatureView to suspend. Returns: A new feature view with updated status. - - Raises: - SnowflakeMLException: [ValueError] FeatureView is not in running status. - SnowflakeMLException: [RuntimeError] Failed to update feature view status. """ - if feature_view.status != FeatureViewStatus.RUNNING: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.SNOWML_UPDATE_FAILED, - original_exception=ValueError( - f"FeatureView {feature_view.name}/{feature_view.version} is not in running status. " - f"Actual status: {feature_view.status}" - ), - ) return self._update_feature_view_status(feature_view, "SUSPEND") @dispatch_decorator(prpr_version="1.0.8") @@ -1272,7 +1247,7 @@ def _join_features( query = f""" SELECT l_{layer}.*, - r_{layer}.* EXCLUDE {join_keys_str} + r_{layer}.* EXCLUDE ({join_keys_str}) FROM ({query}) l_{layer} LEFT JOIN ( SELECT {join_keys_str}, {', '.join(cols)} diff --git a/snowflake/ml/feature_store/feature_view.py b/snowflake/ml/feature_store/feature_view.py index 1f4b8b67..33a8f513 100644 --- a/snowflake/ml/feature_store/feature_view.py +++ b/snowflake/ml/feature_store/feature_view.py @@ -50,8 +50,9 @@ def __init__(self, version: str) -> None: class FeatureViewStatus(Enum): DRAFT = "DRAFT" STATIC = "STATIC" - RUNNING = "RUNNING" + RUNNING = "RUNNING" # This can be deprecated after BCR 2024_02 gets fully deployed SUSPENDED = "SUSPENDED" + ACTIVE = "ACTIVE" @dataclass(frozen=True) diff --git a/snowflake/ml/fileset/BUILD.bazel b/snowflake/ml/fileset/BUILD.bazel index ac5e77e5..99c5340a 100644 --- a/snowflake/ml/fileset/BUILD.bazel +++ b/snowflake/ml/fileset/BUILD.bazel @@ -30,6 +30,7 @@ py_library( ":stage_fs", "//snowflake/ml/_internal:telemetry", "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/utils:connection_params", ], ) diff --git a/snowflake/ml/fileset/sfcfs.py b/snowflake/ml/fileset/sfcfs.py index bcdee678..08001e16 100644 --- a/snowflake/ml/fileset/sfcfs.py +++ b/snowflake/ml/fileset/sfcfs.py @@ -1,6 +1,7 @@ import collections import logging -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast import fsspec @@ -13,12 +14,16 @@ ) from snowflake.ml._internal.utils import identifier from snowflake.ml.fileset import stage_fs +from snowflake.ml.utils import connection_params PROTOCOL_NAME = "sfc" _SFFilePath = collections.namedtuple("_SFFilePath", ["database", "schema", "stage", "filepath"]) _PROJECT = "FileSet" +_FILESYSTEM_KWARGS_KEY = "kwargs" +_RECREATE_FROM_SERIALIZED = "recreate_from_serialized" + class SFFileSystem(fsspec.AbstractFileSystem): """A filesystem that allows user to access Snowflake stages and stage files with valid Snowflake locations. @@ -73,7 +78,17 @@ def __init__( Raises: ValueError: An error occurred when not exactly one of sf_connection and snowpark_session is given. + SnowflakeMLException: A failure was encountered while recreating the SFFileSystem from a serialized state. """ + if kwargs.get(_RECREATE_FROM_SERIALIZED): + try: + snowpark_session = self._create_default_session() + except Exception as e: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.SNOWML_DESERIALIZATION_FAILED, + original_exception=ValueError("Unable to deserialize SFFileSystem."), + ) from e + if sf_connection: self._conn = sf_connection elif snowpark_session: @@ -85,6 +100,62 @@ def __init__( super().__init__(**kwargs) + def _create_default_session(self) -> snowpark.Session: + """Create a Snowpark Session from default login options. + + Returns: + An active Snowpark Session. + + Raises: + ValueError: Snowflake Login Options could not be retrieved from default locations. + ValueError: Snowflake Connection could not be created. + + """ + try: + snowflake_config = connection_params.SnowflakeLoginOptions() + except Exception as e: + raise ValueError("Unable to retrieve Snowflake Login Options.") from e + + try: + session = snowpark.Session.builder.configs(snowflake_config).create() + except Exception as e: + raise ValueError("Unable to create Snowflake connection.") from e + + assert isinstance(session, snowpark.Session) + return session + + def __reduce__(self) -> Tuple[Callable[[], Type["SFFileSystem"]], Tuple[()], Dict[str, Any]]: + """Returns a state dictionary for use in serialization. + + Returns: + A tuple that is used for recreating the SFFileSystem. For more information, refer to + https://docs.python.org/3/library/pickle.html#object.__reduce__ + A `partial` is used to generate a callable that accepts kwargs. + """ + state_dictionary = {_FILESYSTEM_KWARGS_KEY: self._kwargs} + + return partial(self.__class__, **{_RECREATE_FROM_SERIALIZED: True}), (), state_dictionary + + def __setstate__(self, state_dict: Dict[str, Any]) -> None: + """Sets the dictionary state at deserialization time, and rebuilds a snowflake connection. + + Args: + state_dict: State dictionary saved at serialization time. + + Raises: + KeyError: The Kwargs key is not present in the state dictionary. + ValueError: The value corresponding to the kwargs key is not a dictionary. + + """ + if _FILESYSTEM_KWARGS_KEY not in state_dict: + raise KeyError(f"Serialized state dictionary missing key {_FILESYSTEM_KWARGS_KEY}.") + + kwargs_dict = state_dict.get(_FILESYSTEM_KWARGS_KEY) + if not isinstance(kwargs_dict, dict): + raise ValueError(f"The value corresponding to {_FILESYSTEM_KWARGS_KEY} is not a dictionary.") + + self._kwargs = kwargs_dict + def _get_stage_fs(self, sf_file_path: _SFFilePath) -> stage_fs.SFStageFileSystem: """Get the stage file system for the given snowflake location. diff --git a/snowflake/ml/fileset/sfcfs_test.py b/snowflake/ml/fileset/sfcfs_test.py index 45642b01..14282706 100644 --- a/snowflake/ml/fileset/sfcfs_test.py +++ b/snowflake/ml/fileset/sfcfs_test.py @@ -1,3 +1,5 @@ +import pickle + import fsspec from absl.testing import absltest @@ -238,6 +240,44 @@ def test_open(self) -> None: "nytrain/1.txt", mode="rb", block_size=None, autocommit=True, cache_options=None ) + def test_fs_serializability(self) -> None: + """Test if an object of Snowflake FS can be serialized using pickle.""" + + kwargs_dict = {"key1": "val1", "key2": "val2"} + sffs = sfcfs.SFFileSystem(sf_connection=self.mock_connection, snowpark_session=None, **kwargs_dict) + + pickled_data = pickle.dumps(sffs) + sffs_deserialized = pickle.loads(pickled_data) + assert sffs_deserialized._conn is not None + assert sffs_deserialized._kwargs == kwargs_dict + + def test_create_default_session_exceptions(self) -> None: + """Tests that correct exceptions are raised when the function fails to create a session. + Mocks the two session creation functions called by _create_default_connection individually. + """ + sffs = sfcfs.SFFileSystem(sf_connection=self.mock_connection) + with self.assertRaises(ValueError): + with absltest.mock.patch( + "snowflake.ml.fileset.sfcfs.connection_params.SnowflakeLoginOptions", + side_effect=Exception("Error message"), + ): + sffs._create_default_session() + + with self.assertRaises(ValueError): + with absltest.mock.patch( + "snowflake.snowpark.Session.SessionBuilder.create", side_effect=Exception("Error message") + ): + sffs._create_default_session() + + def test_set_state_bad_state_dict(self) -> None: + """When deserializing, the state dictionary requires a kwargs key that corresponds to a dictionary.""" + sffs = sfcfs.SFFileSystem(sf_connection=self.mock_connection) + with self.assertRaises(KeyError): + sffs.__setstate__(state_dict={"bad_key": 2}) + + with self.assertRaises(ValueError): + sffs.__setstate__(state_dict={"kwargs": "not_a_dict"}) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_api.py b/snowflake/ml/model/_api.py index f69bfb34..28da6d3e 100644 --- a/snowflake/ml/model/_api.py +++ b/snowflake/ml/model/_api.py @@ -491,7 +491,9 @@ def predict( keep_order = True output_with_input_features = False df = model_signature._convert_and_validate_local_data(X, sig.inputs) - s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(session, df, keep_order=keep_order) + s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df( + session, df, keep_order=keep_order, features=sig.inputs + ) else: keep_order = False output_with_input_features = True diff --git a/snowflake/ml/model/_client/model/model_version_impl.py b/snowflake/ml/model/_client/model/model_version_impl.py index 9e38a900..2b9035bb 100644 --- a/snowflake/ml/model/_client/model/model_version_impl.py +++ b/snowflake/ml/model/_client/model/model_version_impl.py @@ -283,6 +283,7 @@ def run( X: Union[pd.DataFrame, dataframe.DataFrame], *, function_name: Optional[str] = None, + strict_input_validation: bool = False, ) -> Union[pd.DataFrame, dataframe.DataFrame]: """Invoke a method in a model version object. @@ -290,6 +291,8 @@ def run( X: The input data, which could be a pandas DataFrame or Snowpark DataFrame. function_name: The function name to run. It is the name used to call a function in SQL. Defaults to None. It can only be None if there is only 1 method. + strict_input_validation: Enable stricter validation for the input data. This will result value range based + type validation to make sure your input data won't overflow when providing to the model. Raises: ValueError: When no method with the corresponding name is available. @@ -331,5 +334,6 @@ def run( X=X, model_name=self._model_name, version_name=self._version_name, + strict_input_validation=strict_input_validation, statement_params=statement_params, ) diff --git a/snowflake/ml/model/_client/model/model_version_impl_test.py b/snowflake/ml/model/_client/model/model_version_impl_test.py index 100d8921..ddf9cb31 100644 --- a/snowflake/ml/model/_client/model/model_version_impl_test.py +++ b/snowflake/ml/model/_client/model/model_version_impl_test.py @@ -334,6 +334,7 @@ def test_run(self) -> None: X=m_df, model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + strict_input_validation=False, statement_params=mock.ANY, ) @@ -350,6 +351,7 @@ def test_run(self) -> None: X=m_df, model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + strict_input_validation=False, statement_params=mock.ANY, ) @@ -376,6 +378,34 @@ def test_run_without_method_name(self) -> None: X=m_df, model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + strict_input_validation=False, + statement_params=mock.ANY, + ) + + def test_run_strict(self) -> None: + m_df = mock_data_frame.MockDataFrame() + m_methods = [ + { + "name": '"predict"', + "target_method": "predict", + "signature": _DUMMY_SIG["predict"], + }, + ] + + with mock.patch.object( + self.m_mv, "show_functions", return_value=m_methods + ) as mock_list_methods, mock.patch.object( + self.m_mv._model_ops, "invoke_method", return_value=m_df + ) as mock_invoke_method: + self.m_mv.run(m_df, strict_input_validation=True) + mock_list_methods.assert_called_once_with() + mock_invoke_method.assert_called_once_with( + method_name='"predict"', + signature=_DUMMY_SIG["predict"], + X=m_df, + model_name=sql_identifier.SqlIdentifier("MODEL"), + strict_input_validation=True, + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), statement_params=mock.ANY, ) diff --git a/snowflake/ml/model/_client/ops/model_ops.py b/snowflake/ml/model/_client/ops/model_ops.py index f06dde43..7d569239 100644 --- a/snowflake/ml/model/_client/ops/model_ops.py +++ b/snowflake/ml/model/_client/ops/model_ops.py @@ -382,6 +382,7 @@ def invoke_method( X: Union[type_hints.SupportedDataType, dataframe.DataFrame], model_name: sql_identifier.SqlIdentifier, version_name: sql_identifier.SqlIdentifier, + strict_input_validation: bool = False, statement_params: Optional[Dict[str, str]] = None, ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]: identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED @@ -390,12 +391,16 @@ def invoke_method( if not isinstance(X, dataframe.DataFrame): keep_order = True output_with_input_features = False - df = model_signature._convert_and_validate_local_data(X, signature.inputs) - s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(self._session, df, keep_order=keep_order) + df = model_signature._convert_and_validate_local_data(X, signature.inputs, strict=strict_input_validation) + s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df( + self._session, df, keep_order=keep_order, features=signature.inputs + ) else: keep_order = False output_with_input_features = True - identifier_rule = model_signature._validate_snowpark_data(X, signature.inputs) + identifier_rule = model_signature._validate_snowpark_data( + X, signature.inputs, strict=strict_input_validation + ) s_df = X original_cols = s_df.columns diff --git a/snowflake/ml/model/_client/ops/model_ops_test.py b/snowflake/ml/model/_client/ops/model_ops_test.py index 2c7709e2..7cf886c3 100644 --- a/snowflake/ml/model/_client/ops/model_ops_test.py +++ b/snowflake/ml/model/_client/ops/model_ops_test.py @@ -658,7 +658,9 @@ def test_invoke_method_1(self) -> None: version_name=sql_identifier.SqlIdentifier("V1"), statement_params=self.m_statement_params, ) - mock_convert_from_df.assert_called_once_with(self.c_session, mock.ANY, keep_order=True) + mock_convert_from_df.assert_called_once_with( + self.c_session, mock.ANY, keep_order=True, features=m_sig.inputs + ) mock_invoke_method.assert_called_once_with( method_name=sql_identifier.SqlIdentifier("PREDICT"), input_df=m_df, @@ -692,7 +694,9 @@ def test_invoke_method_1_no_drop(self) -> None: version_name=sql_identifier.SqlIdentifier("V1"), statement_params=self.m_statement_params, ) - mock_convert_from_df.assert_called_once_with(self.c_session, mock.ANY, keep_order=True) + mock_convert_from_df.assert_called_once_with( + self.c_session, mock.ANY, keep_order=True, features=m_sig.inputs + ) mock_invoke_method.assert_called_once_with( method_name=sql_identifier.SqlIdentifier("PREDICT"), input_df=m_df, @@ -726,7 +730,43 @@ def test_invoke_method_2(self) -> None: statement_params=self.m_statement_params, ) mock_convert_from_df.assert_not_called() - mock_validate_snowpark_data.assert_called_once_with(m_df, m_sig.inputs) + mock_validate_snowpark_data.assert_called_once_with(m_df, m_sig.inputs, strict=False) + + mock_invoke_method.assert_called_once_with( + method_name=sql_identifier.SqlIdentifier("PREDICT"), + input_df=m_df, + input_args=["INPUT"], + returns=[("output", spt.FloatType(), "OUTPUT")], + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_convert_to_df.assert_not_called() + + def test_invoke_method_3(self) -> None: + m_sig = _DUMMY_SIG["predict"] + m_df = mock_data_frame.MockDataFrame() + m_df.__setattr__("columns", ["COL1", "COL2"]) + with mock.patch.object( + snowpark_handler.SnowparkDataFrameHandler, "convert_from_df" + ) as mock_convert_from_df, mock.patch.object( + model_signature, "_validate_snowpark_data", return_value=model_signature.SnowparkIdentifierRule.NORMALIZED + ) as mock_validate_snowpark_data, mock.patch.object( + self.m_ops._model_version_client, "invoke_method", return_value=m_df + ) as mock_invoke_method, mock.patch.object( + snowpark_handler.SnowparkDataFrameHandler, "convert_to_df" + ) as mock_convert_to_df: + self.m_ops.invoke_method( + method_name=sql_identifier.SqlIdentifier("PREDICT"), + signature=m_sig, + X=cast(DataFrame, m_df), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + strict_input_validation=True, + statement_params=self.m_statement_params, + ) + mock_convert_from_df.assert_not_called() + mock_validate_snowpark_data.assert_called_once_with(m_df, m_sig.inputs, strict=True) mock_invoke_method.assert_called_once_with( method_name=sql_identifier.SqlIdentifier("PREDICT"), diff --git a/snowflake/ml/model/_client/sql/model_version.py b/snowflake/ml/model/_client/sql/model_version.py index 6e64d07c..a25ff9e2 100644 --- a/snowflake/ml/model/_client/sql/model_version.py +++ b/snowflake/ml/model/_client/sql/model_version.py @@ -111,11 +111,17 @@ def get_file( local_location = target_path.resolve().as_posix() local_location_url = f"file://{local_location}" - query_result_checker.SqlResultValidator( - self._session, - f"GET {_normalize_url_for_sql(stage_location_url)} {_normalize_url_for_sql(local_location_url)}", - statement_params=statement_params, - ).has_dimensions(expected_rows=1).validate() + if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call] + options = {"parallel": 10} + cursor = self._session._conn._cursor + cursor._download(stage_location_url, str(target_path), options) # type: ignore[attr-defined] + cursor.fetchall() + else: + query_result_checker.SqlResultValidator( + self._session, + f"GET {_normalize_url_for_sql(stage_location_url)} {_normalize_url_for_sql(local_location_url)}", + statement_params=statement_params, + ).has_dimensions(expected_rows=1).validate() return target_path / file_path.name def set_comment( diff --git a/snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template b/snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template index cafb6d89..32c7ba64 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +++ b/snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template @@ -3,27 +3,8 @@ FROM ${base_image} as build COPY ${model_env_folder}/conda.yml conda.yml COPY ${model_env_folder}/requirements.txt requirements.txt - -# Set MAMBA_DOCKERFILE_ACTIVATE=1 to activate the conda environment during build time. -ARG MAMBA_DOCKERFILE_ACTIVATE=1 - -# Bitsandbytes uses this ENVVAR to determine CUDA library location -ENV CONDA_PREFIX=/opt/conda - -# The micromamba image comes with an empty environment named base. -# CONDA_OVERRIDE_CUDA ref https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-virtual.html -RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="${cuda_override_env}" \ - micromamba install -y -n base -f conda.yml && \ - python -m pip install "uvicorn[standard]" gunicorn starlette==0.30.0 && \ - python -m pip install -r requirements.txt && \ - micromamba clean -afy - COPY ${inference_server_dir} ./${inference_server_dir} COPY ${entrypoint_script} ./${entrypoint_script} -${copy_model_statement} - -${extra_env_statement} - USER root RUN if id mambauser >/dev/null 2>&1; then \ @@ -40,9 +21,31 @@ RUN if id mambauser >/dev/null 2>&1; then \ --home $HOME \ $USER; \ fi + +RUN chmod +rx conda.yml +RUN chmod +rx requirements.txt RUN chmod +x ./${entrypoint_script} + USER mambauser +# Set MAMBA_DOCKERFILE_ACTIVATE=1 to activate the conda environment during build time. +ARG MAMBA_DOCKERFILE_ACTIVATE=1 + +# Bitsandbytes uses this ENVVAR to determine CUDA library location +ENV CONDA_PREFIX=/opt/conda + +# The micromamba image comes with an empty environment named base. +# CONDA_OVERRIDE_CUDA ref https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-virtual.html +RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="${cuda_override_env}" \ + micromamba install -y -n base -f conda.yml && \ + python -m pip install "uvicorn[standard]" gunicorn starlette==0.30.0 && \ + python -m pip install -r requirements.txt && \ + micromamba clean -afy + +${copy_model_statement} + +${extra_env_statement} + # Expose the port on which the Starlette app will run. EXPOSE 5000 diff --git a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture index 49cef93d..ab6a9ecd 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture +++ b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture @@ -3,21 +3,9 @@ FROM mambaorg/micromamba:1.4.3 as build COPY env/conda.yml conda.yml COPY env/requirements.txt requirements.txt -ARG MAMBA_DOCKERFILE_ACTIVATE=1 -ENV CONDA_PREFIX=/opt/conda -RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="" \ - micromamba install -y -n base -f conda.yml && \ - python -m pip install "uvicorn[standard]" gunicorn starlette==0.30.0 && \ - python -m pip install -r requirements.txt && \ - micromamba clean -afy - COPY inference_server ./inference_server COPY gunicorn_run.sh ./gunicorn_run.sh - - - - USER root RUN if id mambauser >/dev/null 2>&1; then \ echo "mambauser already exists."; \ @@ -32,8 +20,19 @@ RUN if id mambauser >/dev/null 2>&1; then \ --home $HOME \ $USER; \ fi + +RUN chmod +rx conda.yml +RUN chmod +rx requirements.txt RUN chmod +x ./gunicorn_run.sh + USER mambauser +ARG MAMBA_DOCKERFILE_ACTIVATE=1 +ENV CONDA_PREFIX=/opt/conda +RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="" \ + micromamba install -y -n base -f conda.yml && \ + python -m pip install "uvicorn[standard]" gunicorn starlette==0.30.0 && \ + python -m pip install -r requirements.txt && \ + micromamba clean -afy EXPOSE 5000 CMD ["./gunicorn_run.sh"] diff --git a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_CUDA b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_CUDA index c17d8329..7eb9358f 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_CUDA +++ b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_CUDA @@ -3,21 +3,9 @@ FROM mambaorg/micromamba:1.4.3 as build COPY env/conda.yml conda.yml COPY env/requirements.txt requirements.txt -ARG MAMBA_DOCKERFILE_ACTIVATE=1 -ENV CONDA_PREFIX=/opt/conda -RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="11.7" \ - micromamba install -y -n base -f conda.yml && \ - python -m pip install "uvicorn[standard]" gunicorn starlette==0.30.0 && \ - python -m pip install -r requirements.txt && \ - micromamba clean -afy - COPY inference_server ./inference_server COPY gunicorn_run.sh ./gunicorn_run.sh - - - - USER root RUN if id mambauser >/dev/null 2>&1; then \ echo "mambauser already exists."; \ @@ -32,8 +20,19 @@ RUN if id mambauser >/dev/null 2>&1; then \ --home $HOME \ $USER; \ fi + +RUN chmod +rx conda.yml +RUN chmod +rx requirements.txt RUN chmod +x ./gunicorn_run.sh + USER mambauser +ARG MAMBA_DOCKERFILE_ACTIVATE=1 +ENV CONDA_PREFIX=/opt/conda +RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="11.7" \ + micromamba install -y -n base -f conda.yml && \ + python -m pip install "uvicorn[standard]" gunicorn starlette==0.30.0 && \ + python -m pip install -r requirements.txt && \ + micromamba clean -afy EXPOSE 5000 CMD ["./gunicorn_run.sh"] diff --git a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_model b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_model index 1ace1d46..a73c5173 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_model +++ b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_model @@ -3,20 +3,8 @@ FROM mambaorg/micromamba:1.4.3 as build COPY env/conda.yml conda.yml COPY env/requirements.txt requirements.txt -ARG MAMBA_DOCKERFILE_ACTIVATE=1 -ENV CONDA_PREFIX=/opt/conda -RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="11.7" \ - micromamba install -y -n base -f conda.yml && \ - python -m pip install "uvicorn[standard]" gunicorn starlette==0.30.0 && \ - python -m pip install -r requirements.txt && \ - micromamba clean -afy - COPY inference_server ./inference_server COPY gunicorn_run.sh ./gunicorn_run.sh -COPY model.zip ./model_repo/model.zip - -ENV MODEL_ZIP_STAGE_PATH=model_repo/model.zip - USER root RUN if id mambauser >/dev/null 2>&1; then \ @@ -32,8 +20,23 @@ RUN if id mambauser >/dev/null 2>&1; then \ --home $HOME \ $USER; \ fi + +RUN chmod +rx conda.yml +RUN chmod +rx requirements.txt RUN chmod +x ./gunicorn_run.sh + USER mambauser +ARG MAMBA_DOCKERFILE_ACTIVATE=1 +ENV CONDA_PREFIX=/opt/conda +RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="11.7" \ + micromamba install -y -n base -f conda.yml && \ + python -m pip install "uvicorn[standard]" gunicorn starlette==0.30.0 && \ + python -m pip install -r requirements.txt && \ + micromamba clean -afy + +COPY model.zip ./model_repo/model.zip + +ENV MODEL_ZIP_STAGE_PATH=model_repo/model.zip EXPOSE 5000 CMD ["./gunicorn_run.sh"] diff --git a/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py b/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py index 9406ddae..f9f3164c 100644 --- a/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py +++ b/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py @@ -91,7 +91,10 @@ def test_get_model_final_packages(self) -> None: c_session = cast(session.Session, self.m_session) final_packages = deploy._get_model_final_packages(meta, c_session) - self.assertListEqual(final_packages, list(map(str, _BASIC_DEPENDENCIES_FINAL_PACKAGES))) + self.assertListEqual( + final_packages, + list(map(str, map(env_utils.relax_requirement_version, _BASIC_DEPENDENCIES_FINAL_PACKAGES))), + ) def test_get_model_final_packages_no_relax(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: diff --git a/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel b/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel index bef4b9c7..1bdb7151 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel +++ b/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel @@ -9,6 +9,7 @@ filegroup( "fixtures/MANIFEST_1.yml", "fixtures/MANIFEST_2.yml", "fixtures/MANIFEST_3.yml", + "fixtures/MANIFEST_4.yml", ], ) diff --git a/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_4.yml b/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_4.yml new file mode 100644 index 00000000..37941023 --- /dev/null +++ b/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_4.yml @@ -0,0 +1,73 @@ +manifest_version: '1.0' +methods: +- handler: functions.predict.infer + inputs: + - name: INPUT_1 + type: FLOAT + - name: INPUT_2 + type: ARRAY + - name: INPUT_3 + type: ARRAY + - name: INPUT_4 + type: ARRAY + name: PREDICT + outputs: + - name: OUTPUT_1 + type: FLOAT + - name: OUTPUT_2 + type: ARRAY + - name: OUTPUT_3 + type: ARRAY + - name: OUTPUT_4 + type: ARRAY + runtime: python_runtime + type: TABLE_FUNCTION +runtimes: + python_runtime: + dependencies: + conda: runtimes/python_runtime/env/conda.yml + imports: + - model.zip + - runtimes/python_runtime/snowflake-ml-python.zip + language: PYTHON + version: '3.8' +user_data: + snowpark_ml_data: + functions: + - name: PREDICT + signature: + inputs: + - name: input_1 + type: FLOAT + - name: input_2 + shape: + - -1 + type: FLOAT + - name: input_3 + shape: + - -1 + type: FLOAT + - name: input_4 + shape: + - -1 + type: FLOAT + outputs: + - name: output_1 + type: FLOAT + - name: output_2 + shape: + - 2 + - 2 + type: FLOAT + - name: output_3 + shape: + - 2 + - 2 + type: FLOAT + - name: output_4 + shape: + - 2 + - 2 + type: FLOAT + target_method: predict + schema_version: '2024-02-01' diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py index 8c303be0..df8ea8c4 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py @@ -1,6 +1,6 @@ # This files contains schema definition of what will be written into MANIFEST.yml -from typing import Any, Dict, List, Literal, TypedDict +from typing import Any, Dict, List, Literal, TypedDict, Union from typing_extensions import NotRequired, Required @@ -34,10 +34,10 @@ class ModelMethodSignatureFieldWithName(ModelMethodSignatureField): class ModelFunctionMethodDict(TypedDict): name: Required[str] runtime: Required[str] - type: Required[Literal["FUNCTION"]] + type: Required[str] handler: Required[str] inputs: Required[List[ModelMethodSignatureFieldWithName]] - outputs: Required[List[ModelMethodSignatureField]] + outputs: Required[Union[List[ModelMethodSignatureField], List[ModelMethodSignatureFieldWithName]]] ModelMethodDict = ModelFunctionMethodDict diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py index 24409fe1..008a540d 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py @@ -208,6 +208,53 @@ def test_model_manifest_bad(self) -> None: pathlib.PurePosixPath("model.zip"), ) + def test_model_manifest_table_function(self) -> None: + with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: + mm = model_manifest.ModelManifest(pathlib.Path(workspace)) + with model_meta.create_model_metadata( + model_dir_path=tmpdir, + name="model1", + model_type="custom", + signatures={"predict": _DUMMY_SIG["predict"]}, + python_version="3.8", + ) as meta: + meta.models["model1"] = _DUMMY_BLOB + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ): + mm.save( + self.m_session, + meta, + pathlib.PurePosixPath("model.zip"), + options=type_hints.BaseModelSaveOption( + method_options={ + "predict": type_hints.ModelMethodSaveOptions(function_type="TABLE_FUNCTION") + } + ), + ) + with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: + self.assertEqual( + ( + importlib_resources.files("snowflake.ml.model._model_composer.model_manifest") + .joinpath("fixtures") # type: ignore[no-untyped-call] + .joinpath("MANIFEST_4.yml") + .read_text() + ), + f.read(), + ) + with open(pathlib.Path(workspace, "functions", "predict.py"), encoding="utf-8") as f: + self.assertEqual( + ( + importlib_resources.files("snowflake.ml.model._model_composer.model_method") + .joinpath("fixtures") # type: ignore[no-untyped-call] + .joinpath("function_3.py") + .read_text() + ), + f.read(), + ) + def test_load(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: with open(os.path.join(tmpdir, "MANIFEST.yml"), "w", encoding="utf-8") as f: diff --git a/snowflake/ml/model/_model_composer/model_method/BUILD.bazel b/snowflake/ml/model/_model_composer/model_method/BUILD.bazel index e5bcc0b2..a7e58813 100644 --- a/snowflake/ml/model/_model_composer/model_method/BUILD.bazel +++ b/snowflake/ml/model/_model_composer/model_method/BUILD.bazel @@ -7,6 +7,7 @@ filegroup( srcs = [ "fixtures/function_1.py", "fixtures/function_2.py", + "fixtures/function_3.py", ], ) @@ -15,6 +16,7 @@ py_library( srcs = ["function_generator.py"], data = [ "infer_function.py_template", + "infer_table_function.py_template", ], deps = [ "//snowflake/ml/model:type_hints", @@ -29,6 +31,7 @@ py_test( ], deps = [ ":function_generator", + ":model_method", ], ) diff --git a/snowflake/ml/model/_model_composer/model_method/fixtures/function_3.py b/snowflake/ml/model/_model_composer/model_method/fixtures/function_3.py new file mode 100644 index 00000000..69cbe908 --- /dev/null +++ b/snowflake/ml/model/_model_composer/model_method/fixtures/function_3.py @@ -0,0 +1,76 @@ +import fcntl +import functools +import inspect +import os +import sys +import threading +import zipfile +from types import TracebackType +from typing import Optional, Type + +import anyio +import pandas as pd +from _snowflake import vectorized + +from snowflake.ml.model._packager import model_packager + + +class FileLock: + def __enter__(self) -> None: + self._lock = threading.Lock() + self._lock.acquire() + self._fd = open("/tmp/lockfile.LOCK", "w+") + fcntl.lockf(self._fd, fcntl.LOCK_EX) + + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> None: + self._fd.close() + self._lock.release() + + +# User-defined parameters +MODEL_FILE_NAME = "model.zip" +TARGET_METHOD = "predict" +MAX_BATCH_SIZE = None + + +# Retrieve the model +IMPORT_DIRECTORY_NAME = "snowflake_import_directory" +import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME] + +model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0] +zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME) +extracted = "/tmp/models" +extracted_model_dir_path = os.path.join(extracted, model_dir_name) + +with FileLock(): + if not os.path.isdir(extracted_model_dir_path): + with zipfile.ZipFile(zip_model_path, "r") as myzip: + myzip.extractall(extracted_model_dir_path) + +# Load the model +pk = model_packager.ModelPackager(extracted_model_dir_path) +pk.load(as_custom_model=True) +assert pk.model, "model is not loaded" +assert pk.meta, "model metadata is not loaded" + +# Determine the actual runner +model = pk.model +meta = pk.meta +func = getattr(model, TARGET_METHOD) +if inspect.iscoroutinefunction(func): + runner = functools.partial(anyio.run, func) +else: + runner = functools.partial(func) + +# Determine preprocess parameters +features = meta.signatures[TARGET_METHOD].inputs +input_cols = [feature.name for feature in features] +dtype_map = {feature.name: feature.as_dtype() for feature in features} + + +# Actual table function +class infer: + def end_partition(df: pd.DataFrame) -> pd.DataFrame: + return runner(df) diff --git a/snowflake/ml/model/_model_composer/model_method/function_generator.py b/snowflake/ml/model/_model_composer/model_method/function_generator.py index 1cc23e22..b75cf7f9 100644 --- a/snowflake/ml/model/_model_composer/model_method/function_generator.py +++ b/snowflake/ml/model/_model_composer/model_method/function_generator.py @@ -8,13 +8,17 @@ class FunctionGenerateOptions(TypedDict): max_batch_size: NotRequired[Optional[int]] + function_type: NotRequired[str] def get_function_generate_options_from_options( options: type_hints.ModelSaveOption, target_method: str ) -> FunctionGenerateOptions: - method_option = options.get("method_options", {}).get(target_method, {}) - return FunctionGenerateOptions(max_batch_size=method_option.get("max_batch_size", None)) + method_options = options.get("method_options", {}).get(target_method, {}) + return FunctionGenerateOptions( + max_batch_size=method_options.get("max_batch_size", None), + function_type=method_options.get("function_type", "function"), + ) class FunctionGenerator: @@ -30,15 +34,19 @@ def generate( self, function_file_path: pathlib.Path, target_method: str, + function_type: str, options: Optional[FunctionGenerateOptions] = None, ) -> None: import importlib_resources if options is None: options = {} + + template_filename = f"infer_{function_type.lower()}.py_template" + function_template = ( importlib_resources.files("snowflake.ml.model._model_composer.model_method") - .joinpath("infer_function.py_template") # type: ignore[no-untyped-call] + .joinpath(template_filename) # type: ignore[no-untyped-call] .read_text() ) diff --git a/snowflake/ml/model/_model_composer/model_method/function_generator_test.py b/snowflake/ml/model/_model_composer/model_method/function_generator_test.py index 1776963b..6f63b4f9 100644 --- a/snowflake/ml/model/_model_composer/model_method/function_generator_test.py +++ b/snowflake/ml/model/_model_composer/model_method/function_generator_test.py @@ -4,16 +4,21 @@ import importlib_resources from absl.testing import absltest -from snowflake.ml.model._model_composer.model_method import function_generator +from snowflake.ml.model._model_composer.model_method import ( + function_generator, + model_method, +) class FunctionGeneratorTest(absltest.TestCase): def test_function_generator(self) -> None: fg = function_generator.FunctionGenerator(pathlib.PurePosixPath("@a.b.c/abc/model.zip")) with tempfile.TemporaryDirectory() as tmpdir: + # Generate standard function. fg.generate( pathlib.Path(tmpdir, "handler.py"), "predict", + model_method.ModelMethodFunctionTypes.FUNCTION.value, ) with open(pathlib.Path(tmpdir, "handler.py"), encoding="utf-8") as f: self.assertEqual( @@ -25,10 +30,15 @@ def test_function_generator(self) -> None: ), f.read(), ) + + # Generate function with `__call__` and `max_batch_size`. fg.generate( pathlib.Path(tmpdir, "another_handler.py"), "__call__", - options=function_generator.FunctionGenerateOptions(max_batch_size=10), + model_method.ModelMethodFunctionTypes.FUNCTION.value, + options=function_generator.FunctionGenerateOptions( + max_batch_size=10, + ), ) with open(pathlib.Path(tmpdir, "another_handler.py"), encoding="utf-8") as f: self.assertEqual( @@ -41,6 +51,23 @@ def test_function_generator(self) -> None: f.read(), ) + # Generate table function. + fg.generate( + pathlib.Path(tmpdir, "table_function_handler.py"), + "predict", + model_method.ModelMethodFunctionTypes.TABLE_FUNCTION.value, + ) + with open(pathlib.Path(tmpdir, "table_function_handler.py"), encoding="utf-8") as f: + self.assertEqual( + ( + importlib_resources.files("snowflake.ml.model._model_composer.model_method") + .joinpath("fixtures") # type: ignore[no-untyped-call] + .joinpath("function_3.py") + .read_text() + ), + f.read(), + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template b/snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template new file mode 100644 index 00000000..6cf00bba --- /dev/null +++ b/snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template @@ -0,0 +1,76 @@ +import fcntl +import functools +import inspect +import os +import sys +import threading +import zipfile +from types import TracebackType +from typing import Optional, Type + +import anyio +import pandas as pd +from _snowflake import vectorized + +from snowflake.ml.model._packager import model_packager + + +class FileLock: + def __enter__(self) -> None: + self._lock = threading.Lock() + self._lock.acquire() + self._fd = open("/tmp/lockfile.LOCK", "w+") + fcntl.lockf(self._fd, fcntl.LOCK_EX) + + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> None: + self._fd.close() + self._lock.release() + + +# User-defined parameters +MODEL_FILE_NAME = "{model_file_name}" +TARGET_METHOD = "{target_method}" +MAX_BATCH_SIZE = {max_batch_size} + + +# Retrieve the model +IMPORT_DIRECTORY_NAME = "snowflake_import_directory" +import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME] + +model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0] +zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME) +extracted = "/tmp/models" +extracted_model_dir_path = os.path.join(extracted, model_dir_name) + +with FileLock(): + if not os.path.isdir(extracted_model_dir_path): + with zipfile.ZipFile(zip_model_path, "r") as myzip: + myzip.extractall(extracted_model_dir_path) + +# Load the model +pk = model_packager.ModelPackager(extracted_model_dir_path) +pk.load(as_custom_model=True) +assert pk.model, "model is not loaded" +assert pk.meta, "model metadata is not loaded" + +# Determine the actual runner +model = pk.model +meta = pk.meta +func = getattr(model, TARGET_METHOD) +if inspect.iscoroutinefunction(func): + runner = functools.partial(anyio.run, func) +else: + runner = functools.partial(func) + +# Determine preprocess parameters +features = meta.signatures[TARGET_METHOD].inputs +input_cols = [feature.name for feature in features] +dtype_map = {{feature.name: feature.as_dtype() for feature in features}} + + +# Actual table function +class {function_name}: + def end_partition(df: pd.DataFrame) -> pd.DataFrame: + return runner(df) diff --git a/snowflake/ml/model/_model_composer/model_method/model_method.py b/snowflake/ml/model/_model_composer/model_method/model_method.py index 6d5a9b16..5f15f09a 100644 --- a/snowflake/ml/model/_model_composer/model_method/model_method.py +++ b/snowflake/ml/model/_model_composer/model_method/model_method.py @@ -1,6 +1,7 @@ import collections +import enum import pathlib -from typing import Optional, TypedDict +from typing import List, Optional, TypedDict, Union from typing_extensions import NotRequired @@ -12,20 +13,37 @@ from snowflake.snowpark._internal import type_utils +class ModelMethodFunctionTypes(enum.Enum): + FUNCTION = "FUNCTION" + TABLE_FUNCTION = "TABLE_FUNCTION" + + class ModelMethodOptions(TypedDict): """Options when creating model method. case_sensitive: Specify when the name of the method should be considered as case sensitive when registered to SQL. + function_type: One of `ModelMethodFunctionTypes` specifying function type. """ case_sensitive: NotRequired[bool] + function_type: NotRequired[str] def get_model_method_options_from_options( options: type_hints.ModelSaveOption, target_method: str ) -> ModelMethodOptions: method_option = options.get("method_options", {}).get(target_method, {}) - return ModelMethodOptions(case_sensitive=method_option.get("case_sensitive", False)) + + function_type = method_option.get("function_type", ModelMethodFunctionTypes.FUNCTION.value) + if function_type not in [function_type.value for function_type in ModelMethodFunctionTypes]: + raise NotImplementedError + + # TODO(TH): enforce minimum snowflake version + + return ModelMethodOptions( + case_sensitive=method_option.get("case_sensitive", False), + function_type=function_type, + ) class ModelMethod: @@ -71,6 +89,8 @@ def __init__( if self.target_method not in self.model_meta.signatures.keys(): raise ValueError(f"Target method {self.target_method} is not available in the signatures of the model.") + self.function_type = self.options.get("function_type", ModelMethodFunctionTypes.FUNCTION.value) + @staticmethod def _get_method_arg_from_feature( feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False @@ -94,6 +114,7 @@ def save( self.function_generator.generate( workspace_path / ModelMethod.FUNCTIONS_DIR_REL_PATH / f"{self.target_method}.py", self.target_method, + self.function_type, options=options, ) input_list = [ @@ -109,13 +130,25 @@ def save( "In this case, set case_sensitive as True for those methods to distinguish them." ) + outputs: Union[ + List[model_manifest_schema.ModelMethodSignatureField], + List[model_manifest_schema.ModelMethodSignatureFieldWithName], + ] + if self.function_type == ModelMethodFunctionTypes.TABLE_FUNCTION.value: + outputs = [ + ModelMethod._get_method_arg_from_feature(ft, case_sensitive=self.options.get("case_sensitive", False)) + for ft in self.model_meta.signatures[self.target_method].outputs + ] + else: + outputs = [model_manifest_schema.ModelMethodSignatureField(type="OBJECT")] + return model_manifest_schema.ModelFunctionMethodDict( name=self.method_name.resolved(), runtime=self.runtime_name, - type="FUNCTION", + type=self.function_type, handler=".".join( [ModelMethod.FUNCTIONS_DIR_REL_PATH, self.target_method, self.function_generator.FUNCTION_NAME] ), inputs=input_list, - outputs=[model_manifest_schema.ModelMethodSignatureField(type="OBJECT")], + outputs=outputs, ) diff --git a/snowflake/ml/model/_model_composer/model_method/model_method_test.py b/snowflake/ml/model/_model_composer/model_method/model_method_test.py index c6b6e45c..fc8b8af4 100644 --- a/snowflake/ml/model/_model_composer/model_method/model_method_test.py +++ b/snowflake/ml/model/_model_composer/model_method/model_method_test.py @@ -169,6 +169,46 @@ def test_model_method(self) -> None: }, ) + with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: + with model_meta.create_model_metadata( + model_dir_path=tmpdir, name="model1", model_type="custom", signatures=_DUMMY_SIG + ) as meta: + meta.models["model1"] = _DUMMY_BLOB + + mm = model_method.ModelMethod( + meta, + "predict", + "python_runtime", + fg, + options=model_method.ModelMethodOptions( + function_type=model_method.ModelMethodFunctionTypes.TABLE_FUNCTION.value + ), + ) + method_dict = mm.save( + pathlib.Path(workspace), + ) + with open(pathlib.Path(workspace, "functions", "predict.py"), encoding="utf-8") as f: + self.assertEqual( + ( + importlib_resources.files(model_method_pkg) + .joinpath("fixtures") # type: ignore[no-untyped-call] + .joinpath("function_3.py") + .read_text() + ), + f.read(), + ) + self.assertDictEqual( + method_dict, + { + "name": "PREDICT", + "runtime": "python_runtime", + "type": "TABLE_FUNCTION", + "handler": "functions.predict.infer", + "inputs": [{"name": "INPUT", "type": "FLOAT"}, {"name": "NAME", "type": "STRING"}], + "outputs": [{"name": "OUTPUT", "type": "FLOAT"}], + }, + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_model_composer/model_runtime/model_runtime.py b/snowflake/ml/model/_model_composer/model_runtime/model_runtime.py index 0f02ea73..6dce86b3 100644 --- a/snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +++ b/snowflake/ml/model/_model_composer/model_runtime/model_runtime.py @@ -11,7 +11,10 @@ from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api from snowflake.snowpark import session -_UDF_INFERENCE_DEPENDENCIES = _runtime_requirements.REQUIREMENTS +_UDF_INFERENCE_DEPENDENCIES = [ + str(env_utils.get_package_spec_with_supported_ops_only(requirements.Requirement(r))) + for r in _runtime_requirements.REQUIREMENTS +] class ModelRuntime: diff --git a/snowflake/ml/model/_model_composer/model_runtime/model_runtime_test.py b/snowflake/ml/model/_model_composer/model_runtime/model_runtime_test.py index cda06a09..65135595 100644 --- a/snowflake/ml/model/_model_composer/model_runtime/model_runtime_test.py +++ b/snowflake/ml/model/_model_composer/model_runtime/model_runtime_test.py @@ -31,6 +31,15 @@ ) ) +_BASIC_DEPENDENCIES_TARGET_RELAXED = list( + sorted( + map( + lambda x: str(env_utils.relax_requirement_version(requirements.Requirement(x))), + model_runtime._UDF_INFERENCE_DEPENDENCIES, + ) + ) +) + _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML = list( sorted( list(map(lambda x: str(requirements.Requirement(x)), model_runtime._UDF_INFERENCE_DEPENDENCIES)) @@ -44,6 +53,22 @@ ) ) +_BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED = list( + sorted( + map( + lambda x: str(env_utils.relax_requirement_version(requirements.Requirement(x))), + model_runtime._UDF_INFERENCE_DEPENDENCIES + + [ + str( + env_utils.get_local_installed_version_of_pip_package( + requirements.Requirement(env_utils.SNOWPARK_ML_PKG_NAME) + ) + ) + ], + ) + ) +) + class ModelRuntimeTest(absltest.TestCase): def setUp(self) -> None: @@ -78,7 +103,9 @@ def test_model_runtime(self) -> None: with open(os.path.join(workspace, "runtimes/python_runtime/env/conda.yml"), encoding="utf-8") as f: dependencies = yaml.safe_load(f) - self.assertContainsSubset(_BASIC_DEPENDENCIES_TARGET_WITH_SNOWML, dependencies["dependencies"]) + self.assertContainsSubset( + _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED, dependencies["dependencies"] + ) def test_model_runtime_local_snowml(self) -> None: with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as workspace: @@ -121,7 +148,7 @@ def test_model_runtime_dup_basic_dep(self) -> None: conda_dependencies=["packaging"], ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] + dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] dep_target.remove(next(filter(lambda x: x.startswith("packaging"), dep_target))) dep_target.append("packaging") dep_target.sort() @@ -151,7 +178,7 @@ def test_model_runtime_dup_basic_dep_other_channel(self) -> None: conda_dependencies=["conda-forge::packaging"], ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] + dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] dep_target.remove(next(filter(lambda x: x.startswith("packaging"), dep_target))) dep_target.append("conda-forge::packaging") dep_target.sort() @@ -181,7 +208,7 @@ def test_model_runtime_dup_basic_dep_pip(self) -> None: pip_requirements=["packaging"], ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] + dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] dep_target.remove(next(filter(lambda x: x.startswith("packaging"), dep_target))) dep_target.sort() @@ -210,7 +237,7 @@ def test_model_runtime_additional_conda_dep(self) -> None: conda_dependencies=["pytorch"], ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] + dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] dep_target.append("pytorch") dep_target.sort() @@ -239,7 +266,7 @@ def test_model_runtime_additional_pip_dep(self) -> None: pip_requirements=["torch"], ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] + dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] dep_target.sort() with mock.patch.object( @@ -268,7 +295,7 @@ def test_model_runtime_additional_dep_both(self) -> None: pip_requirements=["torch"], ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] + dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] dep_target.append("pytorch") dep_target.sort() diff --git a/snowflake/ml/model/_packager/model_handlers_test/custom_test.py b/snowflake/ml/model/_packager/model_handlers_test/custom_test.py index c5716eaf..3bfa6fd7 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/custom_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/custom_test.py @@ -126,6 +126,7 @@ def test_custom_model_with_multiple_artifacts(self) -> None: model=lm, sample_input=d, metadata={"author": "halu", "version": "1"}, + options={"relax_version": False}, ) pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) diff --git a/snowflake/ml/model/_packager/model_meta/model_meta.py b/snowflake/ml/model/_packager/model_meta/model_meta.py index b2be8567..5d0aeb34 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta.py @@ -1,8 +1,8 @@ -import importlib import os import pathlib import sys import tempfile +import warnings import zipfile from contextlib import contextmanager from datetime import datetime @@ -27,8 +27,14 @@ MODEL_METADATA_FILE = "model.yaml" MODEL_CODE_DIR = "code" -_PACKAGING_CORE_DEPENDENCIES = _core_requirements.REQUIREMENTS # Legacy Model only -_PACKAGING_REQUIREMENTS = _packaging_requirements.REQUIREMENTS # New Model only +_PACKAGING_CORE_DEPENDENCIES = [ + str(env_utils.get_package_spec_with_supported_ops_only(requirements.Requirement(r))) + for r in _core_requirements.REQUIREMENTS +] # Legacy Model only +_PACKAGING_REQUIREMENTS = [ + str(env_utils.get_package_spec_with_supported_ops_only(requirements.Requirement(r))) + for r in _packaging_requirements.REQUIREMENTS +] # New Model only _SNOWFLAKE_PKG_NAME = "snowflake" _SNOWFLAKE_ML_PKG_NAME = f"{_SNOWFLAKE_PKG_NAME}.ml" @@ -75,7 +81,17 @@ def create_model_metadata( model_dir_path = os.path.normpath(model_dir_path) embed_local_ml_library = kwargs.pop("embed_local_ml_library", False) legacy_save = kwargs.pop("_legacy_save", False) - relax_version = kwargs.pop("relax_version", False) + if "relax_version" not in kwargs: + warnings.warn( + ( + "`relax_version` is not set and therefore defaulted to True. Dependency version constraints relaxed " + "from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility, " + "reproducibility, etc., set `options={'relax_version': False}` when logging the model." + ), + category=UserWarning, + stacklevel=2, + ) + relax_version = kwargs.pop("relax_version", True) if embed_local_ml_library: # Use the last one which is loaded first, that is mean, it is loaded from site-packages. @@ -200,22 +216,6 @@ def load_code_path(model_dir_path: str) -> None: if code_path in sys.path: sys.path.remove(code_path) sys.path.insert(0, code_path) - module_names = file_utils.get_all_modules(code_path) - # If the module_name starts with snowflake, then do not replace it. - # When deploying, we would add them beforehand. - # When in the local, they should not be added. We already prevent user from overwriting us. - module_names = [ - module_name - for module_name in module_names - if not (module_name.startswith(f"{_SNOWFLAKE_PKG_NAME}.") or module_name == _SNOWFLAKE_PKG_NAME) - ] - for module_name in module_names: - actual_module = sys.modules.pop(module_name, None) - if actual_module is not None: - sys.modules[module_name] = importlib.import_module(module_name) - - assert code_path in sys.path - sys.path.remove(code_path) class ModelMetadata: diff --git a/snowflake/ml/model/_packager/model_meta/model_meta_test.py b/snowflake/ml/model/_packager/model_meta/model_meta_test.py index 30ecbbe7..e6f1cabd 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta_test.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta_test.py @@ -32,6 +32,19 @@ ) ) +_BASIC_DEPENDENCIES_TARGET_RELAXED = list( + sorted( + map( + lambda x: str( + env_utils.relax_requirement_version( + env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x)) + ) + ), + model_meta._PACKAGING_CORE_DEPENDENCIES, + ) + ) +) + _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML = list( sorted( map( @@ -41,6 +54,19 @@ ) ) +_BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED = list( + sorted( + map( + lambda x: str( + env_utils.relax_requirement_version( + env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x)) + ) + ), + model_meta._PACKAGING_CORE_DEPENDENCIES + [env_utils.SNOWPARK_ML_PKG_NAME], + ) + ) +) + _PACKAGING_REQUIREMENTS_TARGET = list( sorted( map( @@ -50,6 +76,19 @@ ) ) +_PACKAGING_REQUIREMENTS_TARGET_RELAXED = list( + sorted( + map( + lambda x: str( + env_utils.relax_requirement_version( + env_utils.get_local_installed_version_of_pip_package(requirements.Requirement(x)) + ) + ), + model_meta._PACKAGING_REQUIREMENTS, + ) + ) +) + _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML = list( sorted( map( @@ -81,13 +120,13 @@ def test_model_meta_dependencies_no_packages(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML) + self.assertListEqual(meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED) self.assertEqual(meta.env.snowpark_ml_version, snowml_env.VERSION) loaded_meta = model_meta.ModelMetadata.load(tmpdir) self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML) + self.assertListEqual(loaded_meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED) self.assertEqual(meta.env.snowpark_ml_version, snowml_env.VERSION) def test_model_meta_dependencies_no_packages_embedded_snowml(self) -> None: @@ -102,13 +141,13 @@ def test_model_meta_dependencies_no_packages_embedded_snowml(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET) + self.assertListEqual(meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET_RELAXED) self.assertIsNotNone(meta.env._snowpark_ml_version.local) loaded_meta = model_meta.ModelMetadata.load(tmpdir) self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET) + self.assertListEqual(loaded_meta.env.conda_dependencies, _BASIC_DEPENDENCIES_TARGET_RELAXED) self.assertIsNotNone(meta.env._snowpark_ml_version.local) def test_model_meta_dependencies_dup_basic_dep(self) -> None: @@ -119,6 +158,7 @@ def test_model_meta_dependencies_dup_basic_dep(self) -> None: model_type="custom", signatures=_DUMMY_SIG, conda_dependencies=["cloudpickle"], + relax_version=False, _legacy_save=True, ) as meta: meta.models["model1"] = _DUMMY_BLOB @@ -144,6 +184,7 @@ def test_model_meta_dependencies_dup_basic_dep_other_channel(self) -> None: model_type="custom", signatures=_DUMMY_SIG, conda_dependencies=["conda-forge::cloudpickle"], + relax_version=False, _legacy_save=True, ) as meta: meta.models["model1"] = _DUMMY_BLOB @@ -170,6 +211,7 @@ def test_model_meta_dependencies_dup_basic_dep_pip(self) -> None: model_type="custom", signatures=_DUMMY_SIG, pip_requirements=["cloudpickle"], + relax_version=False, _legacy_save=True, ) as meta: meta.models["model1"] = _DUMMY_BLOB @@ -197,7 +239,7 @@ def test_model_meta_dependencies_conda(self) -> None: _legacy_save=True, ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] + dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] dep_target.append("pytorch") dep_target.sort() @@ -220,7 +262,7 @@ def test_model_meta_dependencies_pip(self) -> None: _legacy_save=True, ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] + dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] dep_target.sort() self.assertListEqual(meta.env.pip_requirements, ["torch"]) @@ -243,7 +285,7 @@ def test_model_meta_dependencies_both(self) -> None: _legacy_save=True, ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] + dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML_RELAXED[:] dep_target.append("pytorch") dep_target.sort() @@ -264,13 +306,13 @@ def test_model_meta_dependencies_no_packages(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML) + self.assertListEqual(meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED) self.assertEqual(meta.env.snowpark_ml_version, snowml_env.VERSION) loaded_meta = model_meta.ModelMetadata.load(tmpdir) self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML) + self.assertListEqual(loaded_meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED) self.assertEqual(meta.env.snowpark_ml_version, snowml_env.VERSION) def test_model_meta_dependencies_relax_version(self) -> None: @@ -289,6 +331,15 @@ def test_model_meta_dependencies_relax_version(self) -> None: self.assertListEqual(loaded_meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED) self.assertEqual(meta.env.snowpark_ml_version, snowml_env.VERSION) + with self.assertWarnsRegex(UserWarning, "`relax_version` is not set and therefore defaulted to True."): + with model_meta.create_model_metadata( + model_dir_path=tmpdir, + name="model1", + model_type="custom", + signatures=_DUMMY_SIG, + ) as meta: + meta.models["model1"] = _DUMMY_BLOB + def test_model_meta_dependencies_no_packages_embedded_snowml(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: with model_meta.create_model_metadata( @@ -300,13 +351,13 @@ def test_model_meta_dependencies_no_packages_embedded_snowml(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB self.assertListEqual(meta.env.pip_requirements, []) - self.assertListEqual(meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET) + self.assertListEqual(meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_RELAXED) self.assertIsNotNone(meta.env._snowpark_ml_version.local) loaded_meta = model_meta.ModelMetadata.load(tmpdir) self.assertListEqual(loaded_meta.env.pip_requirements, []) - self.assertListEqual(loaded_meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET) + self.assertListEqual(loaded_meta.env.conda_dependencies, _PACKAGING_REQUIREMENTS_TARGET_RELAXED) self.assertIsNotNone(meta.env._snowpark_ml_version.local) def test_model_meta_dependencies_dup_basic_dep(self) -> None: @@ -317,6 +368,7 @@ def test_model_meta_dependencies_dup_basic_dep(self) -> None: model_type="custom", signatures=_DUMMY_SIG, conda_dependencies=["cloudpickle"], + relax_version=False, ) as meta: meta.models["model1"] = _DUMMY_BLOB dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML[:] @@ -341,6 +393,7 @@ def test_model_meta_dependencies_dup_basic_dep_other_channel(self) -> None: model_type="custom", signatures=_DUMMY_SIG, conda_dependencies=["conda-forge::cloudpickle"], + relax_version=False, ) as meta: meta.models["model1"] = _DUMMY_BLOB dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML[:] @@ -366,6 +419,7 @@ def test_model_meta_dependencies_dup_basic_dep_pip(self) -> None: model_type="custom", signatures=_DUMMY_SIG, pip_requirements=["cloudpickle"], + relax_version=False, ) as meta: meta.models["model1"] = _DUMMY_BLOB dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML[:] @@ -391,7 +445,7 @@ def test_model_meta_dependencies_conda(self) -> None: conda_dependencies=["pytorch"], ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML[:] + dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED[:] dep_target.append("pytorch") dep_target.sort() @@ -413,7 +467,7 @@ def test_model_meta_dependencies_pip(self) -> None: pip_requirements=["torch"], ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML[:] + dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED[:] dep_target.sort() self.assertListEqual(meta.env.pip_requirements, ["torch"]) @@ -435,7 +489,7 @@ def test_model_meta_dependencies_both(self) -> None: pip_requirements=["torch"], ) as meta: meta.models["model1"] = _DUMMY_BLOB - dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML[:] + dep_target = _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML_RELAXED[:] dep_target.append("pytorch") dep_target.sort() diff --git a/snowflake/ml/model/_packager/model_packager_test.py b/snowflake/ml/model/_packager/model_packager_test.py index a4416f33..08c6a5d8 100644 --- a/snowflake/ml/model/_packager/model_packager_test.py +++ b/snowflake/ml/model/_packager/model_packager_test.py @@ -1,4 +1,3 @@ -import importlib import os import sys import tempfile @@ -49,44 +48,6 @@ def get_file(): class ModelLoadHygieneTest(absltest.TestCase): - def test_model_load_hygiene(self) -> None: - with tempfile.TemporaryDirectory() as workspace: - with tempfile.TemporaryDirectory() as src_path: - fake_mod_dirpath = os.path.join(src_path, "fake", "fake_module") - os.makedirs(fake_mod_dirpath) - - py_file_path = os.path.join(fake_mod_dirpath, "p.py") - with open(py_file_path, "w", encoding="utf-8") as f: - f.write(PY_SRC) - f.flush() - - sys.path.insert(0, src_path) - - from fake.fake_module import p - - self.assertEqual(p.__file__, py_file_path) - - lm = DemoModel(context=custom_model.ModelContext(models={}, artifacts={})) - arr = np.array([[1, 2, 3], [4, 2, 5]]) - d = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - - model_packager.ModelPackager(os.path.join(workspace, "model1")).save( - name="model1", - model=lm, - sample_input=d, - metadata={"author": "halu", "version": "1"}, - code_paths=[os.path.join(src_path, "fake")], - ) - - model_packager.ModelPackager(os.path.join(workspace, "model1")).load() - from fake.fake_module import p - - self.assertEqual(p.__file__, os.path.join(workspace, "model1", "code", "fake", "fake_module", "p.py")) - - importlib.reload(p) - self.assertEqual(p.__file__, py_file_path) - sys.path.remove(src_path) - def test_model_save_validation(self) -> None: with tempfile.TemporaryDirectory() as workspace: with tempfile.TemporaryDirectory() as src_path: diff --git a/snowflake/ml/model/_signatures/snowpark_handler.py b/snowflake/ml/model/_signatures/snowpark_handler.py index c7ee6392..fc05dd7b 100644 --- a/snowflake/ml/model/_signatures/snowpark_handler.py +++ b/snowflake/ml/model/_signatures/snowpark_handler.py @@ -87,7 +87,10 @@ def convert_to_df( @staticmethod def convert_from_df( - session: snowflake.snowpark.Session, df: pd.DataFrame, keep_order: bool = False + session: snowflake.snowpark.Session, + df: pd.DataFrame, + keep_order: bool = False, + features: Optional[Sequence[core.BaseFeatureSpec]] = None, ) -> snowflake.snowpark.DataFrame: # This method is necessary to create the Snowpark Dataframe in correct schema. # However, in this case, the order could not be preserved. Thus, a _ID column has to be added, @@ -101,7 +104,8 @@ def convert_from_df( error_code=error_codes.NOT_IMPLEMENTED, original_exception=ValueError("Cannot convert a Pandas DataFrame whose column index is not a string"), ) - features = pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input") + if not features: + features = pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input") # Role will be no effect on the column index. That is to say, the feature name is the actual column name. sp_df = session.create_dataframe(df) column_names = [] diff --git a/snowflake/ml/model/_signatures/snowpark_test.py b/snowflake/ml/model/_signatures/snowpark_test.py index 242efb95..4d2779a7 100644 --- a/snowflake/ml/model/_signatures/snowpark_test.py +++ b/snowflake/ml/model/_signatures/snowpark_test.py @@ -133,7 +133,9 @@ def test_validate_data_with_features(self) -> None: with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, expected_regex="Feature type [^\\s]* is not met by column" ): - model_signature._validate_snowpark_data(df, fts) + model_signature._validate_snowpark_data(df, fts, strict=True) + + model_signature._validate_snowpark_data(df, fts) fts = [ core.FeatureSpec("a", core.DataType.INT8), @@ -143,7 +145,9 @@ def test_validate_data_with_features(self) -> None: with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, expected_regex="Feature type [^\\s]* is not met by column" ): - model_signature._validate_snowpark_data(df, fts) + model_signature._validate_snowpark_data(df, fts, strict=True) + + model_signature._validate_snowpark_data(df, fts) fts = [ core.FeatureSpec("a", core.DataType.INT16), @@ -185,7 +189,9 @@ def test_validate_data_with_features(self) -> None: with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, expected_regex="Feature type [^\\s]* is not met by column" ): - model_signature._validate_snowpark_data(df, fts) + model_signature._validate_snowpark_data(df, fts, strict=True) + + model_signature._validate_snowpark_data(df, fts) fts = [ core.FeatureSpec("a", core.DataType.UINT8), @@ -252,7 +258,9 @@ def test_validate_data_with_features(self) -> None: with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, expected_regex="Feature type [^\\s]* is not met by column" ): - model_signature._validate_snowpark_data(df, fts) + model_signature._validate_snowpark_data(df, fts, strict=True) + + model_signature._validate_snowpark_data(df, fts) fts = [ core.FeatureSpec("a", core.DataType.INT64), @@ -322,6 +330,17 @@ def test_convert_to_and_from_df(self) -> None: pd_df, snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sp_df), check_dtype=False ) + pd_df = pd.DataFrame([1, 2, 3, 4], columns=["col_0"]) + sp_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df( + self._session, + pd_df, + keep_order=False, + features=[model_signature.FeatureSpec(name="col_0", dtype=model_signature.DataType.INT64)], + ) + pd.testing.assert_frame_equal( + pd_df, snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sp_df), check_dtype=False + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/model_signature.py b/snowflake/ml/model/model_signature.py index 81db40d2..c1bc9877 100644 --- a/snowflake/ml/model/model_signature.py +++ b/snowflake/ml/model/model_signature.py @@ -139,7 +139,9 @@ def _rename_signature_with_snowflake_identifiers( return signature -def _validate_numpy_array(arr: model_types._SupportedNumpyArray, feature_type: core.DataType) -> bool: +def _validate_numpy_array( + arr: model_types._SupportedNumpyArray, feature_type: core.DataType, strict: bool = False +) -> bool: if feature_type in [ core.DataType.INT8, core.DataType.INT16, @@ -152,11 +154,15 @@ def _validate_numpy_array(arr: model_types._SupportedNumpyArray, feature_type: c ]: if not (np.issubdtype(arr.dtype, np.integer)): return False + if not strict: + return True min_v, max_v = arr.min(), arr.max() return bool(max_v <= np.iinfo(feature_type._numpy_type).max and min_v >= np.iinfo(feature_type._numpy_type).min) elif feature_type in [core.DataType.FLOAT, core.DataType.DOUBLE]: if not (np.issubdtype(arr.dtype, np.integer) or np.issubdtype(arr.dtype, np.floating)): return False + if not strict: + return True min_v, max_v = arr.min(), arr.max() return bool( max_v <= np.finfo(feature_type._numpy_type).max # type: ignore[arg-type] @@ -166,12 +172,13 @@ def _validate_numpy_array(arr: model_types._SupportedNumpyArray, feature_type: c return np.can_cast(arr.dtype, feature_type._numpy_type, casting="no") -def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureSpec]) -> None: +def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureSpec], strict: bool = False) -> None: """It validates pandas dataframe with provided features. Args: data: A pandas dataframe to be validated. features: A sequence of feature specifications and feature group specifications, where the dataframe should fit. + strict: Enable strict validation, this includes value range based validation Raises: SnowflakeMLException: NotImplementedError: FeatureGroupSpec is not supported. @@ -206,7 +213,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS ft_type = feature._dtype ft_shape = feature._shape if df_col_dtype != np.dtype("O"): - if not _validate_numpy_array(data_col.to_numpy(), ft_type): + if not _validate_numpy_array(data_col.to_numpy(), ft_type, strict=strict): raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_DATA, original_exception=ValueError( @@ -235,7 +242,10 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in data_col] - if not all(_validate_numpy_array(converted_data, ft_type) for converted_data in converted_data_list): + if not all( + _validate_numpy_array(converted_data, ft_type, strict=strict) + for converted_data in converted_data_list + ): raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_DATA, original_exception=ValueError( @@ -264,7 +274,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS ), ) - if not all(_validate_numpy_array(data_row, ft_type) for data_row in data_col): + if not all(_validate_numpy_array(data_row, ft_type, strict=strict) for data_row in data_col): raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_DATA, original_exception=ValueError( @@ -375,13 +385,14 @@ def _get_dataframe_values_range( def _validate_snowpark_data( - data: snowflake.snowpark.DataFrame, features: Sequence[core.BaseFeatureSpec] + data: snowflake.snowpark.DataFrame, features: Sequence[core.BaseFeatureSpec], strict: bool = False ) -> SnowparkIdentifierRule: """Validate Snowpark DataFrame as input. It will try to map both normalized name or inferred name. Args: data: A snowpark dataframe to be validated. features: A sequence of feature specifications and feature group specifications, where the dataframe should fit. + strict: Enable strict validation, this includes value range based validation. Raises: SnowflakeMLException: NotImplementedError: FeatureGroupSpec is not supported. @@ -398,7 +409,10 @@ def _validate_snowpark_data( SnowparkIdentifierRule.NORMALIZED: [], } schema = data.schema - values_range = _get_dataframe_values_range(data) + if strict: + values_range = _get_dataframe_values_range(data) + else: + values_range = {} for identifier_rule in errors.keys(): for feature in features: try: @@ -423,7 +437,7 @@ def _validate_snowpark_data( errors[identifier_rule].append( ValueError( f"Data Validation Error in feature {feature.name}: " - + f"Feature is an array feature, while {field.name} is not." + + f"Feature is a scalar feature, while {field.name} is not." ), ) warnings.warn( @@ -436,13 +450,13 @@ def _validate_snowpark_data( errors[identifier_rule].append( ValueError( f"Data Validation Error in feature {feature.name}: " - + f"Feature is a scalar feature, while {field.name} is not." + + f"Feature is an array feature, while {field.name} is not." ), ) continue try: _validate_snowpark_type_feature( - data, field, ft_type, feature.name, values_range.get(field.name, None) + data, field, ft_type, feature.name, values_range.get(field.name, None), strict=strict ) except snowml_exceptions.SnowflakeMLException as e: errors[identifier_rule].append(e.original_exception) @@ -479,6 +493,7 @@ def _validate_snowpark_type_feature( ft_type: DataType, ft_name: str, value_range: Optional[Union[Tuple[int, int], Tuple[float, float]]], + strict: bool = False, ) -> None: field_data_type = field.datatype col_name = identifier.get_unescaped_names(field.name) @@ -505,6 +520,8 @@ def _validate_snowpark_type_feature( f"because of its original type {field_data_type}" ), ) + if not strict: + return if value_range is None: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_DATA, @@ -521,7 +538,7 @@ def _validate_snowpark_type_feature( original_exception=ValueError( f"Data Validation Error in feature {ft_name}: " f"Feature type {ft_type} is not met by column {col_name} " - f"because it overflows with min" + f"because it overflows with min or max" ), ) elif ft_type in [core.DataType.FLOAT, core.DataType.DOUBLE]: @@ -538,15 +555,6 @@ def _validate_snowpark_type_feature( + f"Feature type {ft_type} is not met by column {col_name}." ), ) - if value_range is None: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_DATA, - original_exception=ValueError( - f"Data Validation Error in feature {ft_name}: " - f"Feature type {ft_type} is not met by column {col_name} " - f"because of its original type {field_data_type} is non-Numeric." - ), - ) if isinstance(field_data_type, spt.DecimalType) and field_data_type.scale > 0: warnings.warn( ( @@ -558,6 +566,18 @@ def _validate_snowpark_type_feature( category=UserWarning, stacklevel=2, ) + + if not strict: + return + if value_range is None: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_DATA, + original_exception=ValueError( + f"Data Validation Error in feature {ft_name}: " + f"Feature type {ft_type} is not met by column {col_name} " + f"because of its original type {field_data_type} is non-Numeric." + ), + ) min_v, max_v = value_range if ( max_v > np.finfo(ft_type._numpy_type).max # type: ignore[arg-type] @@ -567,7 +587,8 @@ def _validate_snowpark_type_feature( error_code=error_codes.INVALID_DATA, original_exception=ValueError( f"Data Validation Error in feature {ft_name}: " - + f"Feature type {ft_type} is not met by column {col_name}." + f"Feature type {ft_type} is not met by column {col_name}." + f"because it overflows with min or max" ), ) else: @@ -576,7 +597,7 @@ def _validate_snowpark_type_feature( error_code=error_codes.INVALID_DATA, original_exception=ValueError( f"Data Validation Error in feature {ft_name}: " - + f"Feature type {ft_type} is not met by column {col_name}." + f"Feature type {ft_type} is not met by column {col_name}." ), ) @@ -609,21 +630,21 @@ def _convert_local_data_to_df(data: model_types.SupportedLocalDataType) -> pd.Da def _convert_and_validate_local_data( - data: model_types.SupportedLocalDataType, features: Sequence[core.BaseFeatureSpec] + data: model_types.SupportedLocalDataType, features: Sequence[core.BaseFeatureSpec], strict: bool = False ) -> pd.DataFrame: """Validate the data with features in model signature and convert to DataFrame Args: features: A list of feature specs that the data should follow. data: The provided data. + strict: Enable strict validation. Returns: The converted dataframe with renamed column index. """ df = _convert_local_data_to_df(data) df = utils.rename_pandas_df(df, features) - _validate_pandas_df(df, features) - df = pandas_handler.PandasDataFrameHandler.convert_to_df(df, ensure_serializable=True) + _validate_pandas_df(df, features, strict=strict) return df diff --git a/snowflake/ml/model/model_signature_test.py b/snowflake/ml/model/model_signature_test.py index ae27d7c0..28377ead 100644 --- a/snowflake/ml/model/model_signature_test.py +++ b/snowflake/ml/model/model_signature_test.py @@ -135,14 +135,18 @@ def test_validate_pandas_df(self) -> None: expected_original_error_type=ValueError, expected_regex="Feature type [^\\s]* is not met by all elements", ): - model_signature._validate_pandas_df(pd.DataFrame([[257, 5], [6, 8]], columns=["a", "b"]), fts) + model_signature._validate_pandas_df(pd.DataFrame([[257, 5], [6, 8]], columns=["a", "b"]), fts, strict=True) + + model_signature._validate_pandas_df(pd.DataFrame([[257, 5], [6, 8]], columns=["a", "b"]), fts) with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, expected_regex="Feature type [^\\s]* is not met by all elements", ): - model_signature._validate_pandas_df(pd.DataFrame([[2, -5], [6, 8]], columns=["a", "b"]), fts) + model_signature._validate_pandas_df(pd.DataFrame([[2, -5], [6, 8]], columns=["a", "b"]), fts, strict=True) + + model_signature._validate_pandas_df(pd.DataFrame([[2, -5], [6, 8]], columns=["a", "b"]), fts) fts = [ model_signature.FeatureSpec("a", model_signature.DataType.INT8), @@ -201,18 +205,26 @@ def test_validate_pandas_df(self) -> None: expected_regex="Feature type [^\\s]* is not met by all elements", ): model_signature._validate_pandas_df( - pd.DataFrame([[[1, 257], [2, 6]], [[2, 3], [2, 6]]], columns=["a", "b"]), fts + pd.DataFrame([[[1, 257], [2, 6]], [[2, 3], [2, 6]]], columns=["a", "b"]), fts, strict=True ) + model_signature._validate_pandas_df( + pd.DataFrame([[[1, 257], [2, 6]], [[2, 3], [2, 6]]], columns=["a", "b"]), fts + ) + with exception_utils.assert_snowml_exceptions( self, expected_original_error_type=ValueError, expected_regex="Feature type [^\\s]* is not met by all elements", ): model_signature._validate_pandas_df( - pd.DataFrame([[[1, 2], [2, -6]], [[2, 3], [2, 6]]], columns=["a", "b"]), fts + pd.DataFrame([[[1, 2], [2, -6]], [[2, 3], [2, 6]]], columns=["a", "b"]), fts, strict=True ) + model_signature._validate_pandas_df( + pd.DataFrame([[[1, 2], [2, -6]], [[2, 3], [2, 6]]], columns=["a", "b"]), fts + ) + fts = [ model_signature.FeatureSpec("a", model_signature.DataType.INT64), model_signature.FeatureSpec("b", model_signature.DataType.DOUBLE, shape=(2,)), diff --git a/snowflake/ml/model/type_hints.py b/snowflake/ml/model/type_hints.py index 74dc75fe..111e41a7 100644 --- a/snowflake/ml/model/type_hints.py +++ b/snowflake/ml/model/type_hints.py @@ -201,6 +201,7 @@ class SnowparkContainerServiceDeployOptions(DeployOptions): class ModelMethodSaveOptions(TypedDict): case_sensitive: NotRequired[bool] max_batch_size: NotRequired[int] + function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]] class BaseModelSaveOption(TypedDict): diff --git a/snowflake/ml/modeling/_internal/BUILD.bazel b/snowflake/ml/modeling/_internal/BUILD.bazel index 4d9832e8..9a6b6f14 100644 --- a/snowflake/ml/modeling/_internal/BUILD.bazel +++ b/snowflake/ml/modeling/_internal/BUILD.bazel @@ -3,24 +3,50 @@ load("//bazel:py_rules.bzl", "py_library", "py_test") package(default_visibility = ["//visibility:public"]) py_library( - name = "estimator_protocols", - srcs = ["estimator_protocols.py"], + name = "transformer_protocols", + srcs = ["transformer_protocols.py"], ) py_library( - name = "estimator_utils", - srcs = ["estimator_utils.py"], + name = "constants", + srcs = ["constants.py"], + deps = [], +) + +py_library( + name = "model_transformer_builder", + srcs = ["model_transformer_builder.py"], deps = [ + ":constants", + ":transformer_protocols", "//snowflake/ml/_internal/exceptions", + "//snowflake/ml/modeling/_internal/local_implementations:pandas_handlers", + "//snowflake/ml/modeling/_internal/ml_runtime_implementations:ml_runtime_handlers", + "//snowflake/ml/modeling/_internal/snowpark_implementations:snowpark_handlers", "//snowflake/ml/modeling/framework", ], ) py_test( - name = "estimator_protocols_test", - srcs = ["estimator_protocols_test.py"], + name = "model_transformer_builder_test", + srcs = ["model_transformer_builder_test.py"], + deps = [ + ":constants", + ":model_transformer_builder", + ":transformer_protocols", + "//snowflake/ml/modeling/_internal/local_implementations:pandas_handlers", + "//snowflake/ml/modeling/_internal/ml_runtime_implementations:ml_runtime_handlers", + "//snowflake/ml/modeling/_internal/snowpark_implementations:snowpark_handlers", + "//snowflake/ml/utils:connection_params", + ], +) + +py_library( + name = "estimator_utils", + srcs = ["estimator_utils.py"], deps = [ - ":estimator_protocols", + "//snowflake/ml/_internal/exceptions", + "//snowflake/ml/modeling/framework", ], ) @@ -61,9 +87,11 @@ py_library( name = "model_trainer_builder", srcs = ["model_trainer_builder.py"], deps = [ + ":constants", ":estimator_utils", ":model_trainer", "//snowflake/ml/modeling/_internal/local_implementations:pandas_trainer", + "//snowflake/ml/modeling/_internal/ml_runtime_implementations:ml_runtime_trainer", "//snowflake/ml/modeling/_internal/snowpark_implementations:distributed_hpo_trainer", "//snowflake/ml/modeling/_internal/snowpark_implementations:snowpark_trainer", "//snowflake/ml/modeling/_internal/snowpark_implementations:xgboost_external_memory_trainer", @@ -74,9 +102,11 @@ py_test( name = "model_trainer_builder_test", srcs = ["model_trainer_builder_test.py"], deps = [ + ":constants", ":model_trainer", ":model_trainer_builder", "//snowflake/ml/modeling/_internal/local_implementations:pandas_trainer", + "//snowflake/ml/modeling/_internal/ml_runtime_implementations:ml_runtime_trainer", "//snowflake/ml/modeling/_internal/snowpark_implementations:distributed_hpo_trainer", "//snowflake/ml/modeling/_internal/snowpark_implementations:snowpark_trainer", "//snowflake/ml/modeling/_internal/snowpark_implementations:xgboost_external_memory_trainer", diff --git a/snowflake/ml/modeling/_internal/constants.py b/snowflake/ml/modeling/_internal/constants.py new file mode 100644 index 00000000..c62b1143 --- /dev/null +++ b/snowflake/ml/modeling/_internal/constants.py @@ -0,0 +1 @@ +IN_ML_RUNTIME_ENV_VAR = "IN_SPCS_ML_RUNTIME" diff --git a/snowflake/ml/modeling/_internal/estimator_protocols.py b/snowflake/ml/modeling/_internal/estimator_protocols.py deleted file mode 100644 index 2b71e1ec..00000000 --- a/snowflake/ml/modeling/_internal/estimator_protocols.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import List, Optional, Protocol - -import pandas as pd - -from snowflake.snowpark import DataFrame, Session - - -# TODO: Add more specific entities to type hint estimators instead of using `object`. -class TransformerHandlers(Protocol): - def batch_inference( - self, - dataset: DataFrame, - session: Session, - estimator: object, - dependencies: List[str], - inference_method: str, - input_cols: List[str], - pass_through_columns: List[str], - expected_output_cols_list: List[str], - expected_output_cols_type: str = "", - ) -> DataFrame: - raise NotImplementedError - - def score_pandas( - self, - dataset: pd.DataFrame, - estimator: object, - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], - ) -> float: - raise NotImplementedError - - def score_snowpark( - self, - dataset: DataFrame, - session: Session, - estimator: object, - dependencies: List[str], - score_sproc_imports: List[str], - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], - ) -> float: - raise NotImplementedError diff --git a/snowflake/ml/modeling/_internal/estimator_protocols_test.py b/snowflake/ml/modeling/_internal/estimator_protocols_test.py deleted file mode 100644 index 592d4962..00000000 --- a/snowflake/ml/modeling/_internal/estimator_protocols_test.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Protocol - -from absl.testing import absltest, parameterized - -from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers - - -class EstimatorProtocolsTest(parameterized.TestCase): - def test_fit_predict_handlers(self) -> None: - self.assertIsInstance(TransformerHandlers, Protocol) - - def test_cv_handlers(self) -> None: - self.assertIsInstance(TransformerHandlers, Protocol) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/modeling/_internal/local_implementations/BUILD.bazel b/snowflake/ml/modeling/_internal/local_implementations/BUILD.bazel index 4295fa33..ffbb6991 100644 --- a/snowflake/ml/modeling/_internal/local_implementations/BUILD.bazel +++ b/snowflake/ml/modeling/_internal/local_implementations/BUILD.bazel @@ -9,3 +9,12 @@ py_library( "//snowflake/ml/modeling/_internal:model_trainer", ], ) + +py_library( + name = "pandas_handlers", + srcs = ["pandas_handlers.py"], + deps = [ + "//snowflake/ml/_internal/exceptions", + "//snowflake/ml/modeling/_internal:transformer_protocols", + ], +) diff --git a/snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py b/snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py new file mode 100644 index 00000000..b461b6c8 --- /dev/null +++ b/snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py @@ -0,0 +1,226 @@ +import inspect +from typing import Any, List, Optional + +import numpy as np +import pandas as pd + +from snowflake.ml._internal.exceptions import error_codes, exceptions + + +class PandasTransformHandlers: + """Transform(inference and scoring) functions for a pandas dataset.""" + + def __init__( + self, + dataset: pd.DataFrame, + estimator: object, + class_name: str, + subproject: str, + autogenerated: Optional[bool] = False, + ) -> None: + """ + Args: + dataset: The dataset to run transform functions on. + estimator: The estimator used to run transforms. + class_name: class name to be used in telemetry. + subproject: subproject to be used in telemetry. + autogenerated: Whether the class was autogenerated from a template. + """ + self.dataset = dataset + self.estimator = estimator + self.class_name = class_name + self.subproject = subproject + self.autogenerated = autogenerated + + def batch_inference( + self, + inference_method: str, + input_cols: List[str], + expected_output_cols: List[str], + snowpark_input_cols: Optional[List[str]] = None, + drop_input_cols: Optional[bool] = False, + *args: Any, + **kwargs: Any, + ) -> pd.DataFrame: + """Run batch inference on the given dataset. + + Args: + inference_method: the name of the method used by `estimator` to run inference. + input_cols: column names of the input dataset + expected_output_cols: column names (in order) of the output dataset. + snowpark_input_cols: list of snowpark columns. + Covers the situation where training happens in snowpark, transform in pandas. + drop_input_cols: If set True, the response will not contain input columns. + args: additional positional arguments. + kwargs: additional keyword args. + + Returns: + A new dataset of the same type as the input dataset. + + Raises: + SnowflakeMLException: Mismatches between expected feature names and provided feature names. + SnowflakeMLException: expected_output_cols list length does not match required length. + """ + + output_cols = expected_output_cols.copy() + dataset = self.dataset + # Model expects exact same columns names in the input df for predict call. + # Given the scenario that user use snowpark DataFrame in fit call, but pandas DataFrame in predict call + # input cols need to match unquoted / quoted + + if snowpark_input_cols is None: + snowpark_input_cols = [] + + if hasattr(self.estimator, "feature_names_in_"): + features_required_by_estimator = self.estimator.feature_names_in_ + else: + features_required_by_estimator = snowpark_input_cols + + missing_features = [] + features_in_dataset = set(dataset.columns) + + columns_to_select = [] + + for i, f in enumerate(features_required_by_estimator): + if ( + i >= len(input_cols) + or (input_cols[i] != f and snowpark_input_cols[i] != f) + or (input_cols[i] not in features_in_dataset and snowpark_input_cols[i] not in features_in_dataset) + ): + missing_features.append(f) + elif input_cols[i] in features_in_dataset: + columns_to_select.append(input_cols[i]) + elif snowpark_input_cols[i] in features_in_dataset: + columns_to_select.append(snowpark_input_cols[i]) + + if len(missing_features) > 0: + raise exceptions.SnowflakeMLException( + error_code=error_codes.NOT_FOUND, + original_exception=ValueError( + "The feature names should match with those that were passed during fit.\n" + f"Features seen during fit call but not present in the input: {missing_features}\n" + f"Features in the input dataframe : {input_cols}\n" + ), + ) + input_df = dataset[columns_to_select] + input_df.columns = features_required_by_estimator + + inference_res = getattr(self.estimator, inference_method)(input_df, *args, **kwargs) + + if isinstance(inference_res, list) and len(inference_res) > 0 and isinstance(inference_res[0], np.ndarray): + # In case of multioutput estimators, predict_proba, decision_function etc., functions return a list of + # ndarrays. We need to concatenate them. + + # First compute output column names + if len(output_cols) == len(inference_res): + actual_output_cols = [] + for idx, np_arr in enumerate(inference_res): + for i in range(1 if len(np_arr.shape) <= 1 else np_arr.shape[1]): + actual_output_cols.append(f"{output_cols[idx]}_{i}") + output_cols = actual_output_cols + + # Concatenate np arrays + transformed_numpy_array = np.concatenate(inference_res, axis=1) + elif isinstance(inference_res, tuple) and len(inference_res) > 0 and isinstance(inference_res[0], np.ndarray): + # In case of kneighbors, functions return a tuple of ndarrays. + transformed_numpy_array = np.stack(inference_res, axis=1) + else: + transformed_numpy_array = inference_res + + if (len(transformed_numpy_array.shape) == 3) and inference_method != "kneighbors": + # VotingClassifier will return results of shape (n_classifiers, n_samples, n_classes) + # when voting = "soft" and flatten_transform = False. We can't handle unflatten transforms, + # so we ignore flatten_transform flag and flatten the results. + transformed_numpy_array = np.hstack(transformed_numpy_array) # type: ignore[call-overload] + + if len(transformed_numpy_array.shape) == 1: + transformed_numpy_array = np.reshape(transformed_numpy_array, (-1, 1)) + + shape = transformed_numpy_array.shape + if shape[1] != len(output_cols): + if len(output_cols) != 1: + raise exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_ARGUMENT, + original_exception=TypeError( + "expected_output_cols must be same length as transformed array or " "should be of length 1" + ), + ) + actual_output_cols = [] + for i in range(shape[1]): + actual_output_cols.append(f"{output_cols[0]}_{i}") + output_cols = actual_output_cols + + if inference_method == "kneighbors": + if len(transformed_numpy_array.shape) == 3: # return_distance=True + shape = transformed_numpy_array.shape + data = [transformed_numpy_array[:, i, :].tolist() for i in range(shape[1])] + kneighbors_df = pd.DataFrame({output_cols[i]: data[i] for i in range(shape[1])}) + else: # return_distance=False + kneighbors_df = pd.DataFrame( + { + { + output_cols[0]: [ + transformed_numpy_array[i, :].tolist() for i in range(transformed_numpy_array.shape[0]) + ] + } + } + ) + + if drop_input_cols: + dataset = kneighbors_df + else: + dataset = pd.concat([dataset, kneighbors_df], axis=1) + else: + if drop_input_cols: + dataset = pd.DataFrame(data=transformed_numpy_array, columns=output_cols) + else: + dataset = dataset.copy() + dataset[output_cols] = transformed_numpy_array + return dataset + + def score( + self, + input_cols: List[str], + label_cols: List[str], + sample_weight_col: Optional[str], + *args: Any, + **kwargs: Any, + ) -> float: + """Score the given test dataset. + + Args: + input_cols: List of feature columns for scoring. + label_cols: List of label columns for scoring. + sample_weight_col: A column assigning relative weights to each row for scoring. + args: additional positional arguments. + kwargs: additional keyword args. + + Returns: + An accuracy score for the model on the given test data. + + Raises: + SnowflakeMLException: The input column list does not have one of `X` and `X_test`. + """ + assert hasattr(self.estimator, "score") # make type checker happy + argspec = inspect.getfullargspec(self.estimator.score) + if "X" in argspec.args: + score_args = {"X": self.dataset[input_cols]} + elif "X_test" in argspec.args: + score_args = {"X_test": self.dataset[input_cols]} + else: + raise exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_ATTRIBUTE, + original_exception=RuntimeError("Neither 'X' or 'X_test' exist in argument"), + ) + + if len(label_cols) > 0: + label_arg_name = "Y" if "Y" in argspec.args else "y" + score_args[label_arg_name] = self.dataset[label_cols].squeeze() + + if sample_weight_col is not None and "sample_weight" in argspec.args: + score_args["sample_weight"] = self.dataset[sample_weight_col].squeeze() + + score = self.estimator.score(**score_args) + assert isinstance(score, float) # make type checker happy + + return score diff --git a/snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py b/snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py index 6a2d726e..784566bd 100644 --- a/snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +++ b/snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py @@ -1,5 +1,5 @@ import inspect -from typing import List, Optional +from typing import List, Optional, Tuple import pandas as pd @@ -52,3 +52,31 @@ def train(self) -> object: args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze() return self.estimator.fit(**args) + + def train_fit_predict( + self, + pass_through_columns: List[str], + expected_output_cols_list: List[str], + ) -> Tuple[pd.DataFrame, object]: + """Trains the model using specified features and target columns from the dataset. + This API is different from fit itself because it would also provide the predict + output. + + Args: + pass_through_columns (List[str]): The column names that would + display in the returned dataset. + expected_output_cols_list (List[str]): The output columns + name as a list. Defaults to None. + + Returns: + Tuple[pd.DataFrame, object]: [predicted dataset, estimator] + """ + assert hasattr(self.estimator, "fit_predict") # make type checker happy + args = {"X": self.dataset[self.input_cols]} + result = self.estimator.fit_predict(**args) + result_df = pd.DataFrame(data=result, columns=expected_output_cols_list) + if len(pass_through_columns) == 0: + result_df = result_df + else: + result_df = pd.concat([self.dataset, result_df], axis=1) + return (result_df, self.estimator) diff --git a/snowflake/ml/modeling/_internal/ml_runtime_implementations/BUILD.bazel b/snowflake/ml/modeling/_internal/ml_runtime_implementations/BUILD.bazel new file mode 100644 index 00000000..907a7c9b --- /dev/null +++ b/snowflake/ml/modeling/_internal/ml_runtime_implementations/BUILD.bazel @@ -0,0 +1,36 @@ +load("//bazel:py_rules.bzl", "py_library", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "ml_runtime_trainer", + srcs = ["ml_runtime_trainer.py"], + deps = [ + "//snowflake/ml/modeling/_internal:model_trainer", + ], +) + +py_library( + name = "ml_runtime_handlers", + srcs = ["ml_runtime_handlers.py"], + deps = [ + "//snowflake/ml/_internal/exceptions", + "//snowflake/ml/modeling/_internal:transformer_protocols", + ], +) + +py_test( + name = "ml_runtime_handlers_test", + srcs = ["ml_runtime_handlers_test.py"], + deps = [ + ":ml_runtime_handlers", + ], +) + +py_test( + name = "ml_runtime_trainer_test", + srcs = ["ml_runtime_trainer_test.py"], + deps = [ + ":ml_runtime_trainer", + ], +) diff --git a/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py b/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py new file mode 100644 index 00000000..fb868dab --- /dev/null +++ b/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py @@ -0,0 +1,131 @@ +from typing import Any, List, Optional + +from snowflake.snowpark import DataFrame, Session + + +class MLRuntimeTransformHandlers: + def __init__( + self, + dataset: DataFrame, + estimator: object, + class_name: str, + subproject: str, + autogenerated: Optional[bool] = False, + ) -> None: + """ + Args: + dataset: The dataset to run transform functions on. + estimator: The estimator used to run transforms. + class_name: class name to be used in telemetry. + subproject: subproject to be used in telemetry. + autogenerated: Whether the class was autogenerated from a template. + + Raises: + ModuleNotFoundError: The mlruntimes_client module is not available. + """ + try: + from snowflake.ml.runtime import MLRuntimeClient + except ModuleNotFoundError as e: + # This is an internal exception, not a user-facing one. The snowflake.ml.runtime module should + # always be present when this class is instantiated. + raise e + + self.client = MLRuntimeClient() + self.dataset = dataset + self.estimator = estimator + self._class_name = class_name + self._subproject = subproject + self._autogenerated = autogenerated + + def batch_inference( + self, + inference_method: str, + input_cols: List[str], + expected_output_cols: List[str], + pass_through_cols: List[str], + session: Session, + dependencies: List[str], + expected_output_cols_type: Optional[str] = "", + *args: Any, + **kwargs: Any, + ) -> DataFrame: + """Run batch inference on the given dataset. + + Args: + inference_method: the name of the method used by `estimator` to run inference. + input_cols: List of feature columns for inference. + session: An active Snowpark Session. + dependencies: List of dependencies for the transformer. + expected_output_cols: column names (in order) of the output dataset. + pass_through_cols: columns in the dataset not used in inference. + expected_output_cols_type: Expected type of the output columns. + args: additional positional arguments. + kwargs: additional keyword args. + + Returns: + A new dataset of the same type as the input dataset. + + Raises: + TypeError: The ML Runtimes client returned a non-DataFrame result. + """ + output_df = self.client.batch_inference( + inference_method=inference_method, + dataset=self.dataset, + estimator=self.estimator, + input_cols=input_cols, + expected_output_cols=expected_output_cols, + pass_through_cols=pass_through_cols, + session=session, + dependencies=dependencies, + expected_output_cols_type=expected_output_cols_type, + *args, + **kwargs, + ) + if not isinstance(output_df, DataFrame): + raise TypeError( + f"The ML Runtimes Client did not return a DataFrame a non-float value Returned type: {type(output_df)}" + ) + return output_df + + def score( + self, + input_cols: List[str], + label_cols: List[str], + session: Session, + dependencies: List[str], + score_sproc_imports: List[str], + sample_weight_col: Optional[str] = None, + *args: Any, + **kwargs: Any, + ) -> float: + """Score the given test dataset. + + Args: + session: An active Snowpark Session. + dependencies: score function dependencies. + score_sproc_imports: imports for score stored procedure. + input_cols: List of feature columns for inference. + label_cols: List of label columns for scoring. + sample_weight_col: A column assigning relative weights to each row for scoring. + args: additional positional arguments. + kwargs: additional keyword args. + + + Returns: + An accuracy score for the model on the given test data. + + Raises: + TypeError: The ML Runtimes client returned a non-float result + """ + output_score = self.client.score( + estimator=self.estimator, + dataset=self.dataset, + input_cols=input_cols, + label_cols=label_cols, + sample_weight_col=sample_weight_col, + ) + if not isinstance(output_score, float): + raise TypeError( + f"The ML Runtimes Client returned a non-float value {output_score} of type {type(output_score)}" + ) + return output_score diff --git a/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers_test.py b/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers_test.py new file mode 100644 index 00000000..74f1d53a --- /dev/null +++ b/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers_test.py @@ -0,0 +1,41 @@ +import sys + +from absl.testing import absltest +from sklearn.linear_model import LinearRegression + +from snowflake.ml.modeling._internal.ml_runtime_implementations.ml_runtime_handlers import ( + MLRuntimeTransformHandlers, +) +from snowflake.snowpark import DataFrame + + +class MLRuntimeTransformHandlersTest(absltest.TestCase): + def setUp(self) -> None: + self.dataset = absltest.mock.MagicMock(spec=DataFrame) + self.estimator = absltest.mock.MagicMock(spec=LinearRegression) + + def test_exception_client_package_available(self) -> None: + with absltest.mock.patch.dict(sys.modules, {"snowflake.ml.runtime": absltest.mock.Mock()}): + MLRuntimeTransformHandlers( + dataset=self.dataset, + estimator=self.estimator, + class_name="", + subproject="", + ) + + def test_exception_client_package_unavailable(self) -> None: + + with absltest.mock.patch.dict( + sys.modules, {key: value for key, value in sys.modules.items() if key != "snowflake.ml.runtime"} + ): + with self.assertRaises(ModuleNotFoundError): + MLRuntimeTransformHandlers( + dataset=self.dataset, + estimator=self.estimator, + class_name="", + subproject="", + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py b/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py new file mode 100644 index 00000000..52014e81 --- /dev/null +++ b/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py @@ -0,0 +1,66 @@ +from typing import List, Optional + +from snowflake.snowpark import DataFrame, Session + + +class MLRuntimeModelTrainer: + """ML model training using the ml runties client.""" + + def __init__( + self, + estimator: object, + dataset: DataFrame, + session: Session, + input_cols: List[str], + label_cols: Optional[List[str]], + sample_weight_col: Optional[str], + autogenerated: bool = False, + subproject: str = "", + ) -> None: + """ + Initializes the MLRuntimeModelTrainer with a model, a Snowpark DataFrame, feature, and label column names. + + Args: + estimator: SKLearn compatible estimator or transformer object. + dataset: The dataset used for training the model. + session: Snowflake session object to be used for training. + input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training. + label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn. + sample_weight_col: The column name representing the weight of training examples. + autogenerated: A boolean denoting if the trainer is being used by autogenerated code or not. + subproject: subproject name to be used in telemetry. + + Raises: + ModuleNotFoundError: The mlruntimes_client module is not available. + """ + + try: + from snowflake.ml.runtime import MLRuntimeClient + except ModuleNotFoundError as e: + # This is an internal exception, not a user-facing one. The snowflake.ml.runtime module should + # always be present when this class is instantiated. + raise e + + self.client = MLRuntimeClient() + + self.estimator = estimator + self.dataset = dataset + self.session = session + self.input_cols = input_cols + self.label_cols = label_cols + self.sample_weight_col = sample_weight_col + self._autogenerated = autogenerated + self._subproject = subproject + self._class_name = estimator.__class__.__name__ + + def train(self) -> object: + """ + Trains the model by pushing down the compute into SPCS ML Runtime + """ + return self.client.train( + estimator=self.estimator, + dataset=self.dataset, + input_cols=self.input_cols, + label_cols=self.label_cols, + sample_weight_col=self.sample_weight_col, + ) diff --git a/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer_test.py b/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer_test.py new file mode 100644 index 00000000..a9cc538b --- /dev/null +++ b/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer_test.py @@ -0,0 +1,46 @@ +import sys + +from absl.testing import absltest +from sklearn.linear_model import LinearRegression + +from snowflake.ml.modeling._internal.ml_runtime_implementations.ml_runtime_trainer import ( + MLRuntimeModelTrainer, +) +from snowflake.snowpark import DataFrame + + +class MLRuntimeModelTrainerTest(absltest.TestCase): + def setUp(self) -> None: + self.dataset = absltest.mock.MagicMock(spec=DataFrame) + self.dataset._session = absltest.mock.Mock() + self.estimator = absltest.mock.MagicMock(spec=LinearRegression) + + def test_exception_client_package_available(self) -> None: + with absltest.mock.patch.dict(sys.modules, {"snowflake.ml.runtime": absltest.mock.Mock()}): + MLRuntimeModelTrainer( + estimator=self.estimator, + dataset=self.dataset, + session=self.dataset._session, + input_cols=["col_1", "col_2"], + label_cols=["col_1"], + sample_weight_col=None, + ) + + def test_exception_client_package_unavailable(self) -> None: + + with absltest.mock.patch.dict( + sys.modules, {key: value for key, value in sys.modules.items() if key != "snowflake.ml.runtime"} + ): + with self.assertRaises(ModuleNotFoundError): + MLRuntimeModelTrainer( + estimator=self.estimator, + dataset=self.dataset, + session=self.dataset._session, + input_cols=["col_1", "col_2"], + label_cols=["col_1"], + sample_weight_col=None, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/modeling/_internal/model_trainer.py b/snowflake/ml/modeling/_internal/model_trainer.py index 0c99a011..d471faab 100644 --- a/snowflake/ml/modeling/_internal/model_trainer.py +++ b/snowflake/ml/modeling/_internal/model_trainer.py @@ -1,4 +1,8 @@ -from typing import Protocol +from typing import List, Protocol, Tuple, Union + +import pandas as pd + +from snowflake.snowpark import DataFrame class ModelTrainer(Protocol): @@ -11,3 +15,10 @@ class ModelTrainer(Protocol): def train(self) -> object: raise NotImplementedError + + def train_fit_predict( + self, + pass_through_columns: List[str], + expected_output_cols_list: List[str], + ) -> Tuple[Union[DataFrame, pd.DataFrame], object]: + raise NotImplementedError diff --git a/snowflake/ml/modeling/_internal/model_trainer_builder.py b/snowflake/ml/modeling/_internal/model_trainer_builder.py index 6918ed2f..2be7d455 100644 --- a/snowflake/ml/modeling/_internal/model_trainer_builder.py +++ b/snowflake/ml/modeling/_internal/model_trainer_builder.py @@ -1,9 +1,11 @@ +import os from typing import List, Optional, Union import pandas as pd from sklearn import model_selection from snowflake.ml._internal.exceptions import error_codes, exceptions +from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR from snowflake.ml.modeling._internal.estimator_utils import ( get_module_name, is_single_node, @@ -11,6 +13,9 @@ from snowflake.ml.modeling._internal.local_implementations.pandas_trainer import ( PandasModelTrainer, ) +from snowflake.ml.modeling._internal.ml_runtime_implementations.ml_runtime_trainer import ( + MLRuntimeModelTrainer, +) from snowflake.ml.modeling._internal.model_trainer import ModelTrainer from snowflake.ml.modeling._internal.snowpark_implementations.distributed_hpo_trainer import ( DistributedHPOTrainer, @@ -92,7 +97,6 @@ def build( sample_weight_col=sample_weight_col, ) elif isinstance(dataset, DataFrame): - trainer_klass = SnowparkModelTrainer init_args = { "estimator": estimator, "dataset": dataset, @@ -103,6 +107,10 @@ def build( "autogenerated": autogenerated, "subproject": subproject, } + if os.environ.get(IN_ML_RUNTIME_ENV_VAR): + return MLRuntimeModelTrainer(**init_args) # type: ignore[arg-type, return-value] + + trainer_klass = SnowparkModelTrainer assert dataset._session is not None # Make MyPy happy if isinstance(estimator, model_selection.GridSearchCV) or isinstance( @@ -124,3 +132,50 @@ def build( f"Unexpected dataset type: {type(dataset)}." "Supported dataset types: snowpark.DataFrame, pandas.DataFrame." ) + + @classmethod + def build_fit_predict( + cls, + estimator: object, + dataset: Union[DataFrame, pd.DataFrame], + input_cols: Optional[List[str]] = None, + autogenerated: bool = False, + subproject: str = "", + ) -> ModelTrainer: + """ + Builder method that creates an appropriate ModelTrainer instance based on the given params. + """ + if input_cols is None: + raise exceptions.SnowflakeMLException( + error_code=error_codes.NOT_FOUND, + original_exception=ValueError( + "The input column names (input_cols) is None.\n" + "Please put your input_cols when initializing the estimator\n" + ), + ) + if isinstance(dataset, pd.DataFrame): + return PandasModelTrainer( + estimator=estimator, + dataset=dataset, + input_cols=input_cols, + label_cols=None, + sample_weight_col=None, + ) + elif isinstance(dataset, DataFrame): + trainer_klass = SnowparkModelTrainer + init_args = { + "estimator": estimator, + "dataset": dataset, + "session": dataset._session, + "input_cols": input_cols, + "label_cols": None, + "sample_weight_col": None, + "autogenerated": autogenerated, + "subproject": subproject, + } + return trainer_klass(**init_args) # type: ignore[arg-type] + else: + raise TypeError( + f"Unexpected dataset type: {type(dataset)}." + "Supported dataset types: snowpark.DataFrame, pandas.DataFrame." + ) diff --git a/snowflake/ml/modeling/_internal/model_trainer_builder_test.py b/snowflake/ml/modeling/_internal/model_trainer_builder_test.py index 0d653ec2..3eb9f0c1 100644 --- a/snowflake/ml/modeling/_internal/model_trainer_builder_test.py +++ b/snowflake/ml/modeling/_internal/model_trainer_builder_test.py @@ -1,3 +1,5 @@ +import os +import sys from typing import Any from unittest import mock @@ -8,6 +10,10 @@ from sklearn.model_selection import GridSearchCV from xgboost import XGBRegressor +from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR +from snowflake.ml.modeling._internal.ml_runtime_implementations.ml_runtime_trainer import ( + MLRuntimeModelTrainer, +) from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder from snowflake.ml.modeling._internal.snowpark_implementations.distributed_hpo_trainer import ( DistributedHPOTrainer, @@ -28,6 +34,7 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() + os.environ.pop(IN_ML_RUNTIME_ENV_VAR, None) def get_snowpark_dataset(self) -> DataFrame: input_df_pandas = load_iris(as_frame=True).frame @@ -43,6 +50,17 @@ def test_sklearn_model_trainer(self) -> None: self.assertTrue(isinstance(trainer, SnowparkModelTrainer)) + def test_sklearn_model_trainer_in_ml_runtime(self) -> None: + model = LinearRegression() + dataset = self.get_snowpark_dataset() + os.environ[IN_ML_RUNTIME_ENV_VAR] = "True" + + with absltest.mock.patch.dict(sys.modules, {**sys.modules, **{"snowflake.ml.runtime": absltest.mock.Mock()}}): + trainer = ModelTrainerBuilder.build(estimator=model, dataset=dataset, input_cols=[]) + del os.environ[IN_ML_RUNTIME_ENV_VAR] + + self.assertTrue(isinstance(trainer, MLRuntimeModelTrainer)) + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") def test_distributed_hpo_trainer(self, mock_is_single_node: Any) -> None: mock_is_single_node.return_value = False diff --git a/snowflake/ml/modeling/_internal/model_transformer_builder.py b/snowflake/ml/modeling/_internal/model_transformer_builder.py new file mode 100644 index 00000000..43d0560c --- /dev/null +++ b/snowflake/ml/modeling/_internal/model_transformer_builder.py @@ -0,0 +1,85 @@ +import os +from typing import Optional, Union + +import pandas as pd + +from snowflake import snowpark +from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR +from snowflake.ml.modeling._internal.local_implementations.pandas_handlers import ( + PandasTransformHandlers, +) +from snowflake.ml.modeling._internal.ml_runtime_implementations.ml_runtime_handlers import ( + MLRuntimeTransformHandlers, +) +from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import ( + SnowparkTransformHandlers, +) +from snowflake.ml.modeling._internal.transformer_protocols import ModelTransformHandlers + + +class ModelTransformerBuilder: + """ + A builder class to create instances of model transformers for different usage configurations. + + This class provides methods to build model transformers tailored to specific machine learning + models and post-training configurations like dataset's location etc. It abstracts the creation process, + allowing the user to obtain a configured model transformer for a particular model architecture or configuration. + """ + + @classmethod + def build( + cls, + dataset: Union[snowpark.DataFrame, pd.DataFrame], + estimator: object, + class_name: str, + subproject: str, + autogenerated: Optional[bool] = False, + ) -> ModelTransformHandlers: + """ + Builder method that creates an appropriate ModelTrainer instance based on the given params. + These params are the specific parameters required to determine where we execute transforms + (currently remote and local) + + Args: + dataset: The dataset on which transforms will be executed. + estimator: The estimator object used to execute transformations. Must support inference and scoring. + class_name: class name to be used in telemetry. + subproject: subproject to be used in telemetry. + autogenerated: Whether the class was autogenerated from a template. + + Returns: + A ModelTransformHandlers based on function inputs + + Raises: + TypeError: Dataset is not one of the currently supported types(pd.DataFrame, snowpark.DataFrame) + """ + if isinstance(dataset, pd.DataFrame): + return PandasTransformHandlers( + dataset=dataset, + estimator=estimator, + class_name=class_name, + subproject=subproject, + autogenerated=autogenerated, + ) + + elif isinstance(dataset, snowpark.DataFrame): + if os.environ.get(IN_ML_RUNTIME_ENV_VAR): + return MLRuntimeTransformHandlers( + dataset=dataset, + estimator=estimator, + class_name=class_name, + subproject=subproject, + autogenerated=autogenerated, + ) + return SnowparkTransformHandlers( + dataset=dataset, + estimator=estimator, + class_name=class_name, + subproject=subproject, + autogenerated=autogenerated, + ) + else: + raise TypeError( + f"Unexpected dataset type: {type(dataset)}." + "Supported dataset types: snowpark.DataFrame, pandas.DataFrame." + ) diff --git a/snowflake/ml/modeling/_internal/model_transformer_builder_test.py b/snowflake/ml/modeling/_internal/model_transformer_builder_test.py new file mode 100644 index 00000000..e74efb18 --- /dev/null +++ b/snowflake/ml/modeling/_internal/model_transformer_builder_test.py @@ -0,0 +1,73 @@ +import os +import sys + +import inflection +import pytest +from absl.testing import absltest +from sklearn.datasets import load_iris + +from snowflake import snowpark +from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR +from snowflake.ml.modeling._internal.local_implementations.pandas_handlers import ( + PandasTransformHandlers, +) +from snowflake.ml.modeling._internal.ml_runtime_implementations.ml_runtime_handlers import ( + MLRuntimeTransformHandlers, +) +from snowflake.ml.modeling._internal.model_transformer_builder import ( + ModelTransformerBuilder, +) +from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import ( + SnowparkTransformHandlers, +) +from snowflake.ml.utils.connection_params import SnowflakeLoginOptions + + +class ModelTransformBuilderTest(absltest.TestCase): + def setUp(self) -> None: + self._session = snowpark.Session.builder.configs(SnowflakeLoginOptions()).create() + self._pandas_dataset = load_iris(as_frame=True).frame + self._snowpark_dataset = self._get_snowpark_dataset() + + def tearDown(self) -> None: + self._session.close() + os.environ.pop(IN_ML_RUNTIME_ENV_VAR, None) + + def _get_snowpark_dataset(self) -> snowpark.DataFrame: + input_df_pandas = load_iris(as_frame=True).frame + input_df_pandas.columns = [inflection.parameterize(c, "_").upper() for c in input_df_pandas.columns] + input_df_pandas["INDEX"] = input_df_pandas.reset_index().index + input_df: snowpark.DataFrame = self._session.create_dataframe(input_df_pandas) + return input_df + + def test_builder_with_pd_dataset(self) -> None: + transformer_handler = ModelTransformerBuilder.build( + class_name="class_name", subproject="sub_project", dataset=self._pandas_dataset, estimator=None + ) + assert isinstance(transformer_handler, PandasTransformHandlers) + + def test_builder_with_snowpark_dataset(self) -> None: + transformer_handler = ModelTransformerBuilder.build( + class_name="class_name", subproject="sub_project", dataset=self._snowpark_dataset, estimator=None + ) + assert isinstance(transformer_handler, SnowparkTransformHandlers) + + def test_builder_with_snowpark_dataset_in_ml_runtime(self) -> None: + os.environ[IN_ML_RUNTIME_ENV_VAR] = "True" + with absltest.mock.patch.dict(sys.modules, {**sys.modules, **{"snowflake.ml.runtime": absltest.mock.Mock()}}): + transformer_handler = ModelTransformerBuilder.build( + class_name="class_name", subproject="sub_project", dataset=self._snowpark_dataset, estimator=None + ) + assert isinstance(transformer_handler, MLRuntimeTransformHandlers) + del os.environ[IN_ML_RUNTIME_ENV_VAR] + + def test_builder_with_invalid_dataset(self) -> None: + dataset_json = self._pandas_dataset.to_json() + with pytest.raises(TypeError): + ModelTransformerBuilder.build( + class_name="class_name", subproject="sub_project", dataset=dataset_json, estimator=None + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py index 477f596a..73848c29 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py @@ -9,8 +9,6 @@ import pandas as pd from snowflake.ml._internal import telemetry -from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV -from snowflake.ml._internal.exceptions import error_codes, exceptions from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator from snowflake.ml._internal.utils.temp_file_utils import ( @@ -41,26 +39,60 @@ def _get_rand_id() -> str: return str(uuid4()).replace("-", "_").upper() -class SnowparkHandlers: - def __init__(self, class_name: str, subproject: str, autogenerated: Optional[bool] = False) -> None: +class SnowparkTransformHandlers: + def __init__( + self, + dataset: DataFrame, + estimator: object, + class_name: str, + subproject: str, + autogenerated: Optional[bool] = False, + ) -> None: + """ + Args: + dataset: The dataset to run transform functions on. + estimator: The estimator used to run transforms. + class_name: class name to be used in telemetry. + subproject: subproject to be used in telemetry. + autogenerated: Whether the class was autogenerated from a template. + """ + self.dataset = dataset + self.estimator = estimator self._class_name = class_name self._subproject = subproject self._autogenerated = autogenerated def batch_inference( self, - dataset: DataFrame, - session: Session, - estimator: object, - dependencies: List[str], inference_method: str, input_cols: List[str], - pass_through_columns: List[str], - expected_output_cols_list: List[str], - expected_output_cols_type: str = "", + expected_output_cols: List[str], + pass_through_cols: List[str], + session: Session, + dependencies: List[str], + expected_output_cols_type: Optional[str] = "", *args: Any, **kwargs: Any, ) -> DataFrame: + """Run batch inference on the given dataset. + + Args: + session: An active Snowpark Session. + dependencies: List of dependencies for the transformer. + inference_method: the name of the method used by `estimator` to run inference. + input_cols: List of feature columns for inference. + pass_through_cols: columns in the dataset not used in inference. + expected_output_cols: column names (in order) of the output dataset. + expected_output_cols_type: Expected type of the output columns. + args: additional positional arguments. + kwargs: additional keyword args. + + Returns: + A new dataset of the same type as the input dataset. + """ + + dataset = self.dataset + estimator = self.estimator # Register vectorized UDF for batch inference batch_inference_udf_name = random_name_for_temp_object(TempObjectType.FUNCTION) snowpark_cols = dataset.select(input_cols).columns @@ -135,25 +167,22 @@ def vec_batch_infer(ds: PandasSeries[dict]) -> PandasSeries[dict]: # type: igno transformed_numpy_array = np.hstack(transformed_numpy_array) # type: ignore[call-overload] if len(transformed_numpy_array.shape) > 1: - if transformed_numpy_array.shape[1] != len(expected_output_cols_list): + if transformed_numpy_array.shape[1] != len(expected_output_cols): # HeterogeneousEnsemble's transform method produce results with variying shapes # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes). # It is hard to predict the response shape without using fragile introspection logic. # So, to avoid that we are packing the results into a dataframe of shape (n_samples, 1) with # each element being a list. - if len(expected_output_cols_list) != 1: + if len(expected_output_cols) != 1: raise TypeError( - "expected_output_cols_list must be same length as transformed array or " - "should be of length 1" + "expected_output_cols must be same length as transformed array or " "should be of length 1" ) series = pd.Series(transformed_numpy_array.tolist()) - transformed_pandas_df = pd.DataFrame(series, columns=expected_output_cols_list) + transformed_pandas_df = pd.DataFrame(series, columns=expected_output_cols) else: - transformed_pandas_df = pd.DataFrame( - transformed_numpy_array.tolist(), columns=expected_output_cols_list - ) + transformed_pandas_df = pd.DataFrame(transformed_numpy_array.tolist(), columns=expected_output_cols) else: - transformed_pandas_df = pd.DataFrame(transformed_numpy_array, columns=expected_output_cols_list) + transformed_pandas_df = pd.DataFrame(transformed_numpy_array, columns=expected_output_cols) return transformed_pandas_df.to_dict("records") # type: ignore[no-any-return] @@ -162,8 +191,8 @@ def vec_batch_infer(ds: PandasSeries[dict]) -> PandasSeries[dict]: # type: igno # Run Transform query_from_df = str(dataset.queries["queries"][0]) - outer_select_list = pass_through_columns[:] - inner_select_list = pass_through_columns[:] + outer_select_list = pass_through_cols[:] + inner_select_list = pass_through_cols[:] outer_select_list.extend( [ @@ -172,7 +201,7 @@ def vec_batch_infer(ds: PandasSeries[dict]) -> PandasSeries[dict]: # type: igno column_name=identifier.get_inferred_name(c), udf_datatype=(f"::{expected_output_cols_type}" if expected_output_cols_type else ""), ) - for c in expected_output_cols_list + for c in expected_output_cols ] ) @@ -202,61 +231,36 @@ def vec_batch_infer(ds: PandasSeries[dict]) -> PandasSeries[dict]: # type: igno return session.sql(sql) - def score_pandas( + def score( self, - dataset: pd.DataFrame, - estimator: object, input_cols: List[str], label_cols: List[str], - sample_weight_col: Optional[str], - ) -> float: - assert hasattr(estimator, "score") # make type checker happy - argspec = inspect.getfullargspec(estimator.score) - if "X" in argspec.args: - args = {"X": dataset[input_cols]} - elif "X_test" in argspec.args: - args = {"X_test": dataset[input_cols]} - else: - raise exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ATTRIBUTE, - original_exception=RuntimeError("Neither 'X' or 'X_test' exist in argument"), - ) - - if len(label_cols) > 0: - label_arg_name = "Y" if "Y" in argspec.args else "y" - args[label_arg_name] = dataset[label_cols].squeeze() - - if sample_weight_col is not None and "sample_weight" in argspec.args: - args["sample_weight"] = dataset[sample_weight_col].squeeze() - - score = estimator.score(**args) - assert isinstance(score, float) # make type checker happy - - return score - - def score_snowpark( - self, - dataset: DataFrame, session: Session, - estimator: object, dependencies: List[str], score_sproc_imports: List[str], - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], + sample_weight_col: Optional[str] = None, + *args: Any, + **kwargs: Any, ) -> float: + """Score the given test dataset. + + Args: + session: An active Snowpark Session. + dependencies: score function dependencies. + score_sproc_imports: imports for score stored procedure. + input_cols: List of feature columns for inference. + label_cols: List of label columns for scoring. + sample_weight_col: A column assigning relative weights to each row for scoring. + args: additional positional arguments. + kwargs: additional keyword args. + + Returns: + An accuracy score for the model on the given test data. + """ + + dataset = self.dataset + estimator = self.estimator dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(dataset) - if SNOWML_SPROC_ENV in os.environ: - statement_params = telemetry.get_function_usage_statement_params( - project=_PROJECT, - subproject=self._subproject, - function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name), - api_calls=[Session.call], - custom_tags=dict([("autogen", True)]) if self._autogenerated else None, - ) - pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params) - pd_df.columns = dataset.columns - return self.score_pandas(pd_df, estimator, input_cols, label_cols, sample_weight_col) # Extract queries that generated the dataframe. We will need to pass it to score procedure. queries = dataset.queries["queries"] diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py index 315eded1..b9954d19 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py @@ -2,9 +2,10 @@ import inspect import os import posixpath -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import cloudpickle as cp +import pandas as pd from snowflake.ml._internal import telemetry from snowflake.ml._internal.exceptions import ( @@ -138,7 +139,7 @@ def _upload_model_to_stage(self, stage_name: str) -> Tuple[str, str]: def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: Dict[str, str]) -> object: """ - Downloads the serialized model from a stage location and unpickels it. + Downloads the serialized model from a stage location and unpickles it. Args: dir_path: Stage directory path where results are stored. @@ -275,6 +276,128 @@ def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProc return fit_wrapper_sproc + def _build_fit_predict_wrapper_sproc( + self, + model_spec: ModelSpecifications, + ) -> Callable[[Session, List[str], str, str, List[str], Dict[str, str], List[str], List[str], str], str]: + """ + Constructs and returns a python stored procedure function to be used for training model. + + Args: + model_spec: ModelSpecifications object that contains model specific information + like required imports, package dependencies, etc. + + Returns: + A callable that can be registered as a stored procedure. + """ + imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml + + def fit_wrapper_function( + session: Session, + sql_queries: List[str], + stage_transform_file_name: str, + stage_result_file_name: str, + input_cols: List[str], + statement_params: Dict[str, str], + pass_through_columns: List[str], + expected_output_cols_list: List[str], + fit_predict_result_name: str, + ) -> str: + import os + + import cloudpickle as cp + import pandas as pd + + for import_name in imports: + importlib.import_module(import_name) + + # Execute snowpark queries and obtain the results as pandas dataframe + # NB: this implies that the result data must fit into memory. + for query in sql_queries[:-1]: + _ = session.sql(query).collect(statement_params=statement_params) + sp_df = session.sql(sql_queries[-1]) + df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params) + df.columns = sp_df.columns + + local_transform_file_name = get_temp_file_path() + + session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params) + + local_transform_file_path = os.path.join( + local_transform_file_name, os.listdir(local_transform_file_name)[0] + ) + with open(local_transform_file_path, mode="r+b") as local_transform_file_obj: + estimator = cp.load(local_transform_file_obj) + + fit_predict_result = estimator.fit_predict(df[input_cols]) + + local_result_file_name = get_temp_file_path() + + with open(local_result_file_name, mode="w+b") as local_result_file_obj: + cp.dump(estimator, local_result_file_obj) + + session.file.put( + local_result_file_name, + stage_result_file_name, + auto_compress=False, + overwrite=True, + statement_params=statement_params, + ) + + # store the predict output + if len(pass_through_columns) != 0: + df = df.copy() + fit_predict_result_pd = pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list) + fit_predict_result_pd = pd.concat([df, fit_predict_result_pd], axis=1) + else: + fit_predict_result_pd = pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list) + + # write into a temp table in sproc and load the table from outside + session.write_pandas( + fit_predict_result_pd, fit_predict_result_name, auto_create_table=True, table_type="temp" + ) + + # Note: you can add something like + "|" + str(df) to the return string + # to pass debug information to the caller. + return str(os.path.basename(local_result_file_name)) + + return fit_wrapper_function + + def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure: + # If the sproc already exists, don't register. + if not hasattr(self.session, "_FIT_PRE_WRAPPER_SPROCS"): + self.session._FIT_PRE_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc] + + model_spec = ModelSpecificationsBuilder.build(model=self.estimator) + fit_predict_sproc_key = model_spec.__class__.__name__ + if fit_predict_sproc_key in self.session._FIT_PRE_WRAPPER_SPROCS: # type: ignore[attr-defined] + fit_sproc: StoredProcedure = self.session._FIT_PRE_WRAPPER_SPROCS[ # type: ignore[attr-defined] + fit_predict_sproc_key + ] + return fit_sproc + + fit_predict_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE) + + relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( + pkg_versions=model_spec.pkgDependencies, session=self.session + ) + + fit_predict_wrapper_sproc = self.session.sproc.register( + func=self._build_fit_predict_wrapper_sproc(model_spec=model_spec), + is_permanent=False, + name=fit_predict_sproc_name, + packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type] + replace=True, + session=self.session, + statement_params=statement_params, + ) + + self.session._FIT_PRE_WRAPPER_SPROCS[ # type: ignore[attr-defined] + fit_predict_sproc_key + ] = fit_predict_wrapper_sproc + + return fit_predict_wrapper_sproc + def train(self) -> object: """ Trains the model by pushing down the compute into Snowflake using stored procedures. @@ -337,3 +460,64 @@ def train(self) -> object: file_name=sproc_export_file_name, statement_params=statement_params, ) + + def train_fit_predict( + self, + pass_through_columns: List[str], + expected_output_cols_list: List[str], + ) -> Tuple[Union[DataFrame, pd.DataFrame], object]: + """Trains the model by pushing down the compute into Snowflake using stored procedures. + This API is different from fit itself because it would also provide the predict + output. + + Args: + pass_through_columns (List[str]): The column names that would + display in the returned dataset. + expected_output_cols_list (List[str]): The output columns + name as a list. Defaults to None. + + Returns: + Tuple[Union[DataFrame, pd.DataFrame], object]: [predicted dataset, estimator] + """ + dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(self.dataset) + + # Extract query that generated the dataframe. We will need to pass it to the fit procedure. + queries = dataset.queries["queries"] + + transform_stage_name = self._create_temp_stage() + (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage( + stage_name=transform_stage_name + ) + + # Call fit sproc + statement_params = telemetry.get_function_usage_statement_params( + project=_PROJECT, + subproject=self._subproject, + function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name), + api_calls=[Session.call], + custom_tags=dict([("autogen", True)]) if self._autogenerated else None, + ) + + fit_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params) + fit_predict_result_name = random_name_for_temp_object(TempObjectType.TABLE) + + sproc_export_file_name: str = fit_wrapper_sproc( + self.session, + queries, + stage_transform_file_name, + stage_result_file_name, + self.input_cols, + statement_params, + pass_through_columns, + expected_output_cols_list, + fit_predict_result_name, + ) + + output_result_sp = self.session.table(fit_predict_result_name) + fitted_estimator = self._fetch_model_from_stage( + dir_path=stage_result_file_name, + file_name=sproc_export_file_name, + statement_params=statement_params, + ) + + return output_result_sp, fitted_estimator diff --git a/snowflake/ml/modeling/_internal/transformer_protocols.py b/snowflake/ml/modeling/_internal/transformer_protocols.py new file mode 100644 index 00000000..cf293b79 --- /dev/null +++ b/snowflake/ml/modeling/_internal/transformer_protocols.py @@ -0,0 +1,191 @@ +from typing import Any, List, Optional, Protocol, TypedDict, Union + +import pandas as pd + +from snowflake import snowpark + + +class LocalModelTransformHandlers(Protocol): + """A protocol defining the behavior of a local execution model transformer.""" + + def __init__( + self, + dataset: pd.DataFrame, + estimator: object, + class_name: str, + subproject: str, + autogenerated: Optional[bool] = False, + ) -> None: + """ + Args: + dataset: The dataset to run transform functions on. + estimator: The estimator used to run transforms. + class_name: class name to be used in telemetry. + subproject: subproject to be used in telemetry. + autogenerated: Whether the class was autogenerated from a template. + """ + ... + + def batch_inference( + self, + inference_method: str, + input_cols: List[str], + expected_output_cols: List[str], + snowpark_input_cols: Optional[List[str]], + drop_input_cols: Optional[bool] = False, + *args: Any, + **kwargs: Any, + ) -> pd.DataFrame: + """Run batch inference on the given dataset. + + Args: + inference_method: the name of the method used by `estimator` to run inference. + input_cols: column names of the input dataset. + expected_output_cols: column names (in order) of the output dataset. + snowpark_input_cols: list of input columns used if estimator is fit in snowpark. + drop_input_cols: If set True, the response will not contain input columns. + args: additional positional arguments. + kwargs: additional keyword args. + + Returns: + A new dataset of the same type as the input dataset. + + # noqa: DAR202 + (function in protocol definition does not actually return a value) + """ + ... + + def score( + self, + input_cols: List[str], + label_cols: List[str], + sample_weight_col: Optional[str], + *args: Any, + **kwargs: Any, + ) -> float: + """Score the given test dataset. + + Args: + input_cols: List of feature columns for scoring. + label_cols: List of label columns for scoring. + sample_weight_col: A column assigning relative weights to each row for scoring. + args: additional positional arguments. + kwargs: additional keyword args. + + Returns: + An accuracy score for the model on the given test data. + + # noqa: DAR202 + (function in protocol definition does not actually return a value) + """ + ... + + +class RemoteModelTransformHandlers(Protocol): + """A protocol defining behavior of a local execution model transformer.""" + + def __init__( + self, + dataset: snowpark.DataFrame, + estimator: object, + class_name: str, + subproject: str, + autogenerated: Optional[bool] = False, + ) -> None: + """ + Args: + dataset: The dataset to run transform functions on. + estimator: The estimator used to run transforms. + class_name: class name to be used in telemetry. + subproject: subproject to be used in telemetry. + autogenerated: Whether the class was autogenerated from a template. + """ + ... + + def batch_inference( + self, + inference_method: str, + input_cols: List[str], + expected_output_cols: List[str], + pass_through_cols: List[str], + session: snowpark.Session, + dependencies: List[str], + expected_output_cols_type: Optional[str] = "", + *args: Any, + **kwargs: Any, + ) -> snowpark.DataFrame: + """Run batch inference on the given dataset. + + Args: + session: An active Snowpark Session. + dependencies: List of dependencies for the transformer. + inference_method: the name of the method used by `estimator` to run inference. + input_cols: List of feature columns for inference. + pass_through_cols: columns in the dataset not used in inference. + expected_output_cols: column names (in order) of the output dataset. + expected_output_cols_type: Expected type of the output columns. + args: additional positional arguments. + kwargs: additional keyword args. + + Returns: + A new dataset of the same type as the input dataset. + + # noqa: DAR202 + (function in protocol definition does not actually return a value) + """ + ... + + def score( + self, + input_cols: List[str], + label_cols: List[str], + session: snowpark.Session, + dependencies: List[str], + score_sproc_imports: List[str], + sample_weight_col: Optional[str] = None, + *args: Any, + **kwargs: Any, + ) -> float: + """Score the given test dataset. + + Args: + session: An active Snowpark Session. + dependencies: score function dependencies. + score_sproc_imports: imports for score stored procedure. + input_cols: List of feature columns for inference. + label_cols: List of label columns for scoring. + sample_weight_col: A column assigning relative weights to each row for scoring. + args: additional positional arguments. + kwargs: additional keyword args. + + Returns: + An accuracy score for the model on the given test data. + + # noqa: DAR202 + (function in protocol definition does not actually return a value) + """ + ... + + +ModelTransformHandlers = Union[LocalModelTransformHandlers, RemoteModelTransformHandlers] + + +class BatchInferenceKwargsTypedDict(TypedDict, total=False): + """A typed dict specifying all possible optional keyword args accepted by batch_inference() methods.""" + + snowpark_input_cols: Optional[List[str]] + drop_input_cols: Optional[bool] + pass_through_cols: List[str] + session: snowpark.Session + dependencies: List[str] + expected_output_cols_type: str + n_neighbors: Optional[int] + return_distance: bool + + +class ScoreKwargsTypedDict(TypedDict, total=False): + """A typed dict specifying all possible optional keyword args accepted by score() methods.""" + + session: snowpark.Session + dependencies: List[str] + score_sproc_imports: List[str] diff --git a/snowflake/ml/modeling/framework/BUILD.bazel b/snowflake/ml/modeling/framework/BUILD.bazel index 272c0531..1d0b254a 100644 --- a/snowflake/ml/modeling/framework/BUILD.bazel +++ b/snowflake/ml/modeling/framework/BUILD.bazel @@ -16,7 +16,8 @@ py_library( "//snowflake/ml/_internal/exceptions:modeling_error_messages", "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/_internal/utils:parallelize", - "//snowflake/ml/modeling/_internal:estimator_protocols", + "//snowflake/ml/modeling/_internal:transformer_protocols", + "//snowflake/ml/modeling/_internal/local_implementations:pandas_handlers", "//snowflake/ml/modeling/_internal/snowpark_implementations:snowpark_handlers", ], ) diff --git a/snowflake/ml/modeling/model_selection/BUILD.bazel b/snowflake/ml/modeling/model_selection/BUILD.bazel index 7d937456..82e857a5 100644 --- a/snowflake/ml/modeling/model_selection/BUILD.bazel +++ b/snowflake/ml/modeling/model_selection/BUILD.bazel @@ -29,6 +29,7 @@ py_library( "//snowflake/ml/_internal:telemetry", "//snowflake/ml/_internal/exceptions", "//snowflake/ml/modeling/_internal:model_trainer_builder", + "//snowflake/ml/modeling/_internal:transformer_protocols", "//snowflake/ml/modeling/_internal/snowpark_implementations:snowpark_handlers", ], ) @@ -41,6 +42,7 @@ py_library( "//snowflake/ml/_internal:telemetry", "//snowflake/ml/_internal/exceptions", "//snowflake/ml/modeling/_internal:model_trainer_builder", + "//snowflake/ml/modeling/_internal:transformer_protocols", "//snowflake/ml/modeling/_internal/snowpark_implementations:snowpark_handlers", ], ) diff --git a/snowflake/ml/modeling/model_selection/grid_search_cv.py b/snowflake/ml/modeling/model_selection/grid_search_cv.py index 93442a97..8adad90e 100644 --- a/snowflake/ml/modeling/model_selection/grid_search_cv.py +++ b/snowflake/ml/modeling/model_selection/grid_search_cv.py @@ -21,7 +21,6 @@ ModelSignature, _infer_signature, ) -from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers from snowflake.ml.modeling._internal.estimator_utils import ( gather_dependencies, original_estimator_has_callable, @@ -29,11 +28,15 @@ validate_sklearn_args, ) from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder -from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import ( - SnowparkHandlers as HandlersImpl, +from snowflake.ml.modeling._internal.model_transformer_builder import ( + ModelTransformerBuilder, +) +from snowflake.ml.modeling._internal.transformer_protocols import ( + BatchInferenceKwargsTypedDict, + ScoreKwargsTypedDict, ) from snowflake.ml.modeling.framework.base import BaseTransformer -from snowflake.snowpark import DataFrame +from snowflake.snowpark import DataFrame, Session from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type _PROJECT = "ModelDevelopment" @@ -43,6 +46,8 @@ _SUBPROJECT = "ModelSelection" DEFAULT_UDTF_NJOBS = 3 +DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame] + class GridSearchCV(BaseTransformer): r"""Exhaustive search over specified parameter values for an estimator @@ -265,10 +270,11 @@ def __init__( # type: ignore[no-untyped-def] self.set_drop_input_cols(drop_input_cols) self.set_sample_weight_col(sample_weight_col) self.set_passthrough_cols(passthrough_cols) - self._handlers: TransformerHandlers = HandlersImpl( - class_name=self.__class__.__name__, - subproject=_SUBPROJECT, - ) + self._autogenerated = False + self._snowpark_cols = self.input_cols + self._autogenerated = False + self._class_name = GridSearchCV.__class__.__name__ + self._subproject = _SUBPROJECT def _get_active_columns(self) -> List[str]: """ "Get the list of columns that are relevant to the transformer.""" @@ -332,14 +338,8 @@ def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]: else: return list(set(dataset.columns) - set(self.output_cols)) - def _batch_inference( - self, - dataset: DataFrame, - inference_method: str, - expected_output_cols_list: List[str], - expected_output_cols_type: str = "", - ) -> DataFrame: - """Util method to create UDF and run batch inference.""" + def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> None: + """Util method to run validate that batch inference can be run on a snowpark dataframe.""" if not self._is_fitted: raise exceptions.SnowflakeMLException( error_code=error_codes.METHOD_NOT_ALLOWED, @@ -359,120 +359,6 @@ def _batch_inference( pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT ) - return self._handlers.batch_inference( - dataset, - session, - self._sklearn_object, - self._get_dependencies(), - inference_method, - self.input_cols, - self._get_pass_through_columns(dataset), - expected_output_cols_list, - expected_output_cols_type, - ) - - def _sklearn_inference( - self, dataset: pd.DataFrame, inference_method: str, expected_output_cols_list: List[str] - ) -> pd.DataFrame: - output_cols = expected_output_cols_list.copy() - - # Model expects exact same columns names in the input df for predict call. - # Given the scenario that user use snowpark DataFrame in fit call, but pandas DataFrame in predict call - # input cols need to match unquoted / quoted - input_cols = self.input_cols - unquoted_input_cols = identifier.get_unescaped_names(self.input_cols) - quoted_input_cols = identifier.get_inferred_names(unquoted_input_cols) - - estimator = self._sklearn_object - - assert estimator is not None - features_required_by_estimator = ( - estimator.feature_names_in_ if hasattr(estimator, "feature_names_in_") else unquoted_input_cols - ) - missing_features = [] - features_in_dataset = set(dataset.columns) - columns_to_select = [] - for i, f in enumerate(features_required_by_estimator): - if ( - i >= len(input_cols) - or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f) - or ( - input_cols[i] not in features_in_dataset - and unquoted_input_cols[i] not in features_in_dataset - and quoted_input_cols[i] not in features_in_dataset - ) - ): - missing_features.append(f) - elif input_cols[i] in features_in_dataset: - columns_to_select.append(input_cols[i]) - elif unquoted_input_cols[i] in features_in_dataset: - columns_to_select.append(unquoted_input_cols[i]) - else: - columns_to_select.append(quoted_input_cols[i]) - - if len(missing_features) > 0: - raise exceptions.SnowflakeMLException( - error_code=error_codes.NOT_FOUND, - original_exception=ValueError( - "The feature names should match with those that were passed during fit.\n" - f"Features seen during fit call but not present in the input: {missing_features}\n" - f"Features in the input dataframe : {input_cols}\n" - ), - ) - input_df = dataset[columns_to_select] - input_df.columns = features_required_by_estimator - - transformed_numpy_array = getattr(estimator, inference_method)(input_df) - - if ( - isinstance(transformed_numpy_array, list) - and len(transformed_numpy_array) > 0 - and isinstance(transformed_numpy_array[0], np.ndarray) - ): - # In case of multioutput estimators, predict_proba(), decision_function(), etc., functions return - # a list of ndarrays. We need to concatenate them. - - # First compute output column names - if len(output_cols) == len(transformed_numpy_array): - actual_output_cols = [] - for idx, np_arr in enumerate(transformed_numpy_array): - for i in range(1 if len(np_arr.shape) <= 1 else np_arr.shape[1]): - actual_output_cols.append(f"{output_cols[idx]}_{i}") - output_cols = actual_output_cols - - # Concatenate np arrays - transformed_numpy_array = np.concatenate(transformed_numpy_array, axis=1) - - if len(transformed_numpy_array.shape) == 3: - # VotingClassifier will return results of shape (n_classifiers, n_samples, n_classes) - # when voting = "soft" and flatten_transform = False. We can't handle unflatten transforms, - # so we ignore flatten_transform flag and flatten the results. - transformed_numpy_array = np.hstack(transformed_numpy_array) - - if len(transformed_numpy_array.shape) == 1: - transformed_numpy_array = np.reshape(transformed_numpy_array, (-1, 1)) - - shape = transformed_numpy_array.shape - if shape[1] != len(output_cols): - if len(output_cols) != 1: - raise exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=TypeError( - "expected_output_cols_list must be same length as transformed array or " "should be of length 1" - ), - ) - actual_output_cols = [] - for i in range(shape[1]): - actual_output_cols.append(f"{output_cols[0]}_{i}") - output_cols = actual_output_cols - - if self._drop_input_cols: - dataset = pd.DataFrame(data=transformed_numpy_array, columns=output_cols) - else: - dataset = dataset.copy() - dataset[output_cols] = transformed_numpy_array - return dataset - @available_if(original_estimator_has_callable("predict")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( project=_PROJECT, @@ -492,6 +378,12 @@ def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, p Transformed dataset. """ super()._check_dataset_type(dataset) + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + inference_method = "predict" + if isinstance(dataset, DataFrame): expected_type_inferred = "" # infer the datatype from label columns @@ -499,19 +391,39 @@ def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, p expected_type_inferred = convert_sp_to_sf_type( self.model_signatures["predict"].outputs[0].as_snowpark_type() ) - - output_df = self._batch_inference( + self._batch_inference_validate_snowpark( dataset=dataset, - inference_method="predict", - expected_output_cols_list=self.output_cols, + inference_method=inference_method, + ) + + assert isinstance( + dataset._session, Session + ) # mypy does not recognize the check in _batch_inference_validate_snowpark() + + transform_kwargs = dict( + session=dataset._session, + dependencies=self._get_dependencies(), + pass_through_cols=self._get_pass_through_columns(dataset), expected_output_cols_type=expected_type_inferred, ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="predict", - expected_output_cols_list=self.output_cols, - ) + transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols) + + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self.output_cols, + **transform_kwargs, + ) return output_df @@ -533,19 +445,41 @@ def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, Transformed dataset. """ super()._check_dataset_type(dataset) + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + inference_method = "transform" + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( - dataset=dataset, - inference_method="transform", - expected_output_cols_list=self.output_cols, + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + assert isinstance( + dataset._session, Session + ) # mypy does not recognize the check in _batch_inference_validate_snowpark() + + transform_kwargs = dict( + session=dataset._session, + dependencies=self._get_dependencies(), + pass_through_cols=self._get_pass_through_columns(dataset), ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="transform", - expected_output_cols_list=self.output_cols, - ) + transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols) + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self.output_cols, + **transform_kwargs, + ) return output_df def _get_output_column_names(self, output_cols_prefix: str) -> List[str]: @@ -599,20 +533,42 @@ def predict_proba( Output dataset with probability of the sample for each class in the model. """ super()._check_dataset_type(dataset) + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + + inference_method = "predict_proba" + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( - dataset=dataset, - inference_method="predict_proba", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + assert isinstance( + dataset._session, Session + ) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._get_dependencies(), + pass_through_cols=self._get_pass_through_columns(dataset), expected_output_cols_type="float", ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="predict_proba", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), - ) + transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols) + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix), + **transform_kwargs, + ) return output_df @available_if(original_estimator_has_callable("predict_log_proba")) # type: ignore[misc] @@ -637,20 +593,42 @@ def predict_log_proba( Output dataset with log probability of the sample for each class in the model. """ super()._check_dataset_type(dataset) + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + + inference_method = "predict_log_proba" + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( - dataset=dataset, - inference_method="predict_log_proba", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + assert isinstance( + dataset._session, Session + ) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._get_dependencies(), + pass_through_cols=self._get_pass_through_columns(dataset), expected_output_cols_type="float", ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="predict_log_proba", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), - ) + transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols) + + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix), + **transform_kwargs, + ) return output_df @available_if(original_estimator_has_callable("decision_function")) # type: ignore[misc] @@ -674,21 +652,43 @@ def decision_function( Returns: Output dataset with results of the decision function for the samples in input dataset. """ + super()._check_dataset_type(dataset) + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + inference_method = "decision_function" + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( - dataset=dataset, - inference_method="decision_function", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + assert isinstance( + dataset._session, Session + ) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._get_dependencies(), + pass_through_cols=self._get_pass_through_columns(dataset), expected_output_cols_type="float", ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="decision_function", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), - ) + transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols) + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix), + **transform_kwargs, + ) return output_df @available_if(original_estimator_has_callable("score_samples")) # type: ignore[misc] @@ -714,20 +714,42 @@ def score_samples( Output dataset with results of the decision function for the samples in input dataset. """ super()._check_dataset_type(dataset) + + inference_method = "score_samples" + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( - dataset=dataset, - inference_method="score_samples", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + assert isinstance( + dataset._session, Session + ) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._get_dependencies(), + pass_through_cols=self._get_pass_through_columns(dataset), expected_output_cols_type="float", ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="score_samples", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), - ) + transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols) + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix), + **transform_kwargs, + ) return output_df @available_if(original_estimator_has_callable("score")) # type: ignore[misc] @@ -742,37 +764,44 @@ def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float: Returns: Score. """ + self._infer_input_output_cols(dataset) super()._check_dataset_type(dataset) - if isinstance(dataset, pd.DataFrame): - output_score = self._handlers.score_pandas( - dataset, self._sklearn_object, self.input_cols, self.label_cols, self.sample_weight_col + + # This dictionary contains optional kwargs for scoring. These kwargs + # are specific to the type of dataset used. + transform_kwargs: ScoreKwargsTypedDict = dict() + + if isinstance(dataset, DataFrame): + selected_cols = self._get_active_columns() + if len(selected_cols) > 0: + dataset = dataset.select(selected_cols) + assert isinstance(dataset._session, Session) # keep mypy happy + transform_kwargs = dict( + session=dataset._session, + dependencies=["snowflake-snowpark-python"] + self._get_dependencies(), + score_sproc_imports=["sklearn"], ) - elif isinstance(dataset, DataFrame): - output_score = self._score_snowpark(dataset) - return output_score + elif isinstance(dataset, pd.DataFrame): + # pandas_handler.score() does not require any extra kwargs. + transform_kwargs = dict() - def _score_snowpark(self, dataset: DataFrame) -> float: - # Specify input columns so column pruning will be enforced - selected_cols = self._get_active_columns() - if len(selected_cols) > 0: - dataset = dataset.select(selected_cols) + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) - session = dataset._session - assert session is not None # keep mypy happy - - score = self._handlers.score_snowpark( - dataset, - session, - self._sklearn_object, - ["snowflake-snowpark-python"] + self._get_dependencies(), - ["sklearn"], - identifier.get_unescaped_names(self.input_cols), - identifier.get_unescaped_names(self.label_cols), - identifier.get_unescaped_names(self.sample_weight_col), + output_score = transform_handlers.score( + input_cols=identifier.get_unescaped_names(self.input_cols), + label_cols=identifier.get_unescaped_names(self.label_cols), + sample_weight_col=identifier.get_unescaped_names(self.sample_weight_col), + **transform_kwargs, ) - return score + return output_score def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None: self._model_signature_dict = dict() diff --git a/snowflake/ml/modeling/model_selection/randomized_search_cv.py b/snowflake/ml/modeling/model_selection/randomized_search_cv.py index 36a50637..e4861826 100644 --- a/snowflake/ml/modeling/model_selection/randomized_search_cv.py +++ b/snowflake/ml/modeling/model_selection/randomized_search_cv.py @@ -18,7 +18,6 @@ ModelSignature, _infer_signature, ) -from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers from snowflake.ml.modeling._internal.estimator_utils import ( gather_dependencies, original_estimator_has_callable, @@ -26,11 +25,15 @@ validate_sklearn_args, ) from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder -from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import ( - SnowparkHandlers as HandlersImpl, +from snowflake.ml.modeling._internal.model_transformer_builder import ( + ModelTransformerBuilder, +) +from snowflake.ml.modeling._internal.transformer_protocols import ( + BatchInferenceKwargsTypedDict, + ScoreKwargsTypedDict, ) from snowflake.ml.modeling.framework.base import BaseTransformer -from snowflake.snowpark import DataFrame +from snowflake.snowpark import DataFrame, Session from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type _PROJECT = "ModelDevelopment" @@ -40,6 +43,8 @@ _SUBPROJECT = "ModelSelection" DEFAULT_UDTF_NJOBS = 3 +DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame] + class RandomizedSearchCV(BaseTransformer): r"""Randomized search on hyper parameters @@ -277,10 +282,12 @@ def __init__( # type: ignore[no-untyped-def] self.set_drop_input_cols(drop_input_cols) self.set_sample_weight_col(sample_weight_col) self.set_passthrough_cols(passthrough_cols) - self._handlers: TransformerHandlers = HandlersImpl( - class_name=self.__class__.__name__, - subproject=_SUBPROJECT, - ) + + self._autogenerated = False + self._snowpark_cols = self.input_cols + self._autogenerated = False + self._class_name = RandomizedSearchCV.__class__.__name__ + self._subproject = _SUBPROJECT def _get_active_columns(self) -> List[str]: """ "Get the list of columns that are relevant to the transformer.""" @@ -344,14 +351,8 @@ def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]: else: return list(set(dataset.columns) - set(self.output_cols)) - def _batch_inference( - self, - dataset: DataFrame, - inference_method: str, - expected_output_cols_list: List[str], - expected_output_cols_type: str = "", - ) -> DataFrame: - """Util method to create UDF and run batch inference.""" + def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> None: + """Util method to run validate that batch inference can be run on a snowpark dataframe.""" if not self._is_fitted: raise exceptions.SnowflakeMLException( error_code=error_codes.METHOD_NOT_ALLOWED, @@ -371,120 +372,6 @@ def _batch_inference( pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT ) - return self._handlers.batch_inference( - dataset, - session, - self._sklearn_object, - self._get_dependencies(), - inference_method, - self.input_cols, - self._get_pass_through_columns(dataset), - expected_output_cols_list, - expected_output_cols_type, - ) - - def _sklearn_inference( - self, dataset: pd.DataFrame, inference_method: str, expected_output_cols_list: List[str] - ) -> pd.DataFrame: - output_cols = expected_output_cols_list.copy() - - # Model expects exact same columns names in the input df for predict call. - # Given the scenario that user use snowpark DataFrame in fit call, but pandas DataFrame in predict call - # input cols need to match unquoted / quoted - input_cols = self.input_cols - unquoted_input_cols = identifier.get_unescaped_names(self.input_cols) - quoted_input_cols = identifier.get_inferred_names(unquoted_input_cols) - - estimator = self._sklearn_object - - assert estimator is not None - features_required_by_estimator = ( - estimator.feature_names_in_ if hasattr(estimator, "feature_names_in_") else unquoted_input_cols - ) - missing_features = [] - features_in_dataset = set(dataset.columns) - columns_to_select = [] - for i, f in enumerate(features_required_by_estimator): - if ( - i >= len(input_cols) - or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f) - or ( - input_cols[i] not in features_in_dataset - and unquoted_input_cols[i] not in features_in_dataset - and quoted_input_cols[i] not in features_in_dataset - ) - ): - missing_features.append(f) - elif input_cols[i] in features_in_dataset: - columns_to_select.append(input_cols[i]) - elif unquoted_input_cols[i] in features_in_dataset: - columns_to_select.append(unquoted_input_cols[i]) - else: - columns_to_select.append(quoted_input_cols[i]) - - if len(missing_features) > 0: - raise exceptions.SnowflakeMLException( - error_code=error_codes.NOT_FOUND, - original_exception=ValueError( - "The feature names should match with those that were passed during fit.\n" - f"Features seen during fit call but not present in the input: {missing_features}\n" - f"Features in the input dataframe : {input_cols}\n" - ), - ) - input_df = dataset[columns_to_select] - input_df.columns = features_required_by_estimator - - transformed_numpy_array = getattr(estimator, inference_method)(input_df) - - if ( - isinstance(transformed_numpy_array, list) - and len(transformed_numpy_array) > 0 - and isinstance(transformed_numpy_array[0], np.ndarray) - ): - # In case of multioutput estimators, predict_proba(), decision_function(), etc., functions return - # a list of ndarrays. We need to concatenate them. - - # First compute output column names - if len(output_cols) == len(transformed_numpy_array): - actual_output_cols = [] - for idx, np_arr in enumerate(transformed_numpy_array): - for i in range(1 if len(np_arr.shape) <= 1 else np_arr.shape[1]): - actual_output_cols.append(f"{output_cols[idx]}_{i}") - output_cols = actual_output_cols - - # Concatenate np arrays - transformed_numpy_array = np.concatenate(transformed_numpy_array, axis=1) - - if len(transformed_numpy_array.shape) == 3: - # VotingClassifier will return results of shape (n_classifiers, n_samples, n_classes) - # when voting = "soft" and flatten_transform = False. We can't handle unflatten transforms, - # so we ignore flatten_transform flag and flatten the results. - transformed_numpy_array = np.hstack(transformed_numpy_array) - - if len(transformed_numpy_array.shape) == 1: - transformed_numpy_array = np.reshape(transformed_numpy_array, (-1, 1)) - - shape = transformed_numpy_array.shape - if shape[1] != len(output_cols): - if len(output_cols) != 1: - raise exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=TypeError( - "expected_output_cols_list must be same length as transformed array or " "should be of length 1" - ), - ) - actual_output_cols = [] - for i in range(shape[1]): - actual_output_cols.append(f"{output_cols[0]}_{i}") - output_cols = actual_output_cols - - if self._drop_input_cols: - dataset = pd.DataFrame(data=transformed_numpy_array, columns=output_cols) - else: - dataset = dataset.copy() - dataset[output_cols] = transformed_numpy_array - return dataset - @available_if(original_estimator_has_callable("predict")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( project=_PROJECT, @@ -503,6 +390,12 @@ def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, p Transformed dataset. """ super()._check_dataset_type(dataset) + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + inference_method = "predict" + if isinstance(dataset, DataFrame): expected_type_inferred = "" # infer the datatype from label columns @@ -510,19 +403,37 @@ def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, p expected_type_inferred = convert_sp_to_sf_type( self.model_signatures["predict"].outputs[0].as_snowpark_type() ) - - output_df = self._batch_inference( + self._batch_inference_validate_snowpark( dataset=dataset, - inference_method="predict", - expected_output_cols_list=self.output_cols, + inference_method=inference_method, + ) + assert isinstance( + dataset._session, Session + ) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._get_dependencies(), + pass_through_cols=self._get_pass_through_columns(dataset), expected_output_cols_type=expected_type_inferred, ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="predict", - expected_output_cols_list=self.output_cols, - ) + transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols) + + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self.output_cols, + **transform_kwargs, + ) return output_df @@ -544,19 +455,40 @@ def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, Transformed dataset. """ super()._check_dataset_type(dataset) + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + inference_method = "transform" + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( - dataset=dataset, - inference_method="transform", - expected_output_cols_list=self.output_cols, + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + assert isinstance( + dataset._session, Session + ) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._get_dependencies(), + pass_through_cols=self._get_pass_through_columns(dataset), ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="transform", - expected_output_cols_list=self.output_cols, - ) + transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols) + + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self.output_cols, + **transform_kwargs, + ) return output_df def _get_output_column_names(self, output_cols_prefix: str) -> List[str]: @@ -610,20 +542,42 @@ def predict_proba( Output dataset with probability of the sample for each class in the model. """ super()._check_dataset_type(dataset) + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + + inference_method = "predict_proba" + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( - dataset=dataset, - inference_method="predict_proba", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + assert isinstance( + dataset._session, Session + ) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._get_dependencies(), + pass_through_cols=self._get_pass_through_columns(dataset), expected_output_cols_type="float", ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="predict_proba", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), - ) + transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols) + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix), + **transform_kwargs, + ) return output_df @available_if(original_estimator_has_callable("predict_log_proba")) # type: ignore[misc] @@ -648,20 +602,42 @@ def predict_log_proba( Output dataset with log probability of the sample for each class in the model. """ super()._check_dataset_type(dataset) + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + + inference_method = "predict_log_proba" + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( - dataset=dataset, - inference_method="predict_log_proba", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + assert isinstance( + dataset._session, Session + ) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._get_dependencies(), + pass_through_cols=self._get_pass_through_columns(dataset), expected_output_cols_type="float", ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="predict_log_proba", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), - ) + transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols) + + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix), + **transform_kwargs, + ) return output_df @available_if(original_estimator_has_callable("decision_function")) # type: ignore[misc] @@ -686,20 +662,41 @@ def decision_function( Output dataset with results of the decision function for the samples in input dataset. """ super()._check_dataset_type(dataset) + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + inference_method = "decision_function" + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( - dataset=dataset, - inference_method="decision_function", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + assert isinstance( + dataset._session, Session + ) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._get_dependencies(), + pass_through_cols=self._get_pass_through_columns(dataset), expected_output_cols_type="float", ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="decision_function", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), - ) + transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols) + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix), + **transform_kwargs, + ) return output_df @available_if(original_estimator_has_callable("score_samples")) # type: ignore[misc] @@ -725,20 +722,42 @@ def score_samples( Output dataset with results of the decision function for the samples in input dataset. """ super()._check_dataset_type(dataset) + + inference_method = "score_samples" + + # This dictionary contains optional kwargs for batch inference. These kwargs + # are specific to the type of dataset used. + transform_kwargs: BatchInferenceKwargsTypedDict = dict() + if isinstance(dataset, DataFrame): - output_df = self._batch_inference( - dataset=dataset, - inference_method="score_samples", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + assert isinstance( + dataset._session, Session + ) # mypy does not recognize the check in _batch_inference_validate_snowpark() + transform_kwargs = dict( + session=dataset._session, + dependencies=self._get_dependencies(), + pass_through_cols=self._get_pass_through_columns(dataset), expected_output_cols_type="float", ) + elif isinstance(dataset, pd.DataFrame): - output_df = self._sklearn_inference( - dataset=dataset, - inference_method="score_samples", - expected_output_cols_list=self._get_output_column_names(output_cols_prefix), - ) + transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols) + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + + output_df: DATAFRAME_TYPE = transform_handlers.batch_inference( + inference_method=inference_method, + input_cols=self.input_cols, + expected_output_cols=self._get_output_column_names(output_cols_prefix), + **transform_kwargs, + ) return output_df @available_if(original_estimator_has_callable("score")) # type: ignore[misc] @@ -755,35 +774,42 @@ def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float: """ self._infer_input_output_cols(dataset) super()._check_dataset_type(dataset) - if isinstance(dataset, pd.DataFrame): - output_score = self._handlers.score_pandas( - dataset, self._sklearn_object, self.input_cols, self.label_cols, self.sample_weight_col + + # This dictionary contains optional kwargs for scoring. These kwargs + # are specific to the type of dataset used. + transform_kwargs: ScoreKwargsTypedDict = dict() + + if isinstance(dataset, DataFrame): + selected_cols = self._get_active_columns() + if len(selected_cols) > 0: + dataset = dataset.select(selected_cols) + + assert isinstance(dataset._session, Session) # keep mypy happy + transform_kwargs = dict( + session=dataset._session, + dependencies=["snowflake-snowpark-python"] + self._get_dependencies(), + score_sproc_imports=["sklearn"], ) - elif isinstance(dataset, DataFrame): - output_score = self._score_snowpark(dataset) - return output_score + elif isinstance(dataset, pd.DataFrame): + # pandas_handler.score() does not require any extra kwargs. + transform_kwargs = dict() - def _score_snowpark(self, dataset: DataFrame) -> float: - # Specify input columns so column pruning will be enforced - selected_cols = self._get_active_columns() - if len(selected_cols) > 0: - dataset = dataset.select(selected_cols) + transform_handlers = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) - session = dataset._session - assert session is not None # keep mypy happy - - score = self._handlers.score_snowpark( - dataset, - session, - self._sklearn_object, - ["snowflake-snowpark-python"] + self._get_dependencies(), - ["sklearn"], - identifier.get_unescaped_names(self.input_cols), - identifier.get_unescaped_names(self.label_cols), - identifier.get_unescaped_names(self.sample_weight_col), + output_score = transform_handlers.score( + input_cols=identifier.get_unescaped_names(self.input_cols), + label_cols=identifier.get_unescaped_names(self.label_cols), + sample_weight_col=identifier.get_unescaped_names(self.sample_weight_col), + **transform_kwargs, ) - return score + return output_score def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None: self._model_signature_dict = dict() diff --git a/snowflake/ml/registry/model_registry.py b/snowflake/ml/registry/model_registry.py index f19de18c..23d7afb1 100644 --- a/snowflake/ml/registry/model_registry.py +++ b/snowflake/ml/registry/model_registry.py @@ -1471,6 +1471,18 @@ def load_model(self, model_name: str, model_version: str) -> Any: Returns: Restored model object. """ + warnings.warn( + ( + "Please use with caution: " + "Using `load_model` method requires you to have the EXACT same Python environments " + "as the one when you logged the model. Any differences will potentially lead to errors.\n" + "Also, if your model contains custom code imported using `code_paths` argument when logging, " + "they will be added to your `sys.path`. It might lead to unexpected module importing issues. " + "If you run into such kind of problems, you need to restart your Python or Notebook kernel." + ), + category=UserWarning, + stacklevel=2, + ) remote_model_path = self._get_model_path(model_name=model_name, model_version=model_version) restored_model = None diff --git a/snowflake/ml/registry/registry.py b/snowflake/ml/registry/registry.py index 22387889..856090c6 100644 --- a/snowflake/ml/registry/registry.py +++ b/snowflake/ml/registry/registry.py @@ -123,7 +123,7 @@ def log_model( Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda Channel. Otherwise, defaults to False - relax_version: Whether or not relax the version constraints of the dependencies. - It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to False. + It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True. - method_options: Per-method saving options including: - case_sensitive: Indicates whether the method and its signature should be case sensitive. This means when you refer the method in the SQL, you need to double quote it. diff --git a/snowflake/ml/version.bzl b/snowflake/ml/version.bzl index 6e677b35..0cda24de 100644 --- a/snowflake/ml/version.bzl +++ b/snowflake/ml/version.bzl @@ -1,2 +1,2 @@ # This is parsed by regex in conda reciper meta file. Make sure not to break it. -VERSION = "1.2.3" +VERSION = "1.3.0" diff --git a/tests/integ/snowflake/ml/_internal/snowpark_handlers_test.py b/tests/integ/snowflake/ml/_internal/snowpark_handlers_test.py index 377f94d8..0ad2758f 100644 --- a/tests/integ/snowflake/ml/_internal/snowpark_handlers_test.py +++ b/tests/integ/snowflake/ml/_internal/snowpark_handlers_test.py @@ -8,7 +8,7 @@ from sklearn.linear_model import LinearRegression as SkLinearRegression from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import ( - SnowparkHandlers, + SnowparkTransformHandlers, ) from tests.integ.snowflake.ml.test_utils import common_test_base @@ -17,7 +17,17 @@ class SnowparkHandlersTest(common_test_base.CommonTestBase): def setUp(self) -> None: """Creates Snowpark and Snowflake environments for testing.""" super().setUp() - self._handlers = SnowparkHandlers(class_name="test", subproject="subproject") + sklearn_estimator = SkLinearRegression() + + self.input_df_pandas, self.input_cols, self.label_cols = self._get_test_dataset() + self.fit_estimator = sklearn_estimator.fit( + X=self.input_df_pandas[self.input_cols], y=self.input_df_pandas[self.label_cols].squeeze() + ) + self.input_df = self.session.create_dataframe(self.input_df_pandas) + + self._handlers = SnowparkTransformHandlers( + dataset=self.input_df, estimator=self.fit_estimator, class_name="test", subproject="subproject" + ) def _get_test_dataset(self) -> Tuple[pd.DataFrame, List[str], List[str]]: """Constructs input dataset to be used in the integration test. @@ -52,51 +62,39 @@ def _get_test_dataset(self) -> Tuple[pd.DataFrame, List[str], List[str]]: @common_test_base.CommonTestBase.sproc_test(additional_packages=["inflection"]) def test_batch_inference(self) -> None: - sklearn_estimator = SkLinearRegression() - input_df_pandas, input_cols, label_cols = self._get_test_dataset() - input_df = self.session.create_dataframe(input_df_pandas) - - fit_estimator = sklearn_estimator.fit(X=input_df_pandas[input_cols], y=input_df_pandas[label_cols].squeeze()) - output_cols = ["OUTPUT_" + c for c in label_cols] + output_cols = ["OUTPUT_" + c for c in self.label_cols] predictions = self._handlers.batch_inference( - dataset=input_df, session=self.session, - estimator=fit_estimator, dependencies=["snowflake-snowpark-python", "numpy", "scikit-learn", "cloudpickle"], inference_method="predict", - input_cols=input_cols, - pass_through_columns=list(set(input_df.columns) - set(output_cols)), - expected_output_cols_list=output_cols, + input_cols=self.input_cols, + pass_through_cols=list(set(self.input_df.columns) - set(output_cols)), + expected_output_cols=output_cols, expected_output_cols_type="INT", ) - sklearn_numpy_arr = fit_estimator.predict(input_df_pandas[input_cols]) + sklearn_numpy_arr = self.fit_estimator.predict(self.input_df_pandas[self.input_cols]) sf_numpy_arr = predictions.to_pandas().sort_values(by="INDEX")[output_cols].to_numpy().flatten() np.testing.assert_allclose(sklearn_numpy_arr, sf_numpy_arr, rtol=1.0e-1, atol=1.0e-2) @common_test_base.CommonTestBase.sproc_test(additional_packages=["inflection"]) def test_score_snowpark(self) -> None: - sklearn_estimator = SkLinearRegression() - input_df_pandas, input_cols, label_cols = self._get_test_dataset() - input_df = self.session.create_dataframe(input_df_pandas) - fit_estimator = sklearn_estimator.fit(X=input_df_pandas[input_cols], y=input_df_pandas[label_cols].squeeze()) - - score = self._handlers.score_snowpark( - dataset=input_df, + score = self._handlers.score( session=self.session, - estimator=fit_estimator, dependencies=["snowflake-snowpark-python", "numpy", "scikit-learn", "cloudpickle"], score_sproc_imports=["sklearn"], - input_cols=input_cols, - label_cols=label_cols, + input_cols=self.input_cols, + label_cols=self.label_cols, sample_weight_col=None, ) - sklearn_score = fit_estimator.score(input_df_pandas[input_cols], input_df_pandas[label_cols].squeeze()) + sklearn_score = self.fit_estimator.score( + self.input_df_pandas[self.input_cols], self.input_df_pandas[self.label_cols].squeeze() + ) np.testing.assert_allclose(score, sklearn_score) diff --git a/tests/integ/snowflake/ml/extra_tests/BUILD.bazel b/tests/integ/snowflake/ml/extra_tests/BUILD.bazel index 9f5058c9..1f73826c 100644 --- a/tests/integ/snowflake/ml/extra_tests/BUILD.bazel +++ b/tests/integ/snowflake/ml/extra_tests/BUILD.bazel @@ -127,30 +127,6 @@ py_test( ], ) -py_test( - name = "fit_predict_test", - srcs = ["fit_predict_test.py"], - shard_count = 3, - deps = [ - "//snowflake/ml/modeling/cluster:agglomerative_clustering", - "//snowflake/ml/modeling/cluster:dbscan", - "//snowflake/ml/modeling/cluster:optics", - "//snowflake/ml/utils:connection_params", - ], -) - -py_test( - name = "fit_transform_test", - srcs = ["fit_transform_test.py"], - shard_count = 3, - deps = [ - "//snowflake/ml/modeling/manifold:mds", - "//snowflake/ml/modeling/manifold:spectral_embedding", - "//snowflake/ml/modeling/manifold:tsne", - "//snowflake/ml/utils:connection_params", - ], -) - py_test( name = "decimal_type_test", srcs = ["decimal_type_test.py"], @@ -169,3 +145,15 @@ py_test( "//snowflake/ml/utils:connection_params", ], ) + +py_test( + name = "fit_transform_test", + srcs = ["fit_transform_test.py"], + shard_count = 3, + deps = [ + "//snowflake/ml/modeling/manifold:mds", + "//snowflake/ml/modeling/manifold:spectral_embedding", + "//snowflake/ml/modeling/manifold:tsne", + "//snowflake/ml/utils:connection_params", + ], +) diff --git a/tests/integ/snowflake/ml/extra_tests/fit_predict_test.py b/tests/integ/snowflake/ml/extra_tests/fit_predict_test.py deleted file mode 100644 index 1d828a1f..00000000 --- a/tests/integ/snowflake/ml/extra_tests/fit_predict_test.py +++ /dev/null @@ -1,64 +0,0 @@ -import numpy as np -import pandas as pd -from absl.testing.absltest import TestCase, main -from sklearn.cluster import ( - DBSCAN as SKDBSCAN, - OPTICS as SKOPTICS, - AgglomerativeClustering as SKAgglomerativeClustering, -) - -from snowflake.ml.modeling.cluster import DBSCAN, OPTICS, AgglomerativeClustering -from snowflake.ml.utils.connection_params import SnowflakeLoginOptions -from snowflake.snowpark import Session - - -class FitPredictTest(TestCase): - def setUp(self): - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(SnowflakeLoginOptions()).create() - - def tearDown(self): - self._session.close() - - def test_aggolomerative(self): - sample_data = np.array([[1, 2], [1, 4], [1, 0], [4, 2], [4, 4], [4, 0]]) - pd_df = pd.DataFrame(sample_data) - pd_df.columns = [str(c) for c in pd_df.columns] - sp_df = self._session.create_dataframe(pd_df) - agg = AgglomerativeClustering(input_cols=sp_df.columns) - sk_agg = SKAgglomerativeClustering() - - return_label = agg.fit_predict(sp_df) - sk_label = sk_agg.fit_predict(sample_data) - - np.testing.assert_allclose(return_label, sk_label, rtol=1.0e-1, atol=1.0e-2) - - def test_dbscan(self): - sample_data = np.array([[1, 2], [2, 2], [2, 3], [8, 7], [8, 8], [25, 80]]) - pd_df = pd.DataFrame(sample_data) - pd_df.columns = [str(c) for c in pd_df.columns] - sp_df = self._session.create_dataframe(pd_df) - dbs = DBSCAN(input_cols=sp_df.columns, eps=3, min_samples=2) - sk_dbs = SKDBSCAN(eps=3, min_samples=2) - - return_label = dbs.fit_predict(sp_df) - sk_label = sk_dbs.fit_predict(sample_data) - - np.testing.assert_allclose(return_label, sk_label, rtol=1.0e-1, atol=1.0e-2) - - def test_optics(self): - sample_data = np.array([[1, 2], [2, 5], [3, 6], [8, 7], [8, 8], [7, 3]]) - pd_df = pd.DataFrame(sample_data) - pd_df.columns = [str(c) for c in pd_df.columns] - sp_df = self._session.create_dataframe(pd_df) - opt = OPTICS(input_cols=sp_df.columns, min_samples=2) - sk_opt = SKOPTICS(min_samples=2) - - return_label = opt.fit_predict(sp_df) - sk_label = sk_opt.fit_predict(sample_data) - - np.testing.assert_allclose(return_label, sk_label, rtol=1.0e-1, atol=1.0e-2) - - -if __name__ == "__main__": - main() diff --git a/tests/integ/snowflake/ml/feature_store/BUILD.bazel b/tests/integ/snowflake/ml/feature_store/BUILD.bazel index 7e62c6b0..d8a15d5a 100644 --- a/tests/integ/snowflake/ml/feature_store/BUILD.bazel +++ b/tests/integ/snowflake/ml/feature_store/BUILD.bazel @@ -70,3 +70,17 @@ py_test( "//snowflake/ml/utils:connection_params", ], ) + +py_test( + name = "feature_store_access_test", + srcs = [ + "access_utils.py", + "feature_store_access_test.py", + ], + shard_count = 16, + deps = [ + ":common_utils", + "//snowflake/ml/feature_store:feature_store_lib", + "//snowflake/ml/utils:connection_params", + ], +) diff --git a/tests/integ/snowflake/ml/feature_store/access_utils.py b/tests/integ/snowflake/ml/feature_store/access_utils.py new file mode 100644 index 00000000..bae00f71 --- /dev/null +++ b/tests/integ/snowflake/ml/feature_store/access_utils.py @@ -0,0 +1,110 @@ +from enum import Enum +from typing import Dict, List + +from snowflake.ml.feature_store.feature_store import ( + _FEATURE_STORE_OBJECT_TAG, + _FEATURE_VIEW_ENTITY_TAG, + _FEATURE_VIEW_TS_COL_TAG, + FeatureStore, +) +from snowflake.snowpark import Session, exceptions + + +class FeatureStoreRole(Enum): + NONE = 0 + CONSUMER = 1 + PRODUCER = 2 + ADMIN = 9 + + +# Lists of permissions as tuples of (OBJECT_TYPE, [PRIVILEGES, ...]) +_PRIVILEGE_LEVELS: Dict[FeatureStoreRole, Dict[str, List[str]]] = { + FeatureStoreRole.ADMIN: { + "database {database}": ["CREATE SCHEMA"], + f"tag {{database}}.{{schema}}.{_FEATURE_VIEW_ENTITY_TAG}": ["OWNERSHIP"], + f"tag {{database}}.{{schema}}.{_FEATURE_VIEW_TS_COL_TAG}": ["OWNERSHIP"], + f"tag {{database}}.{{schema}}.{_FEATURE_STORE_OBJECT_TAG}": ["OWNERSHIP"], + "schema {database}.{schema}": ["OWNERSHIP"], + }, + FeatureStoreRole.PRODUCER: { + "schema {database}.{schema}": [ + "CREATE DYNAMIC TABLE", + "CREATE TABLE", + "CREATE TAG", + "CREATE VIEW", + ], + f"tag {{database}}.{{schema}}.{_FEATURE_VIEW_ENTITY_TAG}": ["APPLY"], + f"tag {{database}}.{{schema}}.{_FEATURE_VIEW_TS_COL_TAG}": ["APPLY"], + f"tag {{database}}.{{schema}}.{_FEATURE_STORE_OBJECT_TAG}": ["APPLY"], + # TODO: The below privileges should be granted on a per-resource level + # between producers (e.g. PRODUCER_A grants PRODUCER_B operate access + # to FEATURE_VIEW_0, but not FEATURE_VIEW_1) + "future tables in schema {database}.{schema}": ["INSERT"], + "all tables in schema {database}.{schema}": ["INSERT"], + "future dynamic tables in schema {database}.{schema}": ["OPERATE"], + "all dynamic tables in schema {database}.{schema}": ["OPERATE"], + "future tasks in schema {database}.{schema}": ["OPERATE"], + "all tasks in schema {database}.{schema}": ["OPERATE"], + }, + FeatureStoreRole.CONSUMER: { + # "warehouse {warehouse}": ["USAGE"], + "database {database}": ["USAGE"], + "schema {database}.{schema}": ["USAGE"], + "future dynamic tables in schema {database}.{schema}": [ + "SELECT", + "MONITOR", + ], + "all dynamic tables in schema {database}.{schema}": [ + "SELECT", + "MONITOR", + ], + "future views in schema {database}.{schema}": [ + "SELECT", + "REFERENCES", + ], + "all views in schema {database}.{schema}": [ + "SELECT", + "REFERENCES", + ], + }, + FeatureStoreRole.NONE: {}, +} + + +def configure_roles( + feature_store: FeatureStore, + admin_role_name: str = "FS_ADMIN", + producer_role_name: str = "FS_PRODUCER", + consumer_role_name: str = "FS_CONSUMER", +) -> None: + session = feature_store._session + session_info = { + "account": session.get_current_account(), + "database": feature_store._config.database, + "schema": feature_store._config.schema, + "warehouse": session.get_current_warehouse(), + } + + def _grant_privileges(session: Session, role_name: str, access_level: FeatureStoreRole) -> None: + for scope, privilege_list in _PRIVILEGE_LEVELS[access_level].items(): + session.sql( + f"grant {','.join(privilege_list)} on {scope.format(**session_info)} to role {role_name}" + ).collect() + + # Try ensuring roles exist. If fail (no CREATE ROLE privilege), just continue + try: + session.sql(f"create role if not exists {admin_role_name}").collect() + session.sql(f"create role if not exists {producer_role_name}").collect() + session.sql(f"create role if not exists {consumer_role_name}").collect() + except exceptions.SnowparkSQLException: + pass + + # Grant privileges to roles + _grant_privileges(session, admin_role_name, FeatureStoreRole.ADMIN) + _grant_privileges(session, producer_role_name, FeatureStoreRole.PRODUCER) + _grant_privileges(session, consumer_role_name, FeatureStoreRole.CONSUMER) + + # Build role hierarchy + # session.sql(f"grant role {consumer_role_name} to role {producer_role_name}").collect() + # session.sql(f"grant role {producer_role_name} to role {admin_role_name}").collect() + # session.sql(f"grant role {admin_role_name} to role {session.get_current_role()}").collect() diff --git a/tests/integ/snowflake/ml/feature_store/common_utils.py b/tests/integ/snowflake/ml/feature_store/common_utils.py index fae6dde9..cd6fb491 100644 --- a/tests/integ/snowflake/ml/feature_store/common_utils.py +++ b/tests/integ/snowflake/ml/feature_store/common_utils.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional from unittest.mock import Mock from uuid import uuid4 @@ -34,9 +34,11 @@ DB_OBJECT_EXPIRE_HOURS = 24 -def create_random_schema(session: Session, prefix: str, database: str = FS_INTEG_TEST_DB) -> str: +def create_random_schema( + session: Session, prefix: str, database: str = FS_INTEG_TEST_DB, additional_options: str = "" +) -> str: schema = prefix + "_" + uuid4().hex.upper() - session.sql(f"CREATE SCHEMA IF NOT EXISTS {database}.{schema}").collect() + session.sql(f"CREATE SCHEMA IF NOT EXISTS {database}.{schema} {additional_options}").collect() return schema @@ -69,6 +71,30 @@ def dispatch(*args: Any) -> Any: return session +def create_mock_table( + session: Session, database: Optional[str] = None, schema: Optional[str] = None, table_prefix: str = "TEST_TABLE" +) -> str: + test_table = f"{table_prefix}_{uuid4().hex.upper()}" + if schema: + test_table = schema + "." + test_table + if database: + assert bool(schema) + test_table = database + "." + test_table + session.sql( + f"""CREATE TABLE IF NOT EXISTS {test_table} + (name VARCHAR(64), id INT, title VARCHAR(128), age INT, dept VARCHAR(64), ts INT) + """ + ).collect() + session.sql( + f"""INSERT OVERWRITE INTO {test_table} (name, id, title, age, dept, ts) + VALUES + ('john', 1, 'boss', 20, 'sales', 100), + ('porter', 2, 'manager', 30, 'engineer', 200) + """ + ).collect() + return test_table + + def get_test_warehouse_name(session: Session) -> str: session_warehouse = session.get_current_warehouse() return session_warehouse if session_warehouse else "REGTEST_ML_4XL_MULTI" diff --git a/tests/integ/snowflake/ml/feature_store/feature_store_access_test.py b/tests/integ/snowflake/ml/feature_store/feature_store_access_test.py new file mode 100644 index 00000000..34c5e2e2 --- /dev/null +++ b/tests/integ/snowflake/ml/feature_store/feature_store_access_test.py @@ -0,0 +1,434 @@ +from inspect import isclass +from typing import Any, Callable, Dict, Optional, Type, Union +from uuid import uuid4 + +from absl.testing import absltest, parameterized +from access_utils import FeatureStoreRole as Role, configure_roles +from common_utils import ( + FS_INTEG_TEST_DB, + cleanup_temporary_objects, + create_mock_table, + create_random_schema, + get_test_warehouse_name, +) + +from snowflake.ml.feature_store.entity import Entity +from snowflake.ml.feature_store.feature_store import CreationMode, FeatureStore +from snowflake.ml.feature_store.feature_view import FeatureView, FeatureViewStatus +from snowflake.ml.utils.connection_params import SnowflakeLoginOptions +from snowflake.snowpark import Session, exceptions as snowpark_exceptions + +_TEST_ROLE_ADMIN = "FS_ROLE_ADMIN" +_TEST_ROLE_PRODUCER = "FS_ROLE_PRODUCER" +_TEST_ROLE_CONSUMER = "FS_ROLE_CONSUMER" +_TEST_ROLE_NONE = "FS_ROLE_NONE" + + +class FeatureStoreAccessTest(parameterized.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls._session = Session.builder.configs(SnowflakeLoginOptions()).create() + cleanup_temporary_objects(cls._session) + cls._test_roles = { + Role.ADMIN: _TEST_ROLE_ADMIN, + Role.PRODUCER: _TEST_ROLE_PRODUCER, + Role.CONSUMER: _TEST_ROLE_CONSUMER, + Role.NONE: _TEST_ROLE_NONE, + } + cls._test_warehouse = get_test_warehouse_name(cls._session) + cls._session.use_warehouse(cls._test_warehouse) + cls._test_database = FS_INTEG_TEST_DB + cls._test_admin = cls._session.get_current_role() + + try: + cls._test_schema = create_random_schema( + cls._session, "FS_TEST", database=cls._test_database, additional_options="WITH MANAGED ACCESS" + ) + cls._feature_store = FeatureStore( + cls._session, + cls._test_database, + cls._test_schema, + cls._test_warehouse, + creation_mode=CreationMode.CREATE_IF_NOT_EXIST, + ) + + configure_roles( + cls._feature_store, + admin_role_name=cls._test_roles[Role.ADMIN], + producer_role_name=cls._test_roles[Role.PRODUCER], + consumer_role_name=cls._test_roles[Role.CONSUMER], + ) + + cls._mock_table = cls._init_test_data() + for role_id in cls._test_roles.values(): + # Grant read access to mock source data table + cls._session.sql(f"GRANT SELECT ON TABLE {cls._mock_table} to role {role_id}").collect() + + except Exception as e: + cls.tearDownClass() + raise Exception(f"Test setup failed: {e}") + + @classmethod + def tearDownClass(cls) -> None: + cls._session.use_role(cls._test_admin) + cls._session.sql(f"DROP SCHEMA IF EXISTS {cls._test_database}.{cls._test_schema}").collect() + cls._session.close() + + def setUp(self) -> None: + self._session.use_role(self._test_admin) + + @classmethod + def _init_test_data(cls) -> str: + prev_role = cls._session.get_current_role() + try: + cls._session.use_role(cls._test_roles[Role.ADMIN]) + test_table: str = create_mock_table(cls._session, cls._test_database, cls._test_schema) + + # Create Entities + e = Entity("foo", ["id"]) + cls._feature_store.register_entity(e) + + fv1 = FeatureView( + name="fv1", + entities=[e], + feature_df=cls._session.sql(f"SELECT id, name, ts FROM {test_table}"), + timestamp_col="ts", + refresh_freq="DOWNSTREAM", + ) + fv1 = cls._feature_store.register_feature_view(feature_view=fv1, version="v1", block=True) + + fv2 = FeatureView( + name="fv2", + entities=[e], + feature_df=cls._session.sql(f"SELECT id, title, ts FROM {test_table}"), + timestamp_col="ts", + refresh_freq="DOWNSTREAM", + ) + fv2 = cls._feature_store.register_feature_view(feature_view=fv2, version="v1", block=True) + + return test_table + + finally: + cls._session.use_role(prev_role) + + def _test_access( + self, + method: Callable[[], Any], + required_access: Role, + test_access: Role, + expected_result: Optional[Union[Type[Exception], Callable[[Any], Optional[bool]], Any]] = None, + expected_access_exception: Type[Exception] = RuntimeError, + access_exception_dict: Optional[Dict[Role, Type[Exception]]] = None, + ) -> Any: + """ + Test a Feature Store API given a specified access level. + + Args: + method: Parameterless callable wrapping method under test + required_access: Expected minimum access needed to execute method + test_access: Access level to execute test under + expected_result: Expected outcome of method call with sufficient access. + May be an exception if method is expected to throw, a constant value, + or a callback with custom assertions and/or returns True on acceptance. + expected_access_exception: Expected exception on insufficient access. + access_exception_dict: Level-specific expected exceptions. Takes precedence + over expected_access_exception for matching access levels. + """ + prev_role = self._session.get_current_role() + try: + self._session.use_role(self._test_roles[test_access]) + if test_access.value < required_access.value: + # Access level specific exception types + if isinstance(access_exception_dict, dict) and test_access in access_exception_dict: + expected_access_exception = access_exception_dict[test_access] + + # TODO: Error pattern + with self.assertRaises(expected_access_exception): + return method() + elif isclass(expected_result) and issubclass(expected_result, Exception): + # TODO: Error pattern + with self.assertRaises(expected_result): + return method() + else: + result = method() + if expected_result is not None: + if callable(expected_result): + # TODO: Use original (admin) role to execute validator? + validate_result = expected_result(result) + self.assertTrue(validate_result is None or validate_result is True) + else: + self.assertEqual(expected_result, result) + return result + finally: + self._session.use_role(prev_role) + + @parameterized.product( + [ + { + "init_args": {"creation_mode": CreationMode.CREATE_IF_NOT_EXIST}, + "required_access": Role.ADMIN, + "expected_result": None, + }, + { + "init_args": {"creation_mode": CreationMode.FAIL_IF_NOT_EXIST}, + "required_access": Role.CONSUMER, + "expected_result": ValueError, + }, + ], + test_access=list(Role), + ) # type: ignore[misc] + def test_init( + self, + init_args: Dict[str, Any], + required_access: Role, + test_access: Role, + expected_result: Optional[Type[Exception]], + ) -> None: + schema_name = f"FS_TEST_{uuid4().hex.upper()}" + + def unit_under_test() -> FeatureStore: + return FeatureStore( + self._session, + self._test_database, + schema_name, + self._test_warehouse, + **init_args, + ) + + try: + self._test_access( + unit_under_test, + required_access, + test_access, + expected_result=expected_result, + access_exception_dict={Role.NONE: ValueError}, + ) + finally: + self._session.sql(f"DROP SCHEMA IF EXISTS {self._test_database}.{schema_name}").collect() + + @parameterized.product(required_access=[Role.ADMIN], test_access=list(Role)) # type: ignore[misc] + def test_clear(self, required_access: Role, test_access: Role) -> None: + # Create isolated Feature Store to test clearing + schema_admin = self._session.get_current_role() + schema = create_random_schema( + self._session, "FS_TEST", database=self._test_database, additional_options="WITH MANAGED ACCESS" + ) + try: + fs = FeatureStore( + self._session, + self._test_database, + schema, + self._test_warehouse, + creation_mode=CreationMode.CREATE_IF_NOT_EXIST, + ) + configure_roles( + fs, + admin_role_name=self._test_roles[Role.ADMIN], + producer_role_name=self._test_roles[Role.PRODUCER], + consumer_role_name=self._test_roles[Role.CONSUMER], + ) + + self._session.use_role(self._test_roles[Role.ADMIN]) + e = Entity(f"test_entity_{uuid4().hex.upper()}"[:32], ["test_key"]) + fs.register_entity(e) + + entity_count = len(fs.list_entities().collect()) + self.assertGreater(entity_count, 0) + self._test_access( + fs.clear, + required_access, + test_access, + ) + + # Do validation on FileSet contents outside _test_access since we need admin access + expected_entity_count = entity_count if test_access.value < Role.ADMIN.value else 0 + self.assertEqual(len(fs.list_entities().collect()), expected_entity_count) + finally: + self._session.use_role(schema_admin) + self._session.sql(f"DROP SCHEMA IF EXISTS {self._test_database}.{schema}").collect() + + @parameterized.product(required_access=[Role.CONSUMER], test_access=list(Role)) # type: ignore[misc] + def test_merge_features(self, required_access: Role, test_access: Role) -> None: + fv1 = self._feature_store.get_feature_view("fv1", "v1") + fv2 = self._feature_store.get_feature_view("fv2", "v1") + + self._test_access( + lambda: self._feature_store.merge_features([fv1, fv2], "merged_fv"), + required_access, + test_access, + access_exception_dict={Role.NONE: snowpark_exceptions.SnowparkSQLException}, + ) + + @parameterized.product(required_access=[Role.PRODUCER], test_access=list(Role)) # type: ignore[misc] + def test_register_feature_view(self, required_access: Role, test_access: Role) -> None: + e = self._feature_store.get_entity("foo") + fv = FeatureView( + name=f"test_fv_{uuid4().hex.upper()}"[:32], + entities=[e], + feature_df=self._session.sql(f"SELECT id, name, ts FROM {self._mock_table}"), + timestamp_col="ts", + refresh_freq="DOWNSTREAM", + ) + + fv = self._test_access( + lambda: self._feature_store.register_feature_view(fv, "test"), + required_access, + test_access, + expected_result=lambda _fv: self.assertIn( + _fv.status, (FeatureViewStatus.RUNNING, FeatureViewStatus.ACTIVE) + ), + ) + + @parameterized.product(required_access=[Role.PRODUCER], test_access=list(Role)) # type: ignore[misc] + def test_suspend_feature_view(self, required_access: Role, test_access: Role) -> None: + e = self._feature_store.get_entity("foo") + fv = FeatureView( + name="test_fv", + entities=[e], + feature_df=self._session.sql(f"SELECT id, name, ts FROM {self._mock_table}"), + timestamp_col="ts", + refresh_freq="DOWNSTREAM", + ) + fv = self._feature_store.register_feature_view(fv, "test", override=True) + + try: + self._test_access( + lambda: self._feature_store.suspend_feature_view(fv), + required_access, + test_access, + lambda _fv: self.assertEqual(FeatureViewStatus.SUSPENDED, _fv.status), + ), + finally: + self._feature_store.delete_feature_view(fv) + + @parameterized.product(required_access=[Role.PRODUCER], test_access=list(Role)) # type: ignore[misc] + def test_resume_feature_view(self, required_access: Role, test_access: Role) -> None: + e = self._feature_store.get_entity("foo") + fv = FeatureView( + name="test_fv", + entities=[e], + feature_df=self._session.sql(f"SELECT id, name, ts FROM {self._mock_table}"), + timestamp_col="ts", + refresh_freq="DOWNSTREAM", + ) + fv = self._feature_store.register_feature_view(fv, "test", override=True) + fv = self._feature_store.suspend_feature_view(fv) + + try: + self._test_access( + lambda: self._feature_store.resume_feature_view(fv), + required_access, + test_access, + expected_result=lambda _fv: self.assertIn( + _fv.status, (FeatureViewStatus.RUNNING, FeatureViewStatus.ACTIVE) + ), + ), + finally: + self._feature_store.delete_feature_view(fv) + + @parameterized.product(required_access=[Role.PRODUCER], test_access=list(Role)) # type: ignore[misc] + def test_generate_dataset(self, required_access: Role, test_access: Role) -> None: + spine_df = self._session.sql(f"SELECT id FROM {self._mock_table}") + fv1 = self._feature_store.get_feature_view("fv1", "v1") + fv2 = self._feature_store.get_feature_view("fv2", "v1") + dataset_name = f"FS_TEST_DATASET_{uuid4().hex.upper()}" + + self._test_access( + lambda: self._feature_store.generate_dataset(spine_df, [fv1, fv2], materialized_table=dataset_name), + required_access, + test_access, + access_exception_dict={Role.NONE: snowpark_exceptions.SnowparkSQLException}, + ) + + @parameterized.product(required_access=[Role.PRODUCER], test_access=list(Role)) # type: ignore[misc] + def test_delete_feature_view(self, required_access: Role, test_access: Role) -> None: + e = self._feature_store.get_entity("foo") + fv = FeatureView( + name="test_fv", + entities=[e], + feature_df=self._session.sql(f"SELECT id, name, ts FROM {self._mock_table}"), + timestamp_col="ts", + refresh_freq="DOWNSTREAM", + ) + + self._session.use_role(self._test_roles[Role.PRODUCER]) + fv = self._feature_store.register_feature_view(fv, "test", override=True) + + try: + self._test_access( + lambda: self._feature_store.delete_feature_view(fv), + required_access, + test_access, + expected_access_exception=snowpark_exceptions.SnowparkSQLException, + access_exception_dict={Role.NONE: snowpark_exceptions.SnowparkSQLException}, + ) + finally: + self._feature_store.delete_feature_view(fv) + + @parameterized.product(required_access=[Role.PRODUCER], test_access=list(Role)) # type: ignore[misc] + def test_delete_entity(self, required_access: Role, test_access: Role) -> None: + e = Entity(f"test_entity_{uuid4().hex.upper()}"[:32], ["test_key"]) + + self._session.use_role(self._test_roles[Role.PRODUCER]) + self._feature_store.register_entity(e) + + self._test_access( + lambda: self._feature_store.delete_entity(e.name), + required_access, + test_access, + ) + + @parameterized.product(required_access=[Role.CONSUMER], test_access=list(Role)) # type: ignore[misc] + def test_list_entities(self, required_access: Role, test_access: Role) -> None: + self._test_access( + self._feature_store.list_entities, + required_access, + test_access, + expected_result=lambda rst: self.assertGreater(len(rst.collect()), 0), + access_exception_dict={Role.NONE: snowpark_exceptions.SnowparkSQLException}, + ) + + @parameterized.product(required_access=[Role.CONSUMER], test_access=list(Role)) # type: ignore[misc] + def test_get_entity(self, required_access: Role, test_access: Role) -> None: + self._test_access( + lambda: self._feature_store.get_entity("foo"), + required_access, + test_access, + expected_result=lambda rst: self.assertIsInstance(rst, Entity), + ) + + @parameterized.product(required_access=[Role.CONSUMER], test_access=list(Role)) # type: ignore[misc] + def test_list_feature_views(self, required_access: Role, test_access: Role) -> None: + self._test_access( + lambda: self._feature_store.list_feature_views(as_dataframe=False), + required_access, + test_access, + expected_result=lambda rst: self.assertGreater(len(rst), 0), + access_exception_dict={Role.NONE: snowpark_exceptions.SnowparkSQLException}, + ) + + @parameterized.product(required_access=[Role.CONSUMER], test_access=list(Role)) # type: ignore[misc] + def test_get_feature_view(self, required_access: Role, test_access: Role) -> None: + self._test_access( + lambda: self._feature_store.get_feature_view("fv1", "v1"), + required_access, + test_access, + expected_result=lambda rst: self.assertIsInstance(rst, FeatureView), + ) + + @parameterized.product(required_access=[Role.CONSUMER], test_access=list(Role)) # type: ignore[misc] + def test_retrieve_feature_values(self, required_access: Role, test_access: Role) -> None: + spine_df = self._session.sql(f"SELECT id FROM {self._mock_table}") + fv1 = self._feature_store.get_feature_view("fv1", "v1") + fv2 = self._feature_store.get_feature_view("fv2", "v1") + + self._test_access( + lambda: self._feature_store.retrieve_feature_values(spine_df, [fv1, fv2]), + required_access, + test_access, + access_exception_dict={Role.NONE: snowpark_exceptions.SnowparkSQLException}, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/feature_store/feature_store_test.py b/tests/integ/snowflake/ml/feature_store/feature_store_test.py index 0b816cff..4442068f 100644 --- a/tests/integ/snowflake/ml/feature_store/feature_store_test.py +++ b/tests/integ/snowflake/ml/feature_store/feature_store_test.py @@ -447,7 +447,9 @@ def test_create_and_delete_feature_views(self) -> None: # register feature view fv0 = fs.register_feature_view(feature_view=fv0, version="FIRST") self.assertEqual(fv0.version, "FIRST") - self.assertEqual(fv0.status, FeatureViewStatus.RUNNING) + self.assertTrue( + fv0.status == FeatureViewStatus.ACTIVE or fv0.status == FeatureViewStatus.RUNNING + ) # fv0.status == FeatureViewStatus.RUNNING can be removed after BCR 2024_02 gets fully deployed self.assertEqual(fv0.refresh_freq, "1 minute") self.assertEqual(fv0, fs.get_feature_view("fv0", "FIRST")) @@ -497,7 +499,9 @@ def test_create_and_delete_feature_views(self) -> None: self.assertEqual(fv.name, "FV1") self.assertEqual(fv.version, "FIRST") self.assertEqual(fv.query, sql0) - self.assertEqual(fv.status, FeatureViewStatus.RUNNING) + self.assertTrue( + fv.status == FeatureViewStatus.ACTIVE or fv.status == FeatureViewStatus.RUNNING + ) # fv.status == FeatureViewStatus.RUNNING can be removed after BCR 2024_02 gets fully deployed self.assertEqual(fv.refresh_freq, "5 minutes") self.assertEqual(fv.warehouse, alternate_warehouse) self.assertEqual(fv.desc, "my_fv1") @@ -547,17 +551,12 @@ def test_resume_and_suspend_feature_view(self) -> None: refresh_freq="DOWNSTREAM", ) my_fv = fs.register_feature_view(feature_view=my_fv, version="v1", block=True) - - with self.assertRaisesRegex(ValueError, "FeatureView.*is not in suspended status.*"): - fs.resume_feature_view(my_fv) - my_fv = fs.suspend_feature_view(my_fv) - - with self.assertRaisesRegex(ValueError, "FeatureView.*is not in running status.*"): - fs.suspend_feature_view(my_fv) - + self.assertEqual(my_fv.status, FeatureViewStatus.SUSPENDED) my_fv = fs.resume_feature_view(my_fv) - self.assertEqual(my_fv.status, FeatureViewStatus.RUNNING) + self.assertTrue( + my_fv.status == FeatureViewStatus.ACTIVE or my_fv.status == FeatureViewStatus.RUNNING + ) # my_fv.status == FeatureViewStatus.RUNNING can be removed after BCR 2024_02 gets fully deployed def test_resume_and_suspend_feature_view_system_error(self) -> None: fs = self._create_feature_store() @@ -1681,6 +1680,104 @@ def test_generate_dataset_point_in_time_join(self) -> None: sort_cols=["CUSTOMER_ID"], ) + def test_cross_feature_store_interop(self) -> None: + # create first feature store and register feature views + first_fs = self._create_feature_store() + + first_entity = Entity("foo", ["id"]) + first_fs.register_entity(first_entity) + first_fv = FeatureView( + name="fv", + entities=[first_entity], + feature_df=self._session.table(self._mock_table).select(["NAME", "ID", "AGE", "TS"]), + timestamp_col="ts", + desc="foobar", + ) + first_fv = first_fs.register_feature_view(feature_view=first_fv, version="v1") + + # create second feature store and register feature views + second_fs = self._create_feature_store() + + second_entity = Entity("foo", ["id"]) + second_fs.register_entity(second_entity) + second_fv = FeatureView( + name="fv", + entities=[second_entity], + feature_df=self._session.table(self._mock_table).select(["ID", "DEPT", "TITLE", "TS"]), + timestamp_col="ts", + desc="foobar", + ) + second_fv = second_fs.register_feature_view(feature_view=second_fv, version="v1") + + # make sure these two feature views are in different feature store + self.assertNotEqual(first_fv.schema, second_fv.schema) + + # generate dataset by joining feature views from different feature store + spine_df = self._session.create_dataframe([(1, 101)], schema=["id", "ts"]) + for fs in [first_fs, second_fs]: + ds = fs.generate_dataset( + spine_df=spine_df, + features=[first_fv, second_fv], + spine_timestamp_col="ts", + ) + compare_dataframe( + actual_df=ds.df.to_pandas(), + target_data={ + "ID": [1], + "TS": [101], + "NAME": ["jonh"], + "AGE": [20], + "DEPT": ["sales"], + "TITLE": ["boss"], + }, + sort_cols=["ID"], + ) + + def test_generate_dataset_left_join(self) -> None: + # testing case for join features without timestamp, which is a left join with the spine + fs = self._create_feature_store() + + e1 = Entity("foo", ["id", "name"]) + fs.register_entity(e1) + + sql1 = f"SELECT id, name, title FROM {self._mock_table}" + fv1 = FeatureView( + name="fv1", + entities=[e1], + feature_df=self._session.sql(sql1), + refresh_freq="DOWNSTREAM", + ) + fv1 = fs.register_feature_view(feature_view=fv1, version="v1", block=True) + + e2 = Entity("bar", ["id"]) + fs.register_entity(e2) + + sql2 = f"SELECT id, age FROM {self._mock_table}" + fv2 = FeatureView( + name='"FvfV2"', + entities=[e2], + feature_df=self._session.sql(sql2), + refresh_freq="DOWNSTREAM", + ) + fv2 = fs.register_feature_view(feature_view=fv2, version="v1", block=True) + spine_df = self._session.create_dataframe([(1, "jonh"), (2, "porter"), (3, "johnny")], schema=["id", "name"]) + + ds = fs.generate_dataset( + spine_df=spine_df, + features=[fv1, fv2], + ) + + compare_dataframe( + actual_df=ds.df.to_pandas(), + target_data={ + "ID": [1, 2, 3], + "NAME": ["jonh", "porter", "johnny"], + "TITLE": ["boss", "manager", None], + "AGE": [20, 30, None], + }, + sort_cols=["ID"], + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/fileset/sfcfs_integ_test.py b/tests/integ/snowflake/ml/fileset/sfcfs_integ_test.py index 9f989a5f..c56d29cb 100644 --- a/tests/integ/snowflake/ml/fileset/sfcfs_integ_test.py +++ b/tests/integ/snowflake/ml/fileset/sfcfs_integ_test.py @@ -1,3 +1,5 @@ +import pickle + import fsspec from absl.testing import absltest @@ -298,6 +300,15 @@ def test_negative_optimize_read(self) -> None: with self.assertRaises(fileset_errors.StageNotFoundError): fs.optimize_read(["@ML_DATASETS.public.stage_does_not_exist/aaa"]) + def test_fs_serializability(self) -> None: + """Test if an object of Snowflake FS can be serialized using pickle.""" + + sfcfs_pickle = sfcfs.SFFileSystem(sf_connection=self.sf_connection) + + pickled_data = pickle.dumps(sfcfs_pickle) + sfcfs_deserialized = pickle.loads(pickled_data) + assert sfcfs_deserialized._conn is not None + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel b/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel index 2426fd20..9005d060 100644 --- a/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel +++ b/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel @@ -27,3 +27,17 @@ py_test( "//tests/integ/snowflake/ml/test_utils:model_factory", ], ) + +py_test( + name = "input_validation_integ_test", + timeout = "long", + srcs = ["input_validation_integ_test.py"], + deps = [ + "//snowflake/ml/model:custom_model", + "//snowflake/ml/model:model_signature", + "//snowflake/ml/registry", + "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:dataframe_utils", + "//tests/integ/snowflake/ml/test_utils:db_manager", + ], +) diff --git a/tests/integ/snowflake/ml/model/_client/model/input_validation_integ_test.py b/tests/integ/snowflake/ml/model/_client/model/input_validation_integ_test.py new file mode 100644 index 00000000..75b0d226 --- /dev/null +++ b/tests/integ/snowflake/ml/model/_client/model/input_validation_integ_test.py @@ -0,0 +1,119 @@ +import uuid + +import pandas as pd +from absl.testing import absltest, parameterized + +from snowflake.ml.model import custom_model, model_signature +from snowflake.ml.registry import registry +from snowflake.ml.utils import connection_params +from snowflake.snowpark import Session +from tests.integ.snowflake.ml.test_utils import dataframe_utils, db_manager + +MODEL_NAME = "TEST_MODEL" +VERSION_NAME = "V1" + + +class DemoModel(custom_model.CustomModel): + def __init__(self, context: custom_model.ModelContext) -> None: + super().__init__(context) + + @custom_model.inference_api + def predict(self, input: pd.DataFrame) -> pd.DataFrame: + return pd.DataFrame({"output": input["c1"]}) + + +class TestInputValidationInteg(parameterized.TestCase): + @classmethod + def setUpClass(self) -> None: + """Creates Snowpark and Snowflake environments for testing.""" + login_options = connection_params.SnowflakeLoginOptions() + + self._run_id = uuid.uuid4().hex + self._test_db = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self._run_id, "db").upper() + self._test_schema = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( + self._run_id, "schema" + ).upper() + + self._session = Session.builder.configs( + { + **login_options, + **{"database": self._test_db, "schema": self._test_schema}, + } + ).create() + + self._db_manager = db_manager.DBManager(self._session) + self._db_manager.create_database(self._test_db) + self._db_manager.create_schema(self._test_schema) + self._db_manager.cleanup_databases(expire_hours=6) + self.registry = registry.Registry(self._session) + + lm = DemoModel(custom_model.ModelContext()) + + self._mv = self.registry.log_model( + model=lm, + model_name=MODEL_NAME, + version_name=VERSION_NAME, + signatures={ + "predict": model_signature.ModelSignature( + inputs=[ + model_signature.FeatureSpec(name="c1", dtype=model_signature.DataType.INT8), + model_signature.FeatureSpec(name="c2", dtype=model_signature.DataType.INT8), + model_signature.FeatureSpec(name="c3", dtype=model_signature.DataType.INT8), + ], + outputs=[ + model_signature.FeatureSpec(name="output", dtype=model_signature.DataType.INT8), + ], + ) + }, + ) + + @classmethod + def tearDownClass(self) -> None: + self._db_manager.drop_database(self._test_db) + self._session.close() + + def test_default_non_strict(self) -> None: + pd.testing.assert_frame_equal( + self._mv.run(pd.DataFrame([[1, 2, 3], [4, 2, 5]])), + pd.DataFrame([1, 4], columns=["output"]), + check_dtype=False, + ) + + pd.testing.assert_frame_equal( + self._mv.run(pd.DataFrame([[1, 2, 3], [257, 2, 5]])), + pd.DataFrame([1, 1], columns=["output"]), + check_dtype=False, + ) + + sp_df = self._session.create_dataframe([[1, 2, 3], [4, 2, 5]], schema=['"c1"', '"c2"', '"c3"']) + y_df_expected = pd.DataFrame([[1, 2, 3, 1], [4, 2, 5, 4]], columns=["c1", "c2", "c3", "output"]) + dataframe_utils.check_sp_df_res(self._mv.run(sp_df), y_df_expected, check_dtype=False) + + sp_df = self._session.create_dataframe([[1, 2, 3], [257, 2, 5]], schema=['"c1"', '"c2"', '"c3"']) + y_df_expected = pd.DataFrame([[1, 2, 3, 1], [257, 2, 5, 1]], columns=["c1", "c2", "c3", "output"]) + dataframe_utils.check_sp_df_res(self._mv.run(sp_df), y_df_expected, check_dtype=False) + + def test_strict(self) -> None: + pd.testing.assert_frame_equal( + self._mv.run(pd.DataFrame([[1, 2, 3], [4, 2, 5]]), strict_input_validation=True), + pd.DataFrame([1, 4], columns=["output"]), + check_dtype=False, + ) + + with self.assertRaisesRegex(ValueError, "Data Validation Error"): + self._mv.run(pd.DataFrame([[1, 2, 4], [257, 2, 5]]), strict_input_validation=True) + + sp_df = self._session.create_dataframe([[1, 2, 3], [4, 2, 5]], schema=['"c1"', '"c2"', '"c3"']) + y_df_expected = pd.DataFrame([[1, 2, 3, 1], [4, 2, 5, 4]], columns=["c1", "c2", "c3", "output"]) + dataframe_utils.check_sp_df_res( + self._mv.run(sp_df, strict_input_validation=True), y_df_expected, check_dtype=False + ) + + sp_df = self._session.create_dataframe([[1, 2, 3], [257, 2, 5]], schema=['"c1"', '"c2"', '"c3"']) + y_df_expected = pd.DataFrame([[1, 2, 3, 1], [257, 2, 5, 257]], columns=["c1", "c2", "c3", "output"]) + with self.assertRaisesRegex(ValueError, "Data Validation Error"): + self._mv.run(sp_df, strict_input_validation=True) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/model/model_badcase_integ_test.py b/tests/integ/snowflake/ml/model/model_badcase_integ_test.py index 98459bfc..d9365edd 100644 --- a/tests/integ/snowflake/ml/model/model_badcase_integ_test.py +++ b/tests/integ/snowflake/ml/model/model_badcase_integ_test.py @@ -98,7 +98,9 @@ def test_custom_demo_model(self) -> None: stage_path=posixpath.join(tmp_stage, "custom_demo_model"), model=lm, conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server(self._session, "snowflake-snowpark-python") + test_env_utils.get_latest_package_version_spec_in_server( + self._session, "snowflake-snowpark-python!=1.12.0" + ) ], sample_input=pd_df, metadata={"author": "halu", "version": "1"}, diff --git a/tests/integ/snowflake/ml/model/warehouse_model_integ_test_utils.py b/tests/integ/snowflake/ml/model/warehouse_model_integ_test_utils.py index 14ecd17b..7bcebfd1 100644 --- a/tests/integ/snowflake/ml/model/warehouse_model_integ_test_utils.py +++ b/tests/integ/snowflake/ml/model/warehouse_model_integ_test_utils.py @@ -26,7 +26,7 @@ def base_test_case( ) -> None: tmp_stage = db._session.get_session_stage() conda_dependencies = [ - test_env_utils.get_latest_package_version_spec_in_server(db._session, "snowflake-snowpark-python") + test_env_utils.get_latest_package_version_spec_in_server(db._session, "snowflake-snowpark-python!=1.12.0") ] if additional_dependencies: conda_dependencies.extend(additional_dependencies) diff --git a/tests/integ/snowflake/ml/modeling/framework/utils.py b/tests/integ/snowflake/ml/modeling/framework/utils.py index a7e3c2d5..4ed71b5f 100644 --- a/tests/integ/snowflake/ml/modeling/framework/utils.py +++ b/tests/integ/snowflake/ml/modeling/framework/utils.py @@ -154,7 +154,8 @@ def gen_fuzz_data( high: upper bound(s) of the output interval (exclusive) Returns: - A tuple of generated data and column names + A tuple of generated data and column names. + The 1st column of test data is "ID". Raises: ValueError: if data type is not supported diff --git a/tests/integ/snowflake/ml/modeling/metrics/BUILD.bazel b/tests/integ/snowflake/ml/modeling/metrics/BUILD.bazel index d5d65045..e1954e23 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/BUILD.bazel +++ b/tests/integ/snowflake/ml/modeling/metrics/BUILD.bazel @@ -69,6 +69,7 @@ py_test( srcs = ["d2_absolute_error_score_test.py"], shard_count = SHARD_COUNT, deps = [ + ":generator", "//snowflake/ml/modeling/metrics:regression", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/modeling/framework:utils", @@ -81,6 +82,7 @@ py_test( srcs = ["d2_pinball_score_test.py"], shard_count = SHARD_COUNT, deps = [ + ":generator", "//snowflake/ml/modeling/metrics:regression", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/modeling/framework:utils", @@ -105,6 +107,7 @@ py_test( srcs = ["f1_score_test.py"], shard_count = SHARD_COUNT, deps = [ + ":generator", "//snowflake/ml/modeling/metrics:classification", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/modeling/framework:utils", @@ -117,6 +120,7 @@ py_test( srcs = ["fbeta_score_test.py"], shard_count = SHARD_COUNT, deps = [ + ":generator", "//snowflake/ml/modeling/metrics:classification", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/modeling/framework:utils", @@ -129,6 +133,7 @@ py_test( srcs = ["log_loss_test.py"], shard_count = SHARD_COUNT, deps = [ + ":generator", "//snowflake/ml/modeling/metrics:classification", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/modeling/framework:utils", @@ -201,6 +206,7 @@ py_test( srcs = ["precision_recall_fscore_support_test.py"], shard_count = SHARD_COUNT, deps = [ + ":generator", "//snowflake/ml/modeling/metrics:classification", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/modeling/framework:utils", @@ -213,6 +219,7 @@ py_test( srcs = ["precision_score_test.py"], shard_count = SHARD_COUNT, deps = [ + ":generator", "//snowflake/ml/modeling/metrics:classification", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/modeling/framework:utils", @@ -225,6 +232,7 @@ py_test( srcs = ["recall_score_test.py"], shard_count = SHARD_COUNT, deps = [ + ":generator", "//snowflake/ml/modeling/metrics:classification", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/modeling/framework:utils", diff --git a/tests/integ/snowflake/ml/modeling/metrics/d2_absolute_error_score_test.py b/tests/integ/snowflake/ml/modeling/metrics/d2_absolute_error_score_test.py index fb9216c5..ffaf1f32 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/d2_absolute_error_score_test.py +++ b/tests/integ/snowflake/ml/modeling/metrics/d2_absolute_error_score_test.py @@ -1,7 +1,8 @@ -from typing import Any, Dict +from typing import Optional, Union from unittest import mock import numpy as np +import numpy.typing as npt from absl.testing import parameterized from absl.testing.absltest import main from sklearn import metrics as sklearn_metrics @@ -10,35 +11,29 @@ from snowflake.ml.modeling import metrics as snowml_metrics from snowflake.ml.utils import connection_params from tests.integ.snowflake.ml.modeling.framework import utils +from tests.integ.snowflake.ml.modeling.metrics import generator -_ROWS = 100 _TYPES = [utils.DataType.INTEGER] * 4 + [utils.DataType.FLOAT] -_BINARY_DATA, _SF_SCHEMA = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=2, -) -_MULTICLASS_DATA, _ = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=5, -) +_BINARY_LOW, _BINARY_HIGH = 0, 2 +_MULTICLASS_LOW, _MULTICLASS_HIGH = 0, 5 +_BINARY_DATA_LIST, _SF_SCHEMA = generator.gen_test_cases(_TYPES, _BINARY_LOW, _BINARY_HIGH) +_MULTICLASS_DATA_LIST, _ = generator.gen_test_cases(_TYPES, _MULTICLASS_LOW, _MULTICLASS_HIGH) +_REGULAR_BINARY_DATA_LIST, _LARGE_BINARY_DATA = _BINARY_DATA_LIST[:-1], _BINARY_DATA_LIST[-1] +_REGULAR_MULTICLASS_DATA_LIST, _LARGE_MULTICLASS_DATA = _MULTICLASS_DATA_LIST[:-1], _MULTICLASS_DATA_LIST[-1] _Y_TRUE_COL = _SF_SCHEMA[1] _Y_PRED_COL = _SF_SCHEMA[2] _Y_TRUE_COLS = [_SF_SCHEMA[1], _SF_SCHEMA[2]] _Y_PRED_COLS = [_SF_SCHEMA[3], _SF_SCHEMA[4]] _SAMPLE_WEIGHT_COL = _SF_SCHEMA[5] -_MULTILABEL_DATA = [ - [1, 0, 1, 0.8, 0.3, 0.6], - [0, 1, 0, 0.2, 0.7, 0.4], - [1, 1, 0, 0.9, 0.6, 0.2], - [0, 0, 1, 0.1, 0.4, 0.8], -] -_MULTILABEL_SCHEMA = ["Y_0", "Y_1", "Y_2", "S_0", "S_1", "S_2"] -_MULTILABEL_Y_TRUE_COLS = [_MULTILABEL_SCHEMA[0], _MULTILABEL_SCHEMA[1], _MULTILABEL_SCHEMA[2]] -_MULTILABEL_Y_PRED_COLS = [_MULTILABEL_SCHEMA[3], _MULTILABEL_SCHEMA[4], _MULTILABEL_SCHEMA[5]] + +_MULTILABEL_TYPES = [utils.DataType.INTEGER] * 3 + [utils.DataType.FLOAT] * 3 +_MULTILABEL_LOW, _MULTILABEL_HIGH = 0, [2, 2, 2, 1, 1, 1] +_MULTILABEL_DATA_LIST, _MULTILABEL_SCHEMA = generator.gen_test_cases( + _MULTILABEL_TYPES, _MULTILABEL_LOW, _MULTILABEL_HIGH +) +_REGULAR_MULTILABEL_DATA_LIST, _LARGE_MULTILABEL_DATA = _MULTILABEL_DATA_LIST[:-1], _MULTILABEL_DATA_LIST[-1] +_MULTILABEL_Y_TRUE_COLS = [_MULTILABEL_SCHEMA[1], _MULTILABEL_SCHEMA[2], _MULTILABEL_SCHEMA[3]] +_MULTILABEL_Y_PRED_COLS = [_MULTILABEL_SCHEMA[4], _MULTILABEL_SCHEMA[5], _MULTILABEL_SCHEMA[6]] class D2AbsoluteErrorScoreTest(parameterized.TestCase): @@ -51,61 +46,73 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "sample_weight_col_name": [None, _SAMPLE_WEIGHT_COL], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], + ) + def test_sample_weight_binary(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.d2_absolute_error_score( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_loss = sklearn_metrics.d2_absolute_error_score( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], + ) + def test_sample_weight_multiclass(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.d2_absolute_error_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_loss = sklearn_metrics.d2_absolute_error_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTILABEL_DATA_LIST))), + multioutput=["raw_values", "uniform_average", [0.2, 1.0, 1.66]], ) - def test_sample_weight(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for sample_weight_col_name in params["sample_weight_col_name"]: - actual_loss = snowml_metrics.d2_absolute_error_score( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - sample_weight_col_name=sample_weight_col_name, - ) - sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None - sklearn_loss = sklearn_metrics.d2_absolute_error_score( - pandas_df[y_true], - pandas_df[y_pred], - sample_weight=sample_weight, - ) - self.assertAlmostEqual(sklearn_loss, actual_loss) - - @parameterized.parameters( # type: ignore[misc] - {"params": {"multioutput": ["raw_values", "uniform_average", [0.2, 1.0, 1.66]]}}, + def test_multioutput(self, data_index: int, multioutput: Union[str, npt.ArrayLike]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTILABEL_DATA_LIST[data_index], _MULTILABEL_SCHEMA) + + actual_loss = snowml_metrics.d2_absolute_error_score( + df=input_df, + y_true_col_names=_MULTILABEL_Y_TRUE_COLS, + y_pred_col_names=_MULTILABEL_Y_PRED_COLS, + multioutput=multioutput, + ) + sklearn_loss = sklearn_metrics.d2_absolute_error_score( + pandas_df[_MULTILABEL_Y_TRUE_COLS], + pandas_df[_MULTILABEL_Y_PRED_COLS], + multioutput=multioutput, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTILABEL_DATA_LIST))), ) - def test_multioutput(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTILABEL_DATA, _MULTILABEL_SCHEMA) - - for multioutput in params["multioutput"]: - actual_loss = snowml_metrics.d2_absolute_error_score( - df=input_df, - y_true_col_names=_MULTILABEL_Y_TRUE_COLS, - y_pred_col_names=_MULTILABEL_Y_PRED_COLS, - multioutput=multioutput, - ) - sklearn_loss = sklearn_metrics.d2_absolute_error_score( - pandas_df[_MULTILABEL_Y_TRUE_COLS], - pandas_df[_MULTILABEL_Y_PRED_COLS], - multioutput=multioutput, - ) - np.testing.assert_allclose(actual_loss, sklearn_loss) - - def test_multilabel(self) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTILABEL_DATA, _MULTILABEL_SCHEMA) + def test_multilabel(self, data_index: int) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTILABEL_DATA_LIST[data_index], _MULTILABEL_SCHEMA) actual_loss = snowml_metrics.d2_absolute_error_score( df=input_df, @@ -116,11 +123,14 @@ def test_multilabel(self) -> None: pandas_df[_MULTILABEL_Y_TRUE_COLS], pandas_df[_MULTILABEL_Y_PRED_COLS], ) - self.assertAlmostEqual(sklearn_loss, actual_loss) + np.testing.assert_allclose(actual_loss, sklearn_loss) + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + ) @mock.patch("snowflake.ml.modeling.metrics.regression.result._RESULT_SIZE_THRESHOLD", 0) - def test_metric_size_threshold(self) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_metric_size_threshold(self, data_index: int) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) actual_loss = snowml_metrics.d2_absolute_error_score( df=input_df, @@ -131,7 +141,7 @@ def test_metric_size_threshold(self) -> None: pandas_df[_Y_TRUE_COLS], pandas_df[_Y_PRED_COLS], ) - self.assertAlmostEqual(sklearn_loss, actual_loss) + np.testing.assert_allclose(actual_loss, sklearn_loss) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/modeling/metrics/d2_pinball_score_test.py b/tests/integ/snowflake/ml/modeling/metrics/d2_pinball_score_test.py index 3ce2f309..23436f0f 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/d2_pinball_score_test.py +++ b/tests/integ/snowflake/ml/modeling/metrics/d2_pinball_score_test.py @@ -1,7 +1,8 @@ -from typing import Any, Dict +from typing import Optional, Union from unittest import mock import numpy as np +import numpy.typing as npt from absl.testing import parameterized from absl.testing.absltest import main from sklearn import metrics as sklearn_metrics @@ -10,35 +11,29 @@ from snowflake.ml.modeling import metrics as snowml_metrics from snowflake.ml.utils import connection_params from tests.integ.snowflake.ml.modeling.framework import utils +from tests.integ.snowflake.ml.modeling.metrics import generator -_ROWS = 100 _TYPES = [utils.DataType.INTEGER] * 4 + [utils.DataType.FLOAT] -_BINARY_DATA, _SF_SCHEMA = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=2, -) -_MULTICLASS_DATA, _ = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=5, -) +_BINARY_LOW, _BINARY_HIGH = 0, 2 +_MULTICLASS_LOW, _MULTICLASS_HIGH = 0, 5 +_BINARY_DATA_LIST, _SF_SCHEMA = generator.gen_test_cases(_TYPES, _BINARY_LOW, _BINARY_HIGH) +_MULTICLASS_DATA_LIST, _ = generator.gen_test_cases(_TYPES, _MULTICLASS_LOW, _MULTICLASS_HIGH) +_REGULAR_BINARY_DATA_LIST, _LARGE_BINARY_DATA = _BINARY_DATA_LIST[:-1], _BINARY_DATA_LIST[-1] +_REGULAR_MULTICLASS_DATA_LIST, _LARGE_MULTICLASS_DATA = _MULTICLASS_DATA_LIST[:-1], _MULTICLASS_DATA_LIST[-1] _Y_TRUE_COL = _SF_SCHEMA[1] _Y_PRED_COL = _SF_SCHEMA[2] _Y_TRUE_COLS = [_SF_SCHEMA[1], _SF_SCHEMA[2]] _Y_PRED_COLS = [_SF_SCHEMA[3], _SF_SCHEMA[4]] _SAMPLE_WEIGHT_COL = _SF_SCHEMA[5] -_MULTILABEL_DATA = [ - [1, 0, 1, 0.8, 0.3, 0.6], - [0, 1, 0, 0.2, 0.7, 0.4], - [1, 1, 0, 0.9, 0.6, 0.2], - [0, 0, 1, 0.1, 0.4, 0.8], -] -_MULTILABEL_SCHEMA = ["Y_0", "Y_1", "Y_2", "S_0", "S_1", "S_2"] -_MULTILABEL_Y_TRUE_COLS = [_MULTILABEL_SCHEMA[0], _MULTILABEL_SCHEMA[1], _MULTILABEL_SCHEMA[2]] -_MULTILABEL_Y_PRED_COLS = [_MULTILABEL_SCHEMA[3], _MULTILABEL_SCHEMA[4], _MULTILABEL_SCHEMA[5]] + +_MULTILABEL_TYPES = [utils.DataType.INTEGER] * 3 + [utils.DataType.FLOAT] * 3 +_MULTILABEL_LOW, _MULTILABEL_HIGH = 0, [2, 2, 2, 1, 1, 1] +_MULTILABEL_DATA_LIST, _MULTILABEL_SCHEMA = generator.gen_test_cases( + _MULTILABEL_TYPES, _MULTILABEL_LOW, _MULTILABEL_HIGH +) +_REGULAR_MULTILABEL_DATA_LIST, _LARGE_MULTILABEL_DATA = _MULTILABEL_DATA_LIST[:-1], _MULTILABEL_DATA_LIST[-1] +_MULTILABEL_Y_TRUE_COLS = [_MULTILABEL_SCHEMA[1], _MULTILABEL_SCHEMA[2], _MULTILABEL_SCHEMA[3]] +_MULTILABEL_Y_PRED_COLS = [_MULTILABEL_SCHEMA[4], _MULTILABEL_SCHEMA[5], _MULTILABEL_SCHEMA[6]] class D2PinballScoreTest(parameterized.TestCase): @@ -51,109 +46,131 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "sample_weight_col_name": [None, _SAMPLE_WEIGHT_COL], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], ) - def test_sample_weight(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for sample_weight_col_name in params["sample_weight_col_name"]: - actual_loss = snowml_metrics.d2_pinball_score( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - sample_weight_col_name=sample_weight_col_name, - ) - sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None - sklearn_loss = sklearn_metrics.d2_pinball_score( - pandas_df[y_true], - pandas_df[y_pred], - sample_weight=sample_weight, - ) - self.assertAlmostEqual(sklearn_loss, actual_loss) - - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "alpha": [0.1, 0.5, 0.99], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + def test_sample_weight_binary(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.d2_pinball_score( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_loss = sklearn_metrics.d2_pinball_score( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], ) - def test_alpha(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for alpha in params["alpha"]: - actual_loss = snowml_metrics.d2_pinball_score( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - alpha=alpha, - ) - sklearn_loss = sklearn_metrics.d2_pinball_score( - pandas_df[y_true], - pandas_df[y_pred], - alpha=alpha, - ) - self.assertAlmostEqual(sklearn_loss, actual_loss) - - @parameterized.parameters( # type: ignore[misc] - {"params": {"multioutput": ["raw_values", "uniform_average", [0.2, 1.0, 1.66]]}}, + def test_sample_weight_multiclass(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.d2_pinball_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_loss = sklearn_metrics.d2_pinball_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + alpha=[0.1, 0.5, 0.99], ) - def test_multioutput(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTILABEL_DATA, _MULTILABEL_SCHEMA) - - for multioutput in params["multioutput"]: - actual_loss = snowml_metrics.d2_pinball_score( - df=input_df, - y_true_col_names=_MULTILABEL_Y_TRUE_COLS, - y_pred_col_names=_MULTILABEL_Y_PRED_COLS, - multioutput=multioutput, - ) - sklearn_loss = sklearn_metrics.d2_pinball_score( - pandas_df[_MULTILABEL_Y_TRUE_COLS], - pandas_df[_MULTILABEL_Y_PRED_COLS], - multioutput=multioutput, - ) - np.testing.assert_allclose(actual_loss, sklearn_loss) - - def test_multilabel(self) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTILABEL_DATA, _MULTILABEL_SCHEMA) + def test_alpha_binary(self, data_index: int, alpha: float) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.d2_pinball_score( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + alpha=alpha, + ) + sklearn_loss = sklearn_metrics.d2_pinball_score( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + alpha=alpha, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + alpha=[0.1, 0.5, 0.99], + ) + def test_alpha_multiclass(self, data_index: int, alpha: float) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.d2_pinball_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + alpha=alpha, + ) + sklearn_loss = sklearn_metrics.d2_pinball_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + alpha=alpha, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss, rtol=1.0e-6, atol=1.0e-6) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTILABEL_DATA_LIST))), + multioutput=["raw_values", "uniform_average", [0.2, 1.0, 1.66]], + ) + def test_multioutput(self, data_index: int, multioutput: Union[str, npt.ArrayLike]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTILABEL_DATA_LIST[data_index], _MULTILABEL_SCHEMA) actual_loss = snowml_metrics.d2_pinball_score( df=input_df, y_true_col_names=_MULTILABEL_Y_TRUE_COLS, y_pred_col_names=_MULTILABEL_Y_PRED_COLS, + multioutput=multioutput, ) sklearn_loss = sklearn_metrics.d2_pinball_score( pandas_df[_MULTILABEL_Y_TRUE_COLS], pandas_df[_MULTILABEL_Y_PRED_COLS], + multioutput=multioutput, ) - self.assertAlmostEqual(sklearn_loss, actual_loss) + np.testing.assert_allclose(actual_loss, sklearn_loss) + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTILABEL_DATA_LIST))), + ) + def test_multilabel(self, data_index: int) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTILABEL_DATA_LIST[data_index], _MULTILABEL_SCHEMA) + + actual_loss = snowml_metrics.d2_pinball_score( + df=input_df, + y_true_col_names=_MULTILABEL_Y_TRUE_COLS, + y_pred_col_names=_MULTILABEL_Y_PRED_COLS, + ) + sklearn_loss = sklearn_metrics.d2_pinball_score( + pandas_df[_MULTILABEL_Y_TRUE_COLS], + pandas_df[_MULTILABEL_Y_PRED_COLS], + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + ) @mock.patch("snowflake.ml.modeling.metrics.regression.result._RESULT_SIZE_THRESHOLD", 0) - def test_metric_size_threshold(self) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_metric_size_threshold(self, data_index: int) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) actual_loss = snowml_metrics.d2_pinball_score( df=input_df, @@ -164,7 +181,7 @@ def test_metric_size_threshold(self) -> None: pandas_df[_Y_TRUE_COLS], pandas_df[_Y_PRED_COLS], ) - self.assertAlmostEqual(sklearn_loss, actual_loss) + np.testing.assert_allclose(actual_loss, sklearn_loss) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/modeling/metrics/f1_score_test.py b/tests/integ/snowflake/ml/modeling/metrics/f1_score_test.py index d2d6bc67..5b3ef1d2 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/f1_score_test.py +++ b/tests/integ/snowflake/ml/modeling/metrics/f1_score_test.py @@ -1,6 +1,7 @@ -from typing import Any, Dict +from typing import List, Optional, Union import numpy as np +import numpy.typing as npt from absl.testing import parameterized from absl.testing.absltest import main from sklearn import exceptions, metrics as sklearn_metrics @@ -9,21 +10,15 @@ from snowflake.ml.modeling import metrics as snowml_metrics from snowflake.ml.utils import connection_params from tests.integ.snowflake.ml.modeling.framework import utils +from tests.integ.snowflake.ml.modeling.metrics import generator -_ROWS = 100 _TYPES = [utils.DataType.INTEGER] * 4 + [utils.DataType.FLOAT] -_BINARY_DATA, _SF_SCHEMA = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=2, -) -_MULTICLASS_DATA, _ = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=5, -) +_BINARY_LOW, _BINARY_HIGH = 0, 2 +_MULTICLASS_LOW, _MULTICLASS_HIGH = 0, 5 +_BINARY_DATA_LIST, _SF_SCHEMA = generator.gen_test_cases(_TYPES, _BINARY_LOW, _BINARY_HIGH) +_MULTICLASS_DATA_LIST, _ = generator.gen_test_cases(_TYPES, _MULTICLASS_LOW, _MULTICLASS_HIGH) +_REGULAR_BINARY_DATA_LIST, _LARGE_BINARY_DATA = _BINARY_DATA_LIST[:-1], _BINARY_DATA_LIST[-1] +_REGULAR_MULTICLASS_DATA_LIST, _LARGE_MULTICLASS_DATA = _MULTICLASS_DATA_LIST[:-1], _MULTICLASS_DATA_LIST[-1] _Y_TRUE_COL = _SF_SCHEMA[1] _Y_PRED_COL = _SF_SCHEMA[2] _Y_TRUE_COLS = [_SF_SCHEMA[1], _SF_SCHEMA[2]] @@ -41,146 +36,155 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @parameterized.parameters( # type: ignore[misc] - {"params": {"labels": [None, [2, 0, 4]]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + labels=[None, [2, 0, 4]], ) - def test_labels(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) - - for labels in params["labels"]: - actual_f = snowml_metrics.f1_score( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - average=None, - labels=labels, - ) - sklearn_f = sklearn_metrics.f1_score( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - average=None, - labels=labels, - ) - np.testing.assert_allclose(actual_f, sklearn_f) + def test_labels(self, data_index: int, labels: Optional[npt.ArrayLike]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_f = snowml_metrics.f1_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + average=None, + labels=labels, + ) + sklearn_f = sklearn_metrics.f1_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + average=None, + labels=labels, + ) + np.testing.assert_allclose(actual_f, sklearn_f) - @parameterized.parameters( # type: ignore[misc] - {"params": {"pos_label": [0, 2, 4]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + pos_label=[0, 2, 4], ) - def test_pos_label(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) - - for pos_label in params["pos_label"]: - actual_f = snowml_metrics.f1_score( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - average=None, - pos_label=pos_label, - ) - sklearn_f = sklearn_metrics.f1_score( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - average=None, - pos_label=pos_label, - ) - np.testing.assert_allclose(actual_f, sklearn_f) + def test_pos_label(self, data_index: int, pos_label: Union[str, int]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_f = snowml_metrics.f1_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + average=None, + pos_label=pos_label, + ) + sklearn_f = sklearn_metrics.f1_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + average=None, + pos_label=pos_label, + ) + np.testing.assert_allclose(actual_f, sklearn_f) - @parameterized.parameters( # type: ignore[misc] - {"params": {"average": [None, "micro", "macro", "weighted"]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + average=[None, "micro", "macro", "weighted"], ) - def test_average_multiclass(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) - - for average in params["average"]: - actual_f = snowml_metrics.f1_score( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - average=average, - ) - sklearn_f = sklearn_metrics.f1_score( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - average=average, - ) - np.testing.assert_allclose(actual_f, sklearn_f) - - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "average": ["binary", "samples"], - "y_true": [_Y_TRUE_COL, _Y_TRUE_COLS], - "y_pred": [_Y_PRED_COL, _Y_PRED_COLS], - } - }, + def test_average_multiclass(self, data_index: int, average: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_f = snowml_metrics.f1_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + average=average, + ) + sklearn_f = sklearn_metrics.f1_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + average=average, + ) + np.testing.assert_allclose(actual_f, sklearn_f) + + @parameterized.product( + ( + dict(y_true=_Y_TRUE_COL, y_pred=_Y_PRED_COL, average="binary"), + dict(y_true=_Y_TRUE_COLS, y_pred=_Y_PRED_COLS, average="samples"), + ), + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), ) - def test_average_binary(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_average_binary_samples( + self, + y_true: Union[str, List[str]], + y_pred: Union[str, List[str]], + average: Optional[str], + data_index: int, + ) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_f = snowml_metrics.f1_score( + df=input_df, + y_true_col_names=y_true, + y_pred_col_names=y_pred, + average=average, + ) + sklearn_f = sklearn_metrics.f1_score( + pandas_df[y_true], + pandas_df[y_pred], + average=average, + ) + np.testing.assert_allclose(actual_f, sklearn_f) - for idx, average in enumerate(params["average"]): - y_true = params["y_true"][idx] - y_pred = params["y_pred"][idx] - actual_f = snowml_metrics.f1_score( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - average=average, - ) - sklearn_f = sklearn_metrics.f1_score( - pandas_df[y_true], - pandas_df[y_pred], - average=average, - ) - np.testing.assert_allclose(actual_f, sklearn_f) + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], + ) + def test_sample_weight_binary(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_f = snowml_metrics.f1_score( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + average=None, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_f = sklearn_metrics.f1_score( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + average=None, + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_f, sklearn_f) - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "sample_weight_col_name": [None, _SAMPLE_WEIGHT_COL], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], ) - def test_sample_weight(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for sample_weight_col_name in params["sample_weight_col_name"]: - actual_f = snowml_metrics.f1_score( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - average=None, - sample_weight_col_name=sample_weight_col_name, - ) - sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None - sklearn_f = sklearn_metrics.f1_score( - pandas_df[y_true], - pandas_df[y_pred], - average=None, - sample_weight=sample_weight, - ) - np.testing.assert_allclose(actual_f, sklearn_f) - - @parameterized.parameters( # type: ignore[misc] - {"params": {"zero_division": [0, 1]}}, + def test_sample_weight_multiclass(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_f = snowml_metrics.f1_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + average=None, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_f = sklearn_metrics.f1_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + average=None, + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_f, sklearn_f) + + @parameterized.product( # type: ignore[misc] + zero_division=[0, 1], ) - def test_zero_division(self, params: Dict[str, Any]) -> None: + def test_zero_division(self, zero_division: Union[str, int]) -> None: data = [ [0, 0, 0, 0, 0, 0], ] pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - for zero_division in params["zero_division"]: - if zero_division == "warn": - continue - + if zero_division != "warn": actual_f = snowml_metrics.f1_score( df=input_df, y_true_col_names=_Y_TRUE_COL, @@ -210,6 +214,38 @@ def test_zero_division(self, params: Dict[str, Any]) -> None: ) np.testing.assert_allclose(actual_f, sklearn_f) + def test_with_large_num_of_rows_binary(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_BINARY_DATA, _SF_SCHEMA) + + actual_f = snowml_metrics.f1_score( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + average=None, + ) + sklearn_f = sklearn_metrics.f1_score( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + average=None, + ) + np.testing.assert_allclose(actual_f, sklearn_f) + + def test_with_large_num_of_rows_multiclass(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_MULTICLASS_DATA, _SF_SCHEMA) + + actual_f = snowml_metrics.f1_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + average=None, + ) + sklearn_f = sklearn_metrics.f1_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + average=None, + ) + np.testing.assert_allclose(actual_f, sklearn_f) + if __name__ == "__main__": main() diff --git a/tests/integ/snowflake/ml/modeling/metrics/fbeta_score_test.py b/tests/integ/snowflake/ml/modeling/metrics/fbeta_score_test.py index 0e59c6d0..a959b607 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/fbeta_score_test.py +++ b/tests/integ/snowflake/ml/modeling/metrics/fbeta_score_test.py @@ -1,6 +1,7 @@ -from typing import Any, Dict +from typing import List, Optional, Union import numpy as np +import numpy.typing as npt from absl.testing import parameterized from absl.testing.absltest import main from sklearn import exceptions, metrics as sklearn_metrics @@ -9,21 +10,15 @@ from snowflake.ml.modeling import metrics as snowml_metrics from snowflake.ml.utils import connection_params from tests.integ.snowflake.ml.modeling.framework import utils +from tests.integ.snowflake.ml.modeling.metrics import generator -_ROWS = 100 _TYPES = [utils.DataType.INTEGER] * 4 + [utils.DataType.FLOAT] -_BINARY_DATA, _SF_SCHEMA = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=2, -) -_MULTICLASS_DATA, _ = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=5, -) +_BINARY_LOW, _BINARY_HIGH = 0, 2 +_MULTICLASS_LOW, _MULTICLASS_HIGH = 0, 5 +_BINARY_DATA_LIST, _SF_SCHEMA = generator.gen_test_cases(_TYPES, _BINARY_LOW, _BINARY_HIGH) +_MULTICLASS_DATA_LIST, _ = generator.gen_test_cases(_TYPES, _MULTICLASS_LOW, _MULTICLASS_HIGH) +_REGULAR_BINARY_DATA_LIST, _LARGE_BINARY_DATA = _BINARY_DATA_LIST[:-1], _BINARY_DATA_LIST[-1] +_REGULAR_MULTICLASS_DATA_LIST, _LARGE_MULTICLASS_DATA = _MULTICLASS_DATA_LIST[:-1], _MULTICLASS_DATA_LIST[-1] _Y_TRUE_COL = _SF_SCHEMA[1] _Y_PRED_COL = _SF_SCHEMA[2] _Y_TRUE_COLS = [_SF_SCHEMA[1], _SF_SCHEMA[2]] @@ -41,190 +36,216 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "beta": [1.0, 0.5], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + beta=[1.0, 0.5], ) - def test_beta(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for beta in params["beta"]: - actual_f = snowml_metrics.fbeta_score( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - beta=beta, - average=None, - ) - sklearn_f = sklearn_metrics.fbeta_score( - pandas_df[y_true], - pandas_df[y_pred], - beta=beta, - average=None, - ) - np.testing.assert_allclose(actual_f, sklearn_f) - - @parameterized.parameters( # type: ignore[misc] - {"params": {"labels": [None, [2, 0, 4]]}}, + def test_beta_binary(self, data_index: int, beta: float) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_f = snowml_metrics.fbeta_score( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + beta=beta, + average=None, + ) + sklearn_f = sklearn_metrics.fbeta_score( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + beta=beta, + average=None, + ) + np.testing.assert_allclose(actual_f, sklearn_f) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + beta=[1.0, 0.5], ) - def test_labels(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) + def test_beta_multiclass(self, data_index: int, beta: float) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) - for labels in params["labels"]: - actual_f = snowml_metrics.fbeta_score( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - beta=0.5, - average=None, - labels=labels, - ) - sklearn_f = sklearn_metrics.fbeta_score( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - beta=0.5, - average=None, - labels=labels, - ) - np.testing.assert_allclose(actual_f, sklearn_f) + actual_f = snowml_metrics.fbeta_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + beta=beta, + average=None, + ) + sklearn_f = sklearn_metrics.fbeta_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + beta=beta, + average=None, + ) + np.testing.assert_allclose(actual_f, sklearn_f) - @parameterized.parameters( # type: ignore[misc] - {"params": {"pos_label": [0, 2, 4]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + labels=[None, [2, 0, 4]], ) - def test_pos_label(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) + def test_labels(self, data_index: int, labels: Optional[npt.ArrayLike]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) - for pos_label in params["pos_label"]: - actual_f = snowml_metrics.fbeta_score( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - beta=0.5, - average=None, - pos_label=pos_label, - ) - sklearn_f = sklearn_metrics.fbeta_score( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - beta=0.5, - average=None, - pos_label=pos_label, - ) - np.testing.assert_allclose(actual_f, sklearn_f) + actual_f = snowml_metrics.fbeta_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + beta=0.5, + average=None, + labels=labels, + ) + sklearn_f = sklearn_metrics.fbeta_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + beta=0.5, + average=None, + labels=labels, + ) + np.testing.assert_allclose(actual_f, sklearn_f) - @parameterized.parameters( # type: ignore[misc] - {"params": {"average": [None, "micro", "macro", "weighted"]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + pos_label=[0, 2, 4], ) - def test_average_multiclass(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) + def test_pos_label(self, data_index: int, pos_label: Union[str, int]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) - for average in params["average"]: - actual_f = snowml_metrics.fbeta_score( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - beta=0.5, - average=average, - ) - sklearn_f = sklearn_metrics.fbeta_score( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - beta=0.5, - average=average, - ) - np.testing.assert_allclose(actual_f, sklearn_f) + actual_f = snowml_metrics.fbeta_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + beta=0.5, + average=None, + pos_label=pos_label, + ) + sklearn_f = sklearn_metrics.fbeta_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + beta=0.5, + average=None, + pos_label=pos_label, + ) + np.testing.assert_allclose(actual_f, sklearn_f) - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "average": ["binary", "samples"], - "y_true": [_Y_TRUE_COL, _Y_TRUE_COLS], - "y_pred": [_Y_PRED_COL, _Y_PRED_COLS], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + average=[None, "micro", "macro", "weighted"], ) - def test_average_binary(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_average_multiclass(self, data_index: int, average: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) - for idx, average in enumerate(params["average"]): - y_true = params["y_true"][idx] - y_pred = params["y_pred"][idx] - actual_f = snowml_metrics.fbeta_score( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - beta=0.5, - average=average, - ) - sklearn_f = sklearn_metrics.fbeta_score( - pandas_df[y_true], - pandas_df[y_pred], - beta=0.5, - average=average, - ) - np.testing.assert_allclose(actual_f, sklearn_f) + actual_f = snowml_metrics.fbeta_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + beta=0.5, + average=average, + ) + sklearn_f = sklearn_metrics.fbeta_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + beta=0.5, + average=average, + ) + np.testing.assert_allclose(actual_f, sklearn_f) + + @parameterized.product( + ( + dict(y_true=_Y_TRUE_COL, y_pred=_Y_PRED_COL, average="binary"), + dict(y_true=_Y_TRUE_COLS, y_pred=_Y_PRED_COLS, average="samples"), + ), + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=(None, _SAMPLE_WEIGHT_COL), + ) + def test_average_binary_samples( + self, + y_true: Union[str, List[str]], + y_pred: Union[str, List[str]], + average: Optional[str], + data_index: int, + sample_weight_col_name: Optional[str], + ) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_f = snowml_metrics.fbeta_score( + df=input_df, + y_true_col_names=y_true, + y_pred_col_names=y_pred, + beta=0.5, + average=average, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_f = sklearn_metrics.fbeta_score( + pandas_df[y_true], + pandas_df[y_pred], + beta=0.5, + average=average, + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_f, sklearn_f) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], + ) + def test_sample_weight_binary(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_f = snowml_metrics.fbeta_score( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + beta=0.5, + average=None, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_f = sklearn_metrics.fbeta_score( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + beta=0.5, + average=None, + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_f, sklearn_f) - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "sample_weight_col_name": [None, _SAMPLE_WEIGHT_COL], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], ) - def test_sample_weight(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for sample_weight_col_name in params["sample_weight_col_name"]: - actual_f = snowml_metrics.fbeta_score( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - beta=0.5, - average=None, - sample_weight_col_name=sample_weight_col_name, - ) - sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None - sklearn_f = sklearn_metrics.fbeta_score( - pandas_df[y_true], - pandas_df[y_pred], - beta=0.5, - average=None, - sample_weight=sample_weight, - ) - np.testing.assert_allclose(actual_f, sklearn_f) - - @parameterized.parameters( # type: ignore[misc] - {"params": {"zero_division": [0, 1]}}, + def test_sample_weight_multiclass(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_f = snowml_metrics.fbeta_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + beta=0.5, + average=None, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_f = sklearn_metrics.fbeta_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + beta=0.5, + average=None, + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_f, sklearn_f) + + @parameterized.product( # type: ignore[misc] + zero_division=[0, 1], ) - def test_zero_division(self, params: Dict[str, Any]) -> None: + def test_zero_division(self, zero_division: Union[str, int]) -> None: data = [ [0, 0, 0, 0, 0, 0], ] pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - for zero_division in params["zero_division"]: - if zero_division == "warn": - continue - + if zero_division != "warn": actual_f = snowml_metrics.fbeta_score( df=input_df, y_true_col_names=_Y_TRUE_COL, @@ -258,6 +279,42 @@ def test_zero_division(self, params: Dict[str, Any]) -> None: ) np.testing.assert_allclose(actual_f, sklearn_f) + def test_with_large_num_of_rows_binary(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_BINARY_DATA, _SF_SCHEMA) + + actual_f = snowml_metrics.fbeta_score( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + beta=0.5, + average=None, + ) + sklearn_f = sklearn_metrics.fbeta_score( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + beta=0.5, + average=None, + ) + np.testing.assert_allclose(actual_f, sklearn_f) + + def test_with_large_num_of_rows_multiclass(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_MULTICLASS_DATA, _SF_SCHEMA) + + actual_f = snowml_metrics.fbeta_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + beta=0.5, + average=None, + ) + sklearn_f = sklearn_metrics.fbeta_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + beta=0.5, + average=None, + ) + np.testing.assert_allclose(actual_f, sklearn_f) + if __name__ == "__main__": main() diff --git a/tests/integ/snowflake/ml/modeling/metrics/generator.py b/tests/integ/snowflake/ml/modeling/metrics/generator.py index 012296d2..7c3fc275 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/generator.py +++ b/tests/integ/snowflake/ml/modeling/metrics/generator.py @@ -25,7 +25,8 @@ def gen_test_cases( high: upper bound(s) of the output interval (exclusive) Returns: - A tuple of test data of multiple sizes and column names + A tuple of test data of multiple sizes and column names. + The 1st column of test data is "ID". """ data_list = [] snowflake_identifiers: List[str] = [] diff --git a/tests/integ/snowflake/ml/modeling/metrics/log_loss_test.py b/tests/integ/snowflake/ml/modeling/metrics/log_loss_test.py index d4f96e77..e61b6bea 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/log_loss_test.py +++ b/tests/integ/snowflake/ml/modeling/metrics/log_loss_test.py @@ -1,5 +1,7 @@ -from typing import Any, Dict +from typing import Optional, Union +import numpy as np +import numpy.typing as npt from absl.testing import parameterized from absl.testing.absltest import main from sklearn import metrics as sklearn_metrics @@ -8,37 +10,29 @@ from snowflake.ml.modeling import metrics as snowml_metrics from snowflake.ml.utils import connection_params from tests.integ.snowflake.ml.modeling.framework import utils +from tests.integ.snowflake.ml.modeling.metrics import generator -_ROWS = 100 _TYPES = [utils.DataType.INTEGER] + [utils.DataType.FLOAT] * 4 -_BINARY_DATA, _SF_SCHEMA = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=[2, 1, 1, 1, 1], -) +_BINARY_LOW, _BINARY_HIGH = 0, [2, 1, 1, 1, 1] +_MULTICLASS_LOW, _MULTICLASS_HIGH = 0, [3, 1, 1, 1, 1] +_BINARY_DATA_LIST, _SF_SCHEMA = generator.gen_test_cases(_TYPES, _BINARY_LOW, _BINARY_HIGH) +_MULTICLASS_DATA_LIST, _ = generator.gen_test_cases(_TYPES, _MULTICLASS_LOW, _MULTICLASS_HIGH) +_REGULAR_BINARY_DATA_LIST, _LARGE_BINARY_DATA = _BINARY_DATA_LIST[:-1], _BINARY_DATA_LIST[-1] +_REGULAR_MULTICLASS_DATA_LIST, _LARGE_MULTICLASS_DATA = _MULTICLASS_DATA_LIST[:-1], _MULTICLASS_DATA_LIST[-1] _BINARY_Y_TRUE_COL = _SF_SCHEMA[1] _BINARY_Y_PRED_COL = _SF_SCHEMA[2] -_MULTICLASS_DATA = [ - [0, 2, 0.29, 0.49, 0.22, 0.18], - [1, 0, 0.33, 0.16, 0.51, 0.69], - [2, 1, 0.54, 0.29, 0.17, 0.04], - [3, 2, 0.27, 0.68, 0.05, 0.17], - [4, 1, 0.82, 0.12, 0.06, 0.91], - [5, 2, 0.08, 0.46, 0.46, 0.76], -] _MULTICLASS_Y_TRUE_COL = _SF_SCHEMA[1] _MULTICLASS_Y_PRED_COLS = [_SF_SCHEMA[2], _SF_SCHEMA[3], _SF_SCHEMA[4]] _SAMPLE_WEIGHT_COL = _SF_SCHEMA[5] -_MULTILABEL_DATA = [ - [1, 0, 1, 0.8, 0.3, 0.6], - [0, 1, 0, 0.2, 0.7, 0.4], - [1, 1, 0, 0.9, 0.6, 0.2], - [0, 0, 1, 0.1, 0.4, 0.8], -] -_MULTILABEL_SCHEMA = ["Y_0", "Y_1", "Y_2", "S_0", "S_1", "S_2"] -_MULTILABEL_Y_TRUE_COLS = [_MULTILABEL_SCHEMA[0], _MULTILABEL_SCHEMA[1], _MULTILABEL_SCHEMA[2]] -_MULTILABEL_Y_PRED_COLS = [_MULTILABEL_SCHEMA[3], _MULTILABEL_SCHEMA[4], _MULTILABEL_SCHEMA[5]] + +_MULTILABEL_TYPES = [utils.DataType.INTEGER] * 3 + [utils.DataType.FLOAT] * 3 +_MULTILABEL_LOW, _MULTILABEL_HIGH = 0, [2, 2, 2, 1, 1, 1] +_MULTILABEL_DATA_LIST, _MULTILABEL_SCHEMA = generator.gen_test_cases( + _MULTILABEL_TYPES, _MULTILABEL_LOW, _MULTILABEL_HIGH +) +_REGULAR_MULTILABEL_DATA_LIST, _LARGE_MULTILABEL_DATA = _MULTILABEL_DATA_LIST[:-1], _MULTILABEL_DATA_LIST[-1] +_MULTILABEL_Y_TRUE_COLS = [_MULTILABEL_SCHEMA[1], _MULTILABEL_SCHEMA[2], _MULTILABEL_SCHEMA[3]] +_MULTILABEL_Y_PRED_COLS = [_MULTILABEL_SCHEMA[4], _MULTILABEL_SCHEMA[5], _MULTILABEL_SCHEMA[6]] class LogLossTest(parameterized.TestCase): @@ -51,125 +45,195 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "eps": ["auto", 0.01, 0.1, 0.5, 0.9, 0.99], - "values": [ - {"data": _BINARY_DATA, "y_true": _BINARY_Y_TRUE_COL, "y_pred": _BINARY_Y_PRED_COL}, - {"data": _MULTICLASS_DATA, "y_true": _MULTICLASS_Y_TRUE_COL, "y_pred": _MULTICLASS_Y_PRED_COLS}, - ], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + eps=["auto", 0.01, 0.1, 0.5, 0.9, 0.99], ) - def test_eps(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for eps in params["eps"]: - actual_loss = snowml_metrics.log_loss( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - eps=eps, - ) - sklearn_loss = sklearn_metrics.log_loss( - pandas_df[y_true], - pandas_df[y_pred], - eps=eps, - ) - self.assertAlmostEqual(sklearn_loss, actual_loss) - - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "normalize": [True, False], - "values": [ - {"data": _BINARY_DATA, "y_true": _BINARY_Y_TRUE_COL, "y_pred": _BINARY_Y_PRED_COL}, - {"data": _MULTICLASS_DATA, "y_true": _MULTICLASS_Y_TRUE_COL, "y_pred": _MULTICLASS_Y_PRED_COLS}, - ], - } - }, + def test_eps_binary(self, data_index: int, eps: Union[float, str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.log_loss( + df=input_df, + y_true_col_names=_BINARY_Y_TRUE_COL, + y_pred_col_names=_BINARY_Y_PRED_COL, + eps=eps, + ) + sklearn_loss = sklearn_metrics.log_loss( + pandas_df[_BINARY_Y_TRUE_COL], + pandas_df[_BINARY_Y_PRED_COL], + eps=eps, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + eps=["auto", 0.01, 0.1, 0.5, 0.9, 0.99], ) - def test_normalize(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for normalize in params["normalize"]: - actual_loss = snowml_metrics.log_loss( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - normalize=normalize, - ) - sklearn_loss = sklearn_metrics.log_loss( - pandas_df[y_true], - pandas_df[y_pred], - normalize=normalize, - ) - self.assertAlmostEqual(sklearn_loss, actual_loss) - - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "sample_weight_col_name": [None, _SAMPLE_WEIGHT_COL], - "values": [ - {"data": _BINARY_DATA, "y_true": _BINARY_Y_TRUE_COL, "y_pred": _BINARY_Y_PRED_COL}, - {"data": _MULTICLASS_DATA, "y_true": _MULTICLASS_Y_TRUE_COL, "y_pred": _MULTICLASS_Y_PRED_COLS}, - ], - } - }, + def test_eps_multiclass(self, data_index: int, eps: Union[float, str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.log_loss( + df=input_df, + y_true_col_names=_MULTICLASS_Y_TRUE_COL, + y_pred_col_names=_MULTICLASS_Y_PRED_COLS, + eps=eps, + ) + sklearn_loss = sklearn_metrics.log_loss( + pandas_df[_MULTICLASS_Y_TRUE_COL], + pandas_df[_MULTICLASS_Y_PRED_COLS], + eps=eps, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + normalize=[True, False], ) - def test_sample_weight(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for sample_weight_col_name in params["sample_weight_col_name"]: - actual_loss = snowml_metrics.log_loss( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - sample_weight_col_name=sample_weight_col_name, - ) - sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None - sklearn_loss = sklearn_metrics.log_loss( - pandas_df[y_true], - pandas_df[y_pred], - sample_weight=sample_weight, - ) - self.assertAlmostEqual(sklearn_loss, actual_loss) - - @parameterized.parameters( # type: ignore[misc] - {"params": {"labels": [None, [2, 0, 4]]}}, + def test_normalize_binary(self, data_index: int, normalize: bool) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.log_loss( + df=input_df, + y_true_col_names=_BINARY_Y_TRUE_COL, + y_pred_col_names=_BINARY_Y_PRED_COL, + normalize=normalize, + ) + sklearn_loss = sklearn_metrics.log_loss( + pandas_df[_BINARY_Y_TRUE_COL], + pandas_df[_BINARY_Y_PRED_COL], + normalize=normalize, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + normalize=[True, False], + ) + def test_normalize_multiclass(self, data_index: int, normalize: bool) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.log_loss( + df=input_df, + y_true_col_names=_MULTICLASS_Y_TRUE_COL, + y_pred_col_names=_MULTICLASS_Y_PRED_COLS, + normalize=normalize, + ) + sklearn_loss = sklearn_metrics.log_loss( + pandas_df[_MULTICLASS_Y_TRUE_COL], + pandas_df[_MULTICLASS_Y_PRED_COLS], + normalize=normalize, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], + ) + def test_sample_weight_binary(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.log_loss( + df=input_df, + y_true_col_names=_BINARY_Y_TRUE_COL, + y_pred_col_names=_BINARY_Y_PRED_COL, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_loss = sklearn_metrics.log_loss( + pandas_df[_BINARY_Y_TRUE_COL], + pandas_df[_BINARY_Y_PRED_COL], + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], ) - def test_labels(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) - - for labels in params["labels"]: - actual_loss = snowml_metrics.log_loss( - df=input_df, - y_true_col_names=_MULTICLASS_Y_TRUE_COL, - y_pred_col_names=_MULTICLASS_Y_PRED_COLS, - labels=labels, - ) - sklearn_loss = sklearn_metrics.log_loss( - pandas_df[_MULTICLASS_Y_TRUE_COL], - pandas_df[_MULTICLASS_Y_PRED_COLS], - labels=labels, - ) - self.assertAlmostEqual(sklearn_loss, actual_loss) - - def test_multilabel(self) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTILABEL_DATA, _MULTILABEL_SCHEMA) + def test_sample_weight_multiclass(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.log_loss( + df=input_df, + y_true_col_names=_MULTICLASS_Y_TRUE_COL, + y_pred_col_names=_MULTICLASS_Y_PRED_COLS, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_loss = sklearn_metrics.log_loss( + pandas_df[_MULTICLASS_Y_TRUE_COL], + pandas_df[_MULTICLASS_Y_PRED_COLS], + sample_weight=sample_weight, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + labels=[None, [2, 0, 4]], + ) + def test_labels(self, data_index: int, labels: Optional[npt.ArrayLike]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_loss = snowml_metrics.log_loss( + df=input_df, + y_true_col_names=_MULTICLASS_Y_TRUE_COL, + y_pred_col_names=_MULTICLASS_Y_PRED_COLS, + labels=labels, + ) + sklearn_loss = sklearn_metrics.log_loss( + pandas_df[_MULTICLASS_Y_TRUE_COL], + pandas_df[_MULTICLASS_Y_PRED_COLS], + labels=labels, + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTILABEL_DATA_LIST))), + ) + def test_multilabel(self, data_index: int) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTILABEL_DATA_LIST[data_index], _MULTILABEL_SCHEMA) + + actual_loss = snowml_metrics.log_loss( + df=input_df, + y_true_col_names=_MULTILABEL_Y_TRUE_COLS, + y_pred_col_names=_MULTILABEL_Y_PRED_COLS, + ) + sklearn_loss = sklearn_metrics.log_loss( + pandas_df[_MULTILABEL_Y_TRUE_COLS], + pandas_df[_MULTILABEL_Y_PRED_COLS], + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + def test_with_large_num_of_rows_binary(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_BINARY_DATA, _SF_SCHEMA) + + actual_loss = snowml_metrics.log_loss( + df=input_df, + y_true_col_names=_BINARY_Y_TRUE_COL, + y_pred_col_names=_BINARY_Y_PRED_COL, + ) + sklearn_loss = sklearn_metrics.log_loss( + pandas_df[_BINARY_Y_TRUE_COL], + pandas_df[_BINARY_Y_PRED_COL], + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + def test_with_large_num_of_rows_multiclass(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_MULTICLASS_DATA, _SF_SCHEMA) + + actual_loss = snowml_metrics.log_loss( + df=input_df, + y_true_col_names=_MULTICLASS_Y_TRUE_COL, + y_pred_col_names=_MULTICLASS_Y_PRED_COLS, + ) + sklearn_loss = sklearn_metrics.log_loss( + pandas_df[_MULTICLASS_Y_TRUE_COL], + pandas_df[_MULTICLASS_Y_PRED_COLS], + ) + np.testing.assert_allclose(actual_loss, sklearn_loss) + + def test_with_large_num_of_rows_multilabel(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_MULTILABEL_DATA, _MULTILABEL_SCHEMA) actual_loss = snowml_metrics.log_loss( df=input_df, @@ -180,7 +244,7 @@ def test_multilabel(self) -> None: pandas_df[_MULTILABEL_Y_TRUE_COLS], pandas_df[_MULTILABEL_Y_PRED_COLS], ) - self.assertAlmostEqual(sklearn_loss, actual_loss) + np.testing.assert_allclose(actual_loss, sklearn_loss) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/modeling/metrics/precision_recall_fscore_support_test.py b/tests/integ/snowflake/ml/modeling/metrics/precision_recall_fscore_support_test.py index d0d26c1e..47341474 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/precision_recall_fscore_support_test.py +++ b/tests/integ/snowflake/ml/modeling/metrics/precision_recall_fscore_support_test.py @@ -1,6 +1,7 @@ -from typing import Any, Dict +from typing import List, Optional, Union import numpy as np +import numpy.typing as npt from absl.testing import parameterized from absl.testing.absltest import main from sklearn import exceptions, metrics as sklearn_metrics @@ -9,21 +10,15 @@ from snowflake.ml.modeling import metrics as snowml_metrics from snowflake.ml.utils import connection_params from tests.integ.snowflake.ml.modeling.framework import utils +from tests.integ.snowflake.ml.modeling.metrics import generator -_ROWS = 100 _TYPES = [utils.DataType.INTEGER] * 4 + [utils.DataType.FLOAT] -_BINARY_DATA, _SF_SCHEMA = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=2, -) -_MULTICLASS_DATA, _ = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=5, -) +_BINARY_LOW, _BINARY_HIGH = 0, 2 +_MULTICLASS_LOW, _MULTICLASS_HIGH = 0, 5 +_BINARY_DATA_LIST, _SF_SCHEMA = generator.gen_test_cases(_TYPES, _BINARY_LOW, _BINARY_HIGH) +_MULTICLASS_DATA_LIST, _ = generator.gen_test_cases(_TYPES, _MULTICLASS_LOW, _MULTICLASS_HIGH) +_REGULAR_BINARY_DATA_LIST, _LARGE_BINARY_DATA = _BINARY_DATA_LIST[:-1], _BINARY_DATA_LIST[-1] +_REGULAR_MULTICLASS_DATA_LIST, _LARGE_MULTICLASS_DATA = _MULTICLASS_DATA_LIST[:-1], _MULTICLASS_DATA_LIST[-1] _Y_TRUE_COL = _SF_SCHEMA[1] _Y_PRED_COL = _SF_SCHEMA[2] _Y_TRUE_COLS = [_SF_SCHEMA[1], _SF_SCHEMA[2]] @@ -41,155 +36,209 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "beta": [1.0, 0.5], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + beta=[1.0, 0.5], ) - def test_beta(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for beta in params["beta"]: - actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - beta=beta, - ) - sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( - pandas_df[y_true], - pandas_df[y_pred], - beta=beta, - ) - np.testing.assert_allclose( - np.array((actual_p, actual_r, actual_f, actual_s)), - np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), - ) + def test_beta_binary(self, data_index: int, beta: float) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) - @parameterized.parameters( # type: ignore[misc] - {"params": {"labels": [None, [2, 0, 4]]}}, + actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + beta=beta, + ) + sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + beta=beta, + ) + np.testing.assert_allclose( + np.array((actual_p, actual_r, actual_f, actual_s)), + np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), + ) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + beta=[1.0, 0.5], ) - def test_labels(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) + def test_beta_multiclass(self, data_index: int, beta: float) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) - for labels in params["labels"]: - actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - labels=labels, - ) - sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - labels=labels, - ) - np.testing.assert_allclose( - np.array((actual_p, actual_r, actual_f, actual_s)), - np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), - ) + actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + beta=beta, + ) + sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + beta=beta, + ) + np.testing.assert_allclose( + np.array((actual_p, actual_r, actual_f, actual_s)), + np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), + ) - @parameterized.parameters( # type: ignore[misc] - {"params": {"pos_label": [0, 2, 4]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + labels=[None, [2, 0, 4]], ) - def test_pos_label(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) + def test_labels(self, data_index: int, labels: Optional[npt.ArrayLike]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) - for pos_label in params["pos_label"]: - actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - pos_label=pos_label, - ) - sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - pos_label=pos_label, - ) - np.testing.assert_allclose( - np.array((actual_p, actual_r, actual_f, actual_s)), - np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), - ) + actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + labels=labels, + ) + sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + labels=labels, + ) + np.testing.assert_allclose( + np.array((actual_p, actual_r, actual_f, actual_s)), + np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), + ) - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "sample_weight_col_name": [None, _SAMPLE_WEIGHT_COL], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + pos_label=[0, 2, 4], ) - def test_sample_weight(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for sample_weight_col_name in params["sample_weight_col_name"]: - actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - sample_weight_col_name=sample_weight_col_name, - ) - sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None - sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( - pandas_df[y_true], - pandas_df[y_pred], - sample_weight=sample_weight, - ) - np.testing.assert_allclose( - np.array((actual_p, actual_r, actual_f, actual_s)), - np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), - ) + def test_pos_label(self, data_index: int, pos_label: Union[str, int]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) - @parameterized.parameters( # type: ignore[misc] - {"params": {"average": [None, "micro", "macro", "weighted"]}}, + actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + pos_label=pos_label, + ) + sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + pos_label=pos_label, + ) + np.testing.assert_allclose( + np.array((actual_p, actual_r, actual_f, actual_s)), + np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), + ) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], ) - def test_average_multiclass(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) + def test_sample_weight_binary(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) - for average in params["average"]: - actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - average=average, - ) - sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - average=average, - ) - np.testing.assert_allclose( - np.array((actual_p, actual_r, actual_f, actual_s), dtype=np.float_), - np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s), dtype=np.float_), - ) + actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + sample_weight=sample_weight, + ) + np.testing.assert_allclose( + np.array((actual_p, actual_r, actual_f, actual_s)), + np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), + ) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], + ) + def test_sample_weight_multiclass(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + sample_weight_col_name=sample_weight_col_name, + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + sample_weight=sample_weight, + ) + np.testing.assert_allclose( + np.array((actual_p, actual_r, actual_f, actual_s)), + np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), + ) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + average=[None, "micro", "macro", "weighted"], + ) + def test_average_binary(self, data_index: int, average: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + average=average, + ) + sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + average=average, + ) + np.testing.assert_allclose( + np.array((actual_p, actual_r, actual_f, actual_s), dtype=np.float_), + np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s), dtype=np.float_), + ) + + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + average=[None, "micro", "macro", "weighted"], + ) + def test_average_multiclass(self, data_index: int, average: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + average=average, + ) + sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + average=average, + ) + np.testing.assert_allclose( + np.array((actual_p, actual_r, actual_f, actual_s), dtype=np.float_), + np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s), dtype=np.float_), + ) @parameterized.product( ( dict(y_true=_Y_TRUE_COL, y_pred=_Y_PRED_COL, average="binary"), dict(y_true=_Y_TRUE_COLS, y_pred=_Y_PRED_COLS, average="samples"), ), + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), sample_weight_col_name=(None, _SAMPLE_WEIGHT_COL), ) - def test_average_binary_samples(self, y_true, y_pred, average, sample_weight_col_name) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_average_binary_samples( + self, + y_true: Union[str, List[str]], + y_pred: Union[str, List[str]], + average: Optional[str], + data_index: int, + sample_weight_col_name: Optional[str], + ) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( df=input_df, @@ -200,27 +249,45 @@ def test_average_binary_samples(self, y_true, y_pred, average, sample_weight_col ) sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( - pandas_df[y_true], pandas_df[y_pred], average=average, sample_weight=sample_weight + pandas_df[y_true], + pandas_df[y_pred], + average=average, + sample_weight=sample_weight, ) np.testing.assert_allclose( np.array((actual_p, actual_r, actual_f, actual_s), dtype=np.float_), np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s), dtype=np.float_), ) - @parameterized.parameters( # type: ignore[misc] - {"params": {"zero_division": ["warn", 0, 1]}}, + @parameterized.product( + zero_division=["warn", 0, 1], ) - def test_zero_division(self, params: Dict[str, Any]) -> None: + def test_zero_division(self, zero_division: Union[str, int]) -> None: data = [ [0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], ] pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - for zero_division in params["zero_division"]: - if zero_division == "warn": - continue + if zero_division == "warn": + sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + zero_division="warn", + ) + with self.assertWarns(exceptions.UndefinedMetricWarning): + actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + zero_division="warn", + ) + np.testing.assert_allclose( + np.array((actual_p, actual_r, actual_f, actual_s)), + np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), + ) + else: actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( df=input_df, y_true_col_names=_Y_TRUE_COL, @@ -237,28 +304,43 @@ def test_zero_division(self, params: Dict[str, Any]) -> None: np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), ) - # warn + def test_no_sample(self) -> None: + data = [] + pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) + + actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + ) sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( pandas_df[_Y_TRUE_COL], pandas_df[_Y_PRED_COL], - zero_division="warn", + ) + np.testing.assert_allclose( + np.array((actual_p, actual_r, actual_f, actual_s)), + np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), ) - with self.assertWarns(exceptions.UndefinedMetricWarning): - actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - zero_division="warn", - ) - np.testing.assert_allclose( - np.array((actual_p, actual_r, actual_f, actual_s)), - np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), - ) + def test_with_large_num_of_rows_binary(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_BINARY_DATA, _SF_SCHEMA) - def test_no_sample(self) -> None: - data = [] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) + actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + ) + sklearn_p, sklearn_r, sklearn_f, sklearn_s = sklearn_metrics.precision_recall_fscore_support( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + ) + np.testing.assert_allclose( + np.array((actual_p, actual_r, actual_f, actual_s)), + np.array((sklearn_p, sklearn_r, sklearn_f, sklearn_s)), + ) + + def test_with_large_num_of_rows_multiclass(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_MULTICLASS_DATA, _SF_SCHEMA) actual_p, actual_r, actual_f, actual_s = snowml_metrics.precision_recall_fscore_support( df=input_df, diff --git a/tests/integ/snowflake/ml/modeling/metrics/precision_score_test.py b/tests/integ/snowflake/ml/modeling/metrics/precision_score_test.py index a661c5dc..bba57655 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/precision_score_test.py +++ b/tests/integ/snowflake/ml/modeling/metrics/precision_score_test.py @@ -1,6 +1,7 @@ -from typing import Any, Dict +from typing import List, Optional, Union import numpy as np +import numpy.typing as npt from absl.testing import parameterized from absl.testing.absltest import main from sklearn import exceptions, metrics as sklearn_metrics @@ -9,21 +10,15 @@ from snowflake.ml.modeling import metrics as snowml_metrics from snowflake.ml.utils import connection_params from tests.integ.snowflake.ml.modeling.framework import utils +from tests.integ.snowflake.ml.modeling.metrics import generator -_ROWS = 100 _TYPES = [utils.DataType.INTEGER] * 4 + [utils.DataType.FLOAT] -_BINARY_DATA, _SF_SCHEMA = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=2, -) -_MULTICLASS_DATA, _ = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=5, -) +_BINARY_LOW, _BINARY_HIGH = 0, 2 +_MULTICLASS_LOW, _MULTICLASS_HIGH = 0, 5 +_BINARY_DATA_LIST, _SF_SCHEMA = generator.gen_test_cases(_TYPES, _BINARY_LOW, _BINARY_HIGH) +_MULTICLASS_DATA_LIST, _ = generator.gen_test_cases(_TYPES, _MULTICLASS_LOW, _MULTICLASS_HIGH) +_REGULAR_BINARY_DATA_LIST, _LARGE_BINARY_DATA = _BINARY_DATA_LIST[:-1], _BINARY_DATA_LIST[-1] +_REGULAR_MULTICLASS_DATA_LIST, _LARGE_MULTICLASS_DATA = _MULTICLASS_DATA_LIST[:-1], _MULTICLASS_DATA_LIST[-1] _Y_TRUE_COL = _SF_SCHEMA[1] _Y_PRED_COL = _SF_SCHEMA[2] _Y_TRUE_COLS = [_SF_SCHEMA[1], _SF_SCHEMA[2]] @@ -41,147 +36,156 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @parameterized.parameters( # type: ignore[misc] - {"params": {"labels": [None, [2, 0, 4]]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + labels=[None, [2, 0, 4]], ) - def test_labels(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) - - for labels in params["labels"]: - actual_p = snowml_metrics.precision_score( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - labels=labels, - average=None, - ) - sklearn_p = sklearn_metrics.precision_score( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - labels=labels, - average=None, - ) - np.testing.assert_allclose(actual_p, sklearn_p) + def test_labels(self, data_index: int, labels: Optional[npt.ArrayLike]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_p = snowml_metrics.precision_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + labels=labels, + average=None, + ) + sklearn_p = sklearn_metrics.precision_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + labels=labels, + average=None, + ) + np.testing.assert_allclose(actual_p, sklearn_p) - @parameterized.parameters( # type: ignore[misc] - {"params": {"pos_label": [0, 2, 4]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + pos_label=[0, 2, 4], ) - def test_pos_label(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) - - for pos_label in params["pos_label"]: - actual_p = snowml_metrics.precision_score( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - pos_label=pos_label, - average="micro", - ) - sklearn_p = sklearn_metrics.precision_score( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - pos_label=pos_label, - average="micro", - ) - np.testing.assert_allclose(actual_p, sklearn_p) + def test_pos_label(self, data_index: int, pos_label: Union[str, int]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_p = snowml_metrics.precision_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + pos_label=pos_label, + average="micro", + ) + sklearn_p = sklearn_metrics.precision_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + pos_label=pos_label, + average="micro", + ) + np.testing.assert_allclose(actual_p, sklearn_p) - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "sample_weight_col_name": [None, _SAMPLE_WEIGHT_COL], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, - ) - def test_sample_weight(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for sample_weight_col_name in params["sample_weight_col_name"]: - actual_p = snowml_metrics.precision_score( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - sample_weight_col_name=sample_weight_col_name, - average="micro", - ) - sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None - sklearn_p = sklearn_metrics.precision_score( - pandas_df[y_true], - pandas_df[y_pred], - sample_weight=sample_weight, - average="micro", - ) - np.testing.assert_allclose(actual_p, sklearn_p) - - @parameterized.parameters( # type: ignore[misc] - {"params": {"average": [None, "micro", "macro", "weighted"]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], ) - def test_average_multiclass(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) - - for average in params["average"]: - actual_p = snowml_metrics.precision_score( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - average=average, - ) - sklearn_p = sklearn_metrics.precision_score( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - average=average, - ) - np.testing.assert_allclose(actual_p, sklearn_p) + def test_sample_weight_binary(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_p = snowml_metrics.precision_score( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + sample_weight_col_name=sample_weight_col_name, + average="micro", + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_p = sklearn_metrics.precision_score( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + sample_weight=sample_weight, + average="micro", + ) + np.testing.assert_allclose(actual_p, sklearn_p) - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "average": ["binary", "samples"], - "y_true": [_Y_TRUE_COL, _Y_TRUE_COLS], - "y_pred": [_Y_PRED_COL, _Y_PRED_COLS], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], ) - def test_average_binary(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_sample_weight_multiclass(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_p = snowml_metrics.precision_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + sample_weight_col_name=sample_weight_col_name, + average="micro", + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_p = sklearn_metrics.precision_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + sample_weight=sample_weight, + average="micro", + ) + np.testing.assert_allclose(actual_p, sklearn_p) - for idx, average in enumerate(params["average"]): - y_true = params["y_true"][idx] - y_pred = params["y_pred"][idx] - actual_p = snowml_metrics.precision_score( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - average=average, - ) - sklearn_p = sklearn_metrics.precision_score( - pandas_df[y_true], - pandas_df[y_pred], - average=average, - ) - np.testing.assert_allclose(actual_p, sklearn_p) + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + average=[None, "micro", "macro", "weighted"], + ) + def test_average_multiclass(self, data_index: int, average: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_p = snowml_metrics.precision_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + average=average, + ) + sklearn_p = sklearn_metrics.precision_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + average=average, + ) + np.testing.assert_allclose(actual_p, sklearn_p) + + @parameterized.product( + ( + dict(y_true=_Y_TRUE_COL, y_pred=_Y_PRED_COL, average="binary"), + dict(y_true=_Y_TRUE_COLS, y_pred=_Y_PRED_COLS, average="samples"), + ), + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + ) + def test_average_binary_samples( + self, + y_true: Union[str, List[str]], + y_pred: Union[str, List[str]], + average: Optional[str], + data_index: int, + ) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_p = snowml_metrics.precision_score( + df=input_df, + y_true_col_names=y_true, + y_pred_col_names=y_pred, + average=average, + ) + sklearn_p = sklearn_metrics.precision_score( + pandas_df[y_true], + pandas_df[y_pred], + average=average, + ) + np.testing.assert_allclose(actual_p, sklearn_p) - @parameterized.parameters( # type: ignore[misc] - {"params": {"zero_division": ["warn", 0, 1]}}, + @parameterized.product( # type: ignore[misc] + zero_division=[0, 1], ) - def test_zero_division(self, params: Dict[str, Any]) -> None: + def test_zero_division(self, zero_division: Union[str, int]) -> None: data = [ [0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], ] pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - for zero_division in params["zero_division"]: - if zero_division == "warn": - continue - + if zero_division != "warn": actual_p = snowml_metrics.precision_score( df=input_df, y_true_col_names=_Y_TRUE_COL, @@ -211,6 +215,38 @@ def test_zero_division(self, params: Dict[str, Any]) -> None: ) np.testing.assert_allclose(actual_p, sklearn_p) + def test_with_large_num_of_rows_binary(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_BINARY_DATA, _SF_SCHEMA) + + actual_p = snowml_metrics.precision_score( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + average=None, + ) + sklearn_p = sklearn_metrics.precision_score( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + average=None, + ) + np.testing.assert_allclose(actual_p, sklearn_p) + + def test_with_large_num_of_rows_multiclass(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_MULTICLASS_DATA, _SF_SCHEMA) + + actual_p = snowml_metrics.precision_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + average=None, + ) + sklearn_p = sklearn_metrics.precision_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + average=None, + ) + np.testing.assert_allclose(actual_p, sklearn_p) + if __name__ == "__main__": main() diff --git a/tests/integ/snowflake/ml/modeling/metrics/recall_score_test.py b/tests/integ/snowflake/ml/modeling/metrics/recall_score_test.py index 323546f8..950ddd0e 100644 --- a/tests/integ/snowflake/ml/modeling/metrics/recall_score_test.py +++ b/tests/integ/snowflake/ml/modeling/metrics/recall_score_test.py @@ -1,6 +1,7 @@ -from typing import Any, Dict +from typing import List, Optional, Union import numpy as np +import numpy.typing as npt from absl.testing import parameterized from absl.testing.absltest import main from sklearn import exceptions, metrics as sklearn_metrics @@ -9,21 +10,15 @@ from snowflake.ml.modeling import metrics as snowml_metrics from snowflake.ml.utils import connection_params from tests.integ.snowflake.ml.modeling.framework import utils +from tests.integ.snowflake.ml.modeling.metrics import generator -_ROWS = 100 _TYPES = [utils.DataType.INTEGER] * 4 + [utils.DataType.FLOAT] -_BINARY_DATA, _SF_SCHEMA = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=2, -) -_MULTICLASS_DATA, _ = utils.gen_fuzz_data( - rows=_ROWS, - types=_TYPES, - low=0, - high=5, -) +_BINARY_LOW, _BINARY_HIGH = 0, 2 +_MULTICLASS_LOW, _MULTICLASS_HIGH = 0, 5 +_BINARY_DATA_LIST, _SF_SCHEMA = generator.gen_test_cases(_TYPES, _BINARY_LOW, _BINARY_HIGH) +_MULTICLASS_DATA_LIST, _ = generator.gen_test_cases(_TYPES, _MULTICLASS_LOW, _MULTICLASS_HIGH) +_REGULAR_BINARY_DATA_LIST, _LARGE_BINARY_DATA = _BINARY_DATA_LIST[:-1], _BINARY_DATA_LIST[-1] +_REGULAR_MULTICLASS_DATA_LIST, _LARGE_MULTICLASS_DATA = _MULTICLASS_DATA_LIST[:-1], _MULTICLASS_DATA_LIST[-1] _Y_TRUE_COL = _SF_SCHEMA[1] _Y_PRED_COL = _SF_SCHEMA[2] _Y_TRUE_COLS = [_SF_SCHEMA[1], _SF_SCHEMA[2]] @@ -41,147 +36,156 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @parameterized.parameters( # type: ignore[misc] - {"params": {"labels": [None, [2, 0, 4]]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + labels=[None, [2, 0, 4]], ) - def test_labels(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) - - for labels in params["labels"]: - actual_r = snowml_metrics.recall_score( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - labels=labels, - average=None, - ) - sklearn_r = sklearn_metrics.recall_score( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - labels=labels, - average=None, - ) - np.testing.assert_allclose(actual_r, sklearn_r) + def test_labels(self, data_index: int, labels: Optional[npt.ArrayLike]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_r = snowml_metrics.recall_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + labels=labels, + average=None, + ) + sklearn_r = sklearn_metrics.recall_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + labels=labels, + average=None, + ) + np.testing.assert_allclose(actual_r, sklearn_r) - @parameterized.parameters( # type: ignore[misc] - {"params": {"pos_label": [0, 2, 4]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + pos_label=[0, 2, 4], ) - def test_pos_label(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) - - for pos_label in params["pos_label"]: - actual_r = snowml_metrics.recall_score( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - pos_label=pos_label, - average="micro", - ) - sklearn_r = sklearn_metrics.recall_score( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - pos_label=pos_label, - average="micro", - ) - np.testing.assert_allclose(actual_r, sklearn_r) + def test_pos_label(self, data_index: int, pos_label: Union[str, int]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_r = snowml_metrics.recall_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + pos_label=pos_label, + average="micro", + ) + sklearn_r = sklearn_metrics.recall_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + pos_label=pos_label, + average="micro", + ) + np.testing.assert_allclose(actual_r, sklearn_r) - @parameterized.parameters( # type: ignore[misc] - {"params": {"average": [None, "micro", "macro", "weighted"]}}, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + average=[None, "micro", "macro", "weighted"], ) - def test_average_multiclass(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _MULTICLASS_DATA, _SF_SCHEMA) - - for average in params["average"]: - actual_r = snowml_metrics.recall_score( - df=input_df, - y_true_col_names=_Y_TRUE_COL, - y_pred_col_names=_Y_PRED_COL, - average=average, - ) - sklearn_r = sklearn_metrics.recall_score( - pandas_df[_Y_TRUE_COL], - pandas_df[_Y_PRED_COL], - average=average, - ) - np.testing.assert_allclose(actual_r, sklearn_r) - - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "average": ["binary", "samples"], - "y_true": [_Y_TRUE_COL, _Y_TRUE_COLS], - "y_pred": [_Y_PRED_COL, _Y_PRED_COLS], - } - }, + def test_average_multiclass(self, data_index: int, average: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_r = snowml_metrics.recall_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + average=average, + ) + sklearn_r = sklearn_metrics.recall_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + average=average, + ) + np.testing.assert_allclose(actual_r, sklearn_r) + + @parameterized.product( + ( + dict(y_true=_Y_TRUE_COL, y_pred=_Y_PRED_COL, average="binary"), + dict(y_true=_Y_TRUE_COLS, y_pred=_Y_PRED_COLS, average="samples"), + ), + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), ) - def test_average_binary(self, params: Dict[str, Any]) -> None: - pandas_df, input_df = utils.get_df(self._session, _BINARY_DATA, _SF_SCHEMA) + def test_average_binary_samples( + self, + y_true: Union[str, List[str]], + y_pred: Union[str, List[str]], + average: Optional[str], + data_index: int, + ) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_r = snowml_metrics.recall_score( + df=input_df, + y_true_col_names=y_true, + y_pred_col_names=y_pred, + average=average, + ) + sklearn_r = sklearn_metrics.recall_score( + pandas_df[y_true], + pandas_df[y_pred], + average=average, + ) + np.testing.assert_allclose(actual_r, sklearn_r) - for idx, average in enumerate(params["average"]): - y_true = params["y_true"][idx] - y_pred = params["y_pred"][idx] - actual_r = snowml_metrics.recall_score( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - average=average, - ) - sklearn_r = sklearn_metrics.recall_score( - pandas_df[y_true], - pandas_df[y_pred], - average=average, - ) - np.testing.assert_allclose(actual_r, sklearn_r) + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_BINARY_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], + ) + def test_sample_weight_binary(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_BINARY_DATA_LIST[data_index], _SF_SCHEMA) + + actual_r = snowml_metrics.recall_score( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + sample_weight_col_name=sample_weight_col_name, + average="micro", + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_r = sklearn_metrics.recall_score( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + sample_weight=sample_weight, + average="micro", + ) + np.testing.assert_allclose(actual_r, sklearn_r) - @parameterized.parameters( # type: ignore[misc] - { - "params": { - "sample_weight_col_name": [None, _SAMPLE_WEIGHT_COL], - "values": [ - {"data": _BINARY_DATA, "y_true": _Y_TRUE_COLS, "y_pred": _Y_PRED_COLS}, - {"data": _MULTICLASS_DATA, "y_true": _Y_TRUE_COL, "y_pred": _Y_PRED_COL}, - ], - } - }, + @parameterized.product( # type: ignore[misc] + data_index=list(range(len(_REGULAR_MULTICLASS_DATA_LIST))), + sample_weight_col_name=[None, _SAMPLE_WEIGHT_COL], ) - def test_sample_weight(self, params: Dict[str, Any]) -> None: - for values in params["values"]: - data = values["data"] - y_true = values["y_true"] - y_pred = values["y_pred"] - pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - - for sample_weight_col_name in params["sample_weight_col_name"]: - actual_r = snowml_metrics.recall_score( - df=input_df, - y_true_col_names=y_true, - y_pred_col_names=y_pred, - sample_weight_col_name=sample_weight_col_name, - average="micro", - ) - sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None - sklearn_r = sklearn_metrics.recall_score( - pandas_df[y_true], - pandas_df[y_pred], - sample_weight=sample_weight, - average="micro", - ) - np.testing.assert_allclose(actual_r, sklearn_r) - - @parameterized.parameters( # type: ignore[misc] - {"params": {"zero_division": ["warn", 0, 1]}}, + def test_sample_weight_multiclass(self, data_index: int, sample_weight_col_name: Optional[str]) -> None: + pandas_df, input_df = utils.get_df(self._session, _REGULAR_MULTICLASS_DATA_LIST[data_index], _SF_SCHEMA) + + actual_r = snowml_metrics.recall_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + sample_weight_col_name=sample_weight_col_name, + average="micro", + ) + sample_weight = pandas_df[sample_weight_col_name].to_numpy() if sample_weight_col_name else None + sklearn_r = sklearn_metrics.recall_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + sample_weight=sample_weight, + average="micro", + ) + np.testing.assert_allclose(actual_r, sklearn_r) + + @parameterized.product( # type: ignore[misc] + zero_division=[0, 1], ) - def test_zero_division(self, params: Dict[str, Any]) -> None: + def test_zero_division(self, zero_division: Union[str, int]) -> None: data = [ [0, 0, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0], ] pandas_df, input_df = utils.get_df(self._session, data, _SF_SCHEMA) - for zero_division in params["zero_division"]: - if zero_division == "warn": - continue - + if zero_division != "warn": actual_r = snowml_metrics.recall_score( df=input_df, y_true_col_names=_Y_TRUE_COL, @@ -211,6 +215,38 @@ def test_zero_division(self, params: Dict[str, Any]) -> None: ) np.testing.assert_allclose(actual_r, sklearn_r) + def test_with_large_num_of_rows_binary(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_BINARY_DATA, _SF_SCHEMA) + + actual_r = snowml_metrics.recall_score( + df=input_df, + y_true_col_names=_Y_TRUE_COLS, + y_pred_col_names=_Y_PRED_COLS, + average=None, + ) + sklearn_r = sklearn_metrics.recall_score( + pandas_df[_Y_TRUE_COLS], + pandas_df[_Y_PRED_COLS], + average=None, + ) + np.testing.assert_allclose(actual_r, sklearn_r) + + def test_with_large_num_of_rows_multiclass(self) -> None: + pandas_df, input_df = utils.get_df(self._session, _LARGE_MULTICLASS_DATA, _SF_SCHEMA) + + actual_r = snowml_metrics.recall_score( + df=input_df, + y_true_col_names=_Y_TRUE_COL, + y_pred_col_names=_Y_PRED_COL, + average=None, + ) + sklearn_r = sklearn_metrics.recall_score( + pandas_df[_Y_TRUE_COL], + pandas_df[_Y_PRED_COL], + average=None, + ) + np.testing.assert_allclose(actual_r, sklearn_r) + if __name__ == "__main__": main() diff --git a/tests/integ/snowflake/ml/modeling/model_selection/check_sklearn_inference_test.py b/tests/integ/snowflake/ml/modeling/model_selection/check_sklearn_inference_test.py index 53916ad1..e0a7f59a 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/check_sklearn_inference_test.py +++ b/tests/integ/snowflake/ml/modeling/model_selection/check_sklearn_inference_test.py @@ -6,7 +6,6 @@ from sklearn.datasets import load_iris from sklearn.linear_model import LinearRegression -from snowflake.ml._internal.exceptions import exceptions from snowflake.ml.modeling.model_selection import ( # type: ignore[attr-defined] GridSearchCV, RandomizedSearchCV, @@ -42,38 +41,37 @@ def setUp(self) -> None: def test_sklearn_inference_gridsearch(self) -> None: reg = GridSearchCV( - estimator=LinearRegression(), param_grid={"fit_intercept": [True, False], "positive": [True, False]} + estimator=LinearRegression(), + input_cols=self._input_cols, + param_grid={"fit_intercept": [True, False], "positive": [True, False]}, ) - reg.set_input_cols(self._input_cols) reg.set_label_cols(self._label_col) reg.set_drop_input_cols(True) reg.fit(self._input_df_pandas) # In predict function, the pandas dataframe's column name is actually wrong (["1"]) # it would raise error - with self.assertRaises(exceptions.SnowflakeMLException): - reg._sklearn_inference(pd.DataFrame({"1": []}), "predict", [""]) - # in the pandas dataframe's column name, some of them are single quoted - # some of them are double quoted + with self.assertRaises(ValueError): + reg.predict(pd.DataFrame({"1": []})) + + # test_pd = self._input_df_pandas test_pd.columns = [ '"sepal_length_cm"', - "sepal_width_cm", + "sepal_width_cm", # mismatched quotation style from input '"petal_length_cm"', - "petal_width_cm", + '"petal_width_cm"', '"target"', '"index"', ] - reg._sklearn_inference(test_pd, "predict", [""]) - # When output cols is an empty array ([]) - # it would raise error - with self.assertRaises(exceptions.SnowflakeMLException): - reg._sklearn_inference(self._input_df_pandas, "predict", []) + with self.assertRaises(ValueError): + reg.predict(test_pd) def test_sklearn_inference_randomizedsearch(self) -> None: reg = RandomizedSearchCV( estimator=LinearRegression(), + input_cols=self._input_cols, param_distributions={"fit_intercept": [True, False], "positive": [True, False]}, ) reg.set_input_cols(self._input_cols) @@ -82,26 +80,23 @@ def test_sklearn_inference_randomizedsearch(self) -> None: reg.fit(self._input_df_pandas) # In predict function, the pandas dataframe's column name is actually wrong (["1"]) # it would raise error - with self.assertRaises(exceptions.SnowflakeMLException): - reg._sklearn_inference(pd.DataFrame({"1": []}), "predict", [""]) + with self.assertRaises(ValueError): + reg.predict(pd.DataFrame({"1": []})) # in the pandas dataframe's column name, some of them are single quoted # some of them are double quoted test_pd = self._input_df_pandas test_pd.columns = [ '"sepal_length_cm"', - "sepal_width_cm", + "sepal_width_cm", # mismatched quotation style from input '"petal_length_cm"', - "petal_width_cm", + '"petal_width_cm"', '"target"', '"index"', ] - reg._sklearn_inference(test_pd, "predict", [""]) - # When output cols is an empty array ([]) - # it would raise error - with self.assertRaises(exceptions.SnowflakeMLException): - reg._sklearn_inference(self._input_df_pandas, "predict", []) + with self.assertRaises(ValueError): + reg.predict(test_pd) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/registry/model/BUILD.bazel b/tests/integ/snowflake/ml/registry/model/BUILD.bazel index 4dd54706..e6ed87e2 100644 --- a/tests/integ/snowflake/ml/registry/model/BUILD.bazel +++ b/tests/integ/snowflake/ml/registry/model/BUILD.bazel @@ -2,6 +2,25 @@ load("//bazel:py_rules.bzl", "py_library", "py_test") package(default_visibility = ["//tests/integ/snowflake/ml:__subpackages__"]) +filegroup( + name = "ext_module", + srcs = glob([ + "my_module/**", + ]), +) + +py_test( + name = "additional_import_test", + srcs = ["additional_import_test.py"], + data = [":ext_module"], + deps = [ + "//snowflake/ml/model:type_hints", + "//snowflake/ml/registry", + "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:db_manager", + ], +) + py_library( name = "registry_model_test_base", testonly = True, diff --git a/tests/integ/snowflake/ml/registry/model/additional_import_test.py b/tests/integ/snowflake/ml/registry/model/additional_import_test.py new file mode 100644 index 00000000..7316e36d --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/additional_import_test.py @@ -0,0 +1,129 @@ +import importlib +import uuid +from functools import partial +from typing import Literal + +import importlib_resources +import numpy as np +import pandas as pd +import xgboost as xgb +from absl.testing import absltest, parameterized +from sklearn import compose, datasets, impute, pipeline, preprocessing + +from snowflake.ml import registry +from snowflake.ml.model import custom_model, model_signature +from snowflake.ml.utils import connection_params +from snowflake.snowpark import Session +from tests.integ.snowflake.ml.registry.model.my_module.utils import column_labeller +from tests.integ.snowflake.ml.test_utils import db_manager + + +class ModelWithAdditionalImportTest(parameterized.TestCase): + def setUp(self) -> None: + """Creates Snowpark and Snowflake environments for testing.""" + login_options = connection_params.SnowflakeLoginOptions() + + self._run_id = uuid.uuid4().hex + self._test_db = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self._run_id, "db").upper() + self._test_schema = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( + self._run_id, "schema" + ).upper() + + self._session = Session.builder.configs( + { + **login_options, + **{"database": self._test_db, "schema": self._test_schema}, + } + ).create() + + self._db_manager = db_manager.DBManager(self._session) + self._db_manager.create_database(self._test_db) + self._db_manager.create_schema(self._test_schema) + self._db_manager.cleanup_databases(expire_hours=6) + self.registry = registry.Registry(self._session) + + def tearDown(self) -> None: + self._db_manager.drop_database(self._test_db) + self._session.close() + + @parameterized.parameters([{"import_method": "ext_modules"}, {"import_method": "code_paths"}]) + def test_additional_import(self, import_method: Literal["ext_modules", "code_paths"]) -> None: + name = f"model_{self._run_id}" + version = f"ver_{import_method}" + + X, y = datasets.make_classification() + X = pd.DataFrame(X, columns=["X" + str(i) for i in range(20)]) + log_trans = pipeline.Pipeline( + [ + ("impute", impute.SimpleImputer()), + ("scaler", preprocessing.MinMaxScaler()), + ( + "logger", + preprocessing.FunctionTransformer( + np.log1p, + feature_names_out=partial(column_labeller, "LOG"), + ), + ), + ] + ) + preproc_pipe = compose.ColumnTransformer( + [("log", log_trans, ["X0", "X1"])], + remainder="passthrough", + verbose_feature_names_out=False, + ) + preproc_pipe.set_output(transform="pandas") + preproc_pipe.fit(X, y) + + xgb_data = xgb.DMatrix(preproc_pipe.transform(X), y) + booster = xgb.train(dict(max_depth=5), xgb_data, num_boost_round=10) + + class MyModel(custom_model.CustomModel): + def __init__(self, context: custom_model.ModelContext) -> None: + super().__init__(context) + + @custom_model.inference_api + def predict(self, X: pd.DataFrame) -> pd.DataFrame: + xgb_data = xgb.DMatrix(self.context.model_ref("pipeline").transform(X)) + preds = self.context.model_ref("model").predict(xgb_data) + res_df = pd.DataFrame({"output": preds}) + return res_df + + my_model = MyModel( + custom_model.ModelContext( + models={ + "pipeline": preproc_pipe, + "model": booster, + }, + artifacts={}, + ) + ) + + sig = model_signature.ModelSignature( + inputs=[model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name=f"X{i}") for i in range(20)], + outputs=[model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="output")], + ) + + if import_method == "ext_modules": + my_module = importlib.import_module("tests.integ.snowflake.ml.registry.model.my_module") + mv = self.registry.log_model( + my_model, + model_name=name, + version_name=version, + signatures={"predict": sig}, + ext_modules=[my_module], + ) + else: + code_path = importlib_resources.files("tests").joinpath("") + mv = self.registry.log_model( + my_model, + model_name=name, + version_name=version, + signatures={"predict": sig}, + code_paths=[code_path], + ) + + mv.run(X, function_name="predict") + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/my_module/__init__.py b/tests/integ/snowflake/ml/registry/model/my_module/__init__.py new file mode 100644 index 00000000..091ce622 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/my_module/__init__.py @@ -0,0 +1,3 @@ +from tests.integ.snowflake.ml.registry.model.my_module import utils + +__all__ = ["utils"] diff --git a/tests/integ/snowflake/ml/registry/model/my_module/utils.py b/tests/integ/snowflake/ml/registry/model/my_module/utils.py new file mode 100644 index 00000000..48c9ddc2 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/my_module/utils.py @@ -0,0 +1,2 @@ +def column_labeller(suffix, self, columns): + return [suffix + "_" + c for c in columns] diff --git a/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py b/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py index c8ee0936..057cec3f 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py +++ b/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py @@ -48,7 +48,7 @@ def _test_registry_model( options: Optional[model_types.ModelSaveOption] = None, ) -> None: conda_dependencies = [ - test_env_utils.get_latest_package_version_spec_in_server(self._session, "snowflake-snowpark-python") + test_env_utils.get_latest_package_version_spec_in_server(self._session, "snowflake-snowpark-python!=1.12.0") ] if additional_dependencies: conda_dependencies.extend(additional_dependencies) diff --git a/tests/integ/snowflake/ml/registry/model_registry_integ_test.py b/tests/integ/snowflake/ml/registry/model_registry_integ_test.py index 802fe2a5..2ae50924 100644 --- a/tests/integ/snowflake/ml/registry/model_registry_integ_test.py +++ b/tests/integ/snowflake/ml/registry/model_registry_integ_test.py @@ -64,7 +64,9 @@ def test_basic_workflow(self) -> None: model=model, tags=model_tags, conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server(self._session, "snowflake-snowpark-python") + test_env_utils.get_latest_package_version_spec_in_server( + self._session, "snowflake-snowpark-python!=1.12.0" + ) ], sample_input_data=test_features, options={"embed_local_ml_library": True}, @@ -79,7 +81,9 @@ def test_basic_workflow(self) -> None: model=model, tags={"stage": "testing", "classifier_type": "svm.SVC"}, conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server(self._session, "snowflake-snowpark-python") + test_env_utils.get_latest_package_version_spec_in_server( + self._session, "snowflake-snowpark-python!=1.12.0" + ) ], sample_input_data=test_features, options={"embed_local_ml_library": True}, @@ -281,7 +285,9 @@ def test_snowml_model(self, model_prepare_callable: callable) -> None: model_version=model_version, model=model, conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server(self._session, "snowflake-snowpark-python") + test_env_utils.get_latest_package_version_spec_in_server( + self._session, "snowflake-snowpark-python!=1.12.0" + ) ], options={"embed_local_ml_library": True}, ) @@ -330,7 +336,9 @@ def test_snowml_pipeline(self) -> None: model_version=model_version, model=model, conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server(self._session, "snowflake-snowpark-python") + test_env_utils.get_latest_package_version_spec_in_server( + self._session, "snowflake-snowpark-python!=1.12.0" + ) ], options={"embed_local_ml_library": True}, ) @@ -425,7 +433,9 @@ def test_log_model_with_dataset(self) -> None: model_version=version, model=model, conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server(self._session, "snowflake-snowpark-python") + test_env_utils.get_latest_package_version_spec_in_server( + self._session, "snowflake-snowpark-python!=1.12.0" + ) ], options={"embed_local_ml_library": True}, artifacts=[atf_ref], diff --git a/tests/integ/snowflake/ml/test_utils/BUILD.bazel b/tests/integ/snowflake/ml/test_utils/BUILD.bazel index 4587f2e6..94d06d77 100644 --- a/tests/integ/snowflake/ml/test_utils/BUILD.bazel +++ b/tests/integ/snowflake/ml/test_utils/BUILD.bazel @@ -96,6 +96,7 @@ py_library( py_test( name = "common_test_base_test", + timeout = "long", srcs = ["common_test_base_test.py"], deps = [ ":common_test_base",