diff --git a/CHANGELOG.md b/CHANGELOG.md index 011abdfe..d70fb5ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ # Release History +## 1.1.2 + +### Bug Fixes + +- Generic: Fix the issue that stack trace is hidden by telemetry unexpectedly. +- Model Development: Execute model signature inference without materializing full dataframe in memory. +- Model Registry: Fix occasional 'snowflake-ml-python library does not exist' error when deploying to SPCS. + +### Behavior Changes + +- Model Registry: When calling `predict` with Snowpark DataFrame, both inferred or normalized column names are accepted. +- Model Registry: When logging a Snowpark ML Modeling Model, sample input data or manually provided signature will be + ignored since they are not necessary. + +### New Features + +- Model Development: SQL implementation of binary `precision_score` metric. + ## 1.1.1 ### Bug Fixes @@ -7,8 +25,6 @@ - Model Registry: The `predict` target method on registered models is now compatible with unsupervised estimators. - Model Development: Fix confusion_matrix incorrect results when the row number cannot be divided by the batch size. -### Behavior Changes - ### New Features - Introduced passthrough_col param in Modeling API. This new param is helpful in scenarios diff --git a/bazel/environments/conda-env-snowflake.yml b/bazel/environments/conda-env-snowflake.yml index 2f727aa3..0c670d26 100644 --- a/bazel/environments/conda-env-snowflake.yml +++ b/bazel/environments/conda-env-snowflake.yml @@ -31,11 +31,13 @@ dependencies: - packaging==23.0 - pandas==1.5.3 - protobuf==3.20.3 + - pyarrow==10.0.1 - pytest==7.4.0 - pytimeparse==1.1.8 - pytorch==2.0.1 - pyyaml==6.0 - requests==2.29.0 + - retrying==1.3.3 - ruamel.yaml==0.17.21 - s3fs==2023.3.0 - scikit-learn==1.3.0 diff --git a/bazel/environments/conda-env.yml b/bazel/environments/conda-env.yml index 9ac7c323..febfa4bd 100644 --- a/bazel/environments/conda-env.yml +++ b/bazel/environments/conda-env.yml @@ -36,11 +36,13 @@ dependencies: - packaging==23.0 - pandas==1.5.3 - protobuf==3.20.3 + - pyarrow==10.0.1 - pytest==7.4.0 - pytimeparse==1.1.8 - pytorch==2.0.1 - pyyaml==6.0 - requests==2.29.0 + - retrying==1.3.3 - ruamel.yaml==0.17.21 - s3fs==2023.3.0 - scikit-learn==1.3.0 diff --git a/bazel/environments/conda-gpu-env.yml b/bazel/environments/conda-gpu-env.yml index 6d213617..3385d0ff 100755 --- a/bazel/environments/conda-gpu-env.yml +++ b/bazel/environments/conda-gpu-env.yml @@ -37,12 +37,14 @@ dependencies: - packaging==23.0 - pandas==1.5.3 - protobuf==3.20.3 + - pyarrow==10.0.1 - pytest==7.4.0 - pytimeparse==1.1.8 - pytorch::pytorch-cuda==11.7.* - pytorch::pytorch==2.0.1 - pyyaml==6.0 - requests==2.29.0 + - retrying==1.3.3 - ruamel.yaml==0.17.21 - s3fs==2023.3.0 - scikit-learn==1.3.0 diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index 6aa16d66..00b04f5c 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -17,7 +17,7 @@ build: noarch: python package: name: snowflake-ml-python - version: 1.1.1 + version: 1.1.2 requirements: build: - python @@ -33,9 +33,11 @@ requirements: - numpy>=1.23,<2 - packaging>=20.9,<24 - pandas>=1.0.0,<2 + - pyarrow - pytimeparse>=1.1.8,<2 - pyyaml>=6.0,<7 - requests + - retrying>=1.3.3,<2 - s3fs>=2022.11,<2024 - scikit-learn>=1.2.1,<1.4 - scipy>=1.9,<2 diff --git a/codegen/codegen_rules.bzl b/codegen/codegen_rules.bzl index af9eec6e..01518ba6 100644 --- a/codegen/codegen_rules.bzl +++ b/codegen/codegen_rules.bzl @@ -92,6 +92,8 @@ def autogen_estimators(module, estimator_info_list): "//snowflake/ml/model:model_signature", "//snowflake/ml/model/_signatures:utils", "//snowflake/ml/modeling/_internal:estimator_utils", + "//snowflake/ml/modeling/_internal:model_trainer", + "//snowflake/ml/modeling/_internal:model_trainer_builder", ], ) diff --git a/codegen/sklearn_wrapper_generator.py b/codegen/sklearn_wrapper_generator.py index 2b590b2b..fe022c1c 100644 --- a/codegen/sklearn_wrapper_generator.py +++ b/codegen/sklearn_wrapper_generator.py @@ -16,44 +16,58 @@ LOAD_DIABETES = "load_diabetes" -ADDITIONAL_PARAM_DESCRIPTIONS = """ - +ADDITIONAL_PARAM_DESCRIPTIONS = { + "input_cols": """ input_cols: Optional[Union[str, List[str]]] A string or list of strings representing column names that contain features. If this parameter is not specified, all columns in the input DataFrame except the columns specified by label_cols, sample_weight_col, and passthrough_cols - parameters are considered input columns. - + parameters are considered input columns. Input columns can also be set after + initialization with the `set_input_cols` method. + """, + "label_cols": """ label_cols: Optional[Union[str, List[str]]] A string or list of strings representing column names that contain labels. - This is a required param for estimators, as there is no way to infer these - columns. If this parameter is not specified, then object is fitted without - labels (like a transformer). - + Label columns must be specified with this parameter during initialization + or with the `set_label_cols` method before fitting. +""", + "output_cols": """ output_cols: Optional[Union[str, List[str]]] A string or list of strings representing column names that will store the output of predict and transform operations. The length of output_cols must - match the expected number of output columns from the specific estimator or + match the expected number of output columns from the specific predictor or transformer class used. - If this parameter is not specified, output column names are derived by - adding an OUTPUT_ prefix to the label column names. These inferred output - column names work for estimator's predict() method, but output_cols must - be set explicitly for transformers. - + If you omit this parameter, output column names are derived by adding an + OUTPUT_ prefix to the label column names for supervised estimators, or + OUTPUT_for unsupervised estimators. These inferred output column names + work for predictors, but output_cols must be set explicitly for transformers. + In general, explicitly specifying output column names is clearer, especially + if you don’t specify the input column names. + To transform in place, pass the same names for input_cols and output_cols. + be set explicitly for transformers. Output columns can also be set after + initialization with the `set_output_cols` method. +""", + "sample_weight_col": """ sample_weight_col: Optional[str] A string representing the column name containing the sample weights. - This argument is only required when working with weighted datasets. - + This argument is only required when working with weighted datasets. Sample + weight column can also be set after initialization with the + `set_sample_weight_col` method. +""", + "passthrough_cols": """ passthrough_cols: Optional[Union[str, List[str]]] A string or a list of strings indicating column names to be excluded from any operations (such as train, transform, or inference). These specified column(s) will remain untouched throughout the process. This option is helpful in scenarios requiring automatic input_cols inference, but need to avoid using specific - columns, like index columns, during training or inference. - + columns, like index columns, during training or inference. Passthrough columns + can also be set after initialization with the `set_passthrough_cols` method. +""", + "drop_input_cols": """ drop_input_cols: Optional[bool], default=False If set, the response of predict(), transform() methods will not contain input columns. -""" +""", +} ADDITIONAL_METHOD_DESCRIPTION = """ Raises: @@ -448,7 +462,6 @@ class WrapperGeneratorBase: is contained in. estimator_imports GENERATED Imports needed for the estimator / fit() call. - wrapper_provider_class GENERATED Class name of wrapper provider. ------------------------------------------------------------------------------------ SIGNATURES AND ARGUMENTS ------------------------------------------------------------------------------------ @@ -545,7 +558,6 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None: self.estimator_imports = "" self.estimator_imports_list: List[str] = [] self.score_sproc_imports: List[str] = [] - self.wrapper_provider_class = "" self.additional_import_statements = "" # Test strings @@ -630,10 +642,11 @@ def _populate_class_doc_fields(self) -> None: class_docstring = inspect.getdoc(self.class_object[1]) or "" class_docstring = class_docstring.rsplit("Attributes\n", 1)[0] + parameters_heading = "Parameters\n----------\n" class_description, param_description = ( - class_docstring.rsplit("Parameters\n", 1) - if len(class_docstring.rsplit("Parameters\n", 1)) == 2 - else (class_docstring, "----------\n") + class_docstring.rsplit(parameters_heading, 1) + if len(class_docstring.rsplit(parameters_heading, 1)) == 2 + else (class_docstring, "") ) # Extract the first sentence of the class description @@ -645,9 +658,11 @@ def _populate_class_doc_fields(self) -> None: f"]\n({self.get_doc_link()})" ) - # Add SnowML specific param descriptions. - param_description = "Parameters\n" + param_description.strip() - param_description += ADDITIONAL_PARAM_DESCRIPTIONS + # Add SnowML specific param descriptions before third party parameters. + snowml_parameters = "" + for d in ADDITIONAL_PARAM_DESCRIPTIONS.values(): + snowml_parameters += d + param_description = f"{parameters_heading}{snowml_parameters}\n{param_description.strip()}" class_docstring = f"{class_description}\n\n{param_description}" class_docstring = textwrap.indent(class_docstring, " ").strip() @@ -718,12 +733,23 @@ def _populate_function_names_and_signatures(self) -> None: for member in inspect.getmembers(self.class_object[1]): if member[0] == "__init__": self.original_init_signature = inspect.signature(member[1]) + elif member[0] == "fit": + original_fit_signature = inspect.signature(member[1]) + if original_fit_signature.parameters["y"].default is None: + # The fit does not require labels, so our label_cols argument is optional. + ADDITIONAL_PARAM_DESCRIPTIONS[ + "label_cols" + ] = """ +label_cols: Optional[Union[str, List[str]]] + This parameter is optional and will be ignored during fit. It is present here for API consistency by convention. + """ signature_lines = [] sklearn_init_lines = [] init_member_args = [] has_kwargs = False sklearn_init_args_dict_list = [] + for k, v in self.original_init_signature.parameters.items(): if k == "self": signature_lines.append("self") @@ -855,9 +881,9 @@ def generate(self) -> "WrapperGeneratorBase": self._populate_flags() self._populate_class_names() self._populate_import_statements() - self._populate_class_doc_fields() self._populate_function_doc_fields() self._populate_function_names_and_signatures() + self._populate_class_doc_fields() self._populate_file_paths() self._populate_integ_test_fields() return self @@ -876,13 +902,8 @@ def generate(self) -> "SklearnWrapperGenerator": # Populate all the common values super().generate() - is_model_selector = WrapperGeneratorFactory._is_class_of_type(self.class_object[1], "BaseSearchCV") - # Populate SKLearn specific values self.estimator_imports_list.extend(["import sklearn", f"import {self.root_module_name}"]) - self.wrapper_provider_class = ( - "SklearnModelSelectionWrapperProvider" if is_model_selector else "SklearnWrapperProvider" - ) self.score_sproc_imports = ["sklearn"] if "random_state" in self.original_init_signature.parameters.keys(): @@ -982,6 +1003,9 @@ def generate(self) -> "SklearnWrapperGenerator": if self._is_hist_gradient_boosting_regressor: self.test_estimator_input_args_list.extend(["min_samples_leaf=1", "max_leaf_nodes=100"]) + self.deps = ( + "f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'" + ) self.supported_export_method = "to_sklearn" self.unsupported_export_methods = ["to_xgboost", "to_lightgbm"] self._construct_string_from_lists() @@ -1010,10 +1034,10 @@ def generate(self) -> "XGBoostWrapperGenerator": ["random_state=0", "subsample=1.0", "colsample_bynode=1.0", "n_jobs=1"] ) self.score_sproc_imports = ["xgboost"] - self.wrapper_provider_class = "XGBoostWrapperProvider" # TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda. self.supported_export_method = "to_xgboost" self.unsupported_export_methods = ["to_sklearn", "to_lightgbm"] + self.deps = "f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'" self._construct_string_from_lists() return self @@ -1039,8 +1063,8 @@ def generate(self) -> "LightGBMWrapperGenerator": self.estimator_imports_list.append("import lightgbm") self.test_estimator_input_args_list.extend(["random_state=0", "n_jobs=1"]) self.score_sproc_imports = ["lightgbm"] - self.wrapper_provider_class = "LightGBMWrapperProvider" + self.deps = "f'numpy=={np.__version__}', f'lightgbm=={lightgbm.__version__}', f'cloudpickle=={cp.__version__}'" self.supported_export_method = "to_lightgbm" self.unsupported_export_methods = ["to_sklearn", "to_xgboost"] self._construct_string_from_lists() diff --git a/codegen/sklearn_wrapper_template.py_template b/codegen/sklearn_wrapper_template.py_template index 96173e1a..f1270070 100644 --- a/codegen/sklearn_wrapper_template.py_template +++ b/codegen/sklearn_wrapper_template.py_template @@ -16,17 +16,19 @@ from sklearn.utils.metaestimators import available_if from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols from snowflake.ml._internal import telemetry from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages +from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV from snowflake.ml._internal.utils import pkg_version_utils, identifier -from snowflake.snowpark import DataFrame +from snowflake.snowpark import DataFrame, Session from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl +from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder +from snowflake.ml.modeling._internal.model_trainer import ModelTrainer from snowflake.ml.modeling._internal.estimator_utils import ( gather_dependencies, original_estimator_has_callable, transform_snowml_obj_to_sklearn_obj, validate_sklearn_args, ) -from snowflake.ml.modeling._internal.snowpark_handlers import {transform.wrapper_provider_class} from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers from snowflake.ml.model.model_signature import ( @@ -46,7 +48,6 @@ _PROJECT = "ModelDevelopment" _SUBPROJECT = "".join([s.capitalize() for s in "{transform.root_module_name}".replace("sklearn.", "").split("_")]) - class {transform.original_class_name}(BaseTransformer): r"""{transform.estimator_class_docstring} """ @@ -57,7 +58,7 @@ class {transform.original_class_name}(BaseTransformer): super().__init__() {transform.estimator_init_member_args} - deps = set({transform.wrapper_provider_class}().dependencies) + deps: Set[str] = set([{transform.deps}]) {transform.estimator_args_gathering_calls} self._deps = list(deps) {transform.estimator_args_transform_calls} @@ -66,13 +67,14 @@ class {transform.original_class_name}(BaseTransformer): args=init_args, klass={transform.root_module_name}.{transform.original_class_name} ) - self._sklearn_object = {transform.root_module_name}.{transform.original_class_name}( + self._sklearn_object: Any = {transform.root_module_name}.{transform.original_class_name}( {transform.sklearn_init_arguments} ) self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols self._snowpark_cols: Optional[List[str]] = self.input_cols - self._handlers: FitPredictHandlers = HandlersImpl(class_name={transform.original_class_name}.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True, wrapper_provider={transform.wrapper_provider_class}()) + self._handlers: FitPredictHandlers = HandlersImpl(class_name={transform.original_class_name}.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True) + self._autogenerated = True def _get_rand_id(self) -> str: """ @@ -118,54 +120,48 @@ class {transform.original_class_name}(BaseTransformer): self """ self._infer_input_output_cols(dataset) - if isinstance(dataset, pd.DataFrame): - assert self._sklearn_object is not None # keep mypy happy - self._sklearn_object = self._handlers.fit_pandas( - dataset, - self._sklearn_object, - self.input_cols, - self.label_cols, - self.sample_weight_col - ) - elif isinstance(dataset, DataFrame): - self._fit_snowpark(dataset) - else: - raise TypeError( - f"Unexpected dataset type: {{type(dataset)}}." - "Supported dataset types: snowpark.DataFrame, pandas.DataFrame." - ) + if isinstance(dataset, DataFrame): + session = dataset._session + assert session is not None # keep mypy happy + # Validate that key package version in user workspace are supported in snowflake conda channel + # If customer doesn't have package in conda channel, replace the ones have the closest versions + self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( + pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT) + + # Specify input columns so column pruning will be enforced + selected_cols = self._get_active_columns() + if len(selected_cols) > 0: + dataset = dataset.select(selected_cols) + + self._snowpark_cols = dataset.select(self.input_cols).columns + + # If we are already in a stored procedure, no need to kick off another one. + if SNOWML_SPROC_ENV in os.environ: + statement_params = telemetry.get_function_usage_statement_params( + project=_PROJECT, + subproject=_SUBPROJECT, + function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), {transform.original_class_name}.__class__.__name__), + api_calls=[Session.call], + custom_tags=dict([("autogen", True)]) if self._autogenerated else None, + ) + pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params) + pd_df.columns = dataset.columns + dataset = pd_df + + model_trainer = ModelTrainerBuilder.build( + estimator=self._sklearn_object, + dataset=dataset, + input_cols=self.input_cols, + label_cols=self.label_cols, + sample_weight_col=self.sample_weight_col, + autogenerated=self._autogenerated, + subproject=_SUBPROJECT + ) + self._sklearn_object = model_trainer.train() self._is_fitted = True self._get_model_signatures(dataset) return self - def _fit_snowpark(self, dataset: DataFrame) -> None: - session = dataset._session - assert session is not None # keep mypy happy - # Validate that key package version in user workspace are supported in snowflake conda channel - # If customer doesn't have package in conda channel, replace the ones have the closest versions - self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( - pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT) - - # Specify input columns so column pruning will be enforced - selected_cols = self._get_active_columns() - if len(selected_cols) > 0: - dataset = dataset.select(selected_cols) - - estimator = self._sklearn_object - assert estimator is not None # Keep mypy happy - - self._snowpark_cols = dataset.select(self.input_cols).columns - - self._sklearn_object = self._handlers.fit_snowpark( - dataset, - session, - estimator, - ["snowflake-snowpark-python"] + self._get_dependencies(), - self.input_cols, - self.label_cols, - self.sample_weight_col, - ) - def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]: if self._drop_input_cols: return [] @@ -353,11 +349,6 @@ class {transform.original_class_name}(BaseTransformer): subproject=_SUBPROJECT, custom_tags=dict([("autogen", True)]), ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - custom_tags=dict([("autogen", True)]), - ) def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]: """{transform.predict_docstring} @@ -402,11 +393,6 @@ class {transform.original_class_name}(BaseTransformer): subproject=_SUBPROJECT, custom_tags=dict([("autogen", True)]), ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - custom_tags=dict([("autogen", True)]), - ) def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]: """{transform.transform_docstring} @@ -447,7 +433,8 @@ class {transform.original_class_name}(BaseTransformer): if {transform.fit_predict_cluster_function_support}: self.fit(dataset) assert self._sklearn_object is not None - return self._sklearn_object.labels_ + labels : npt.NDArray[Any] = self._sklearn_object.labels_ + return labels else: # TODO(xinyi): support fit_predict for mixture classes raise NotImplementedError @@ -484,6 +471,7 @@ class {transform.original_class_name}(BaseTransformer): output_cols = [] # Make sure column names are valid snowflake identifiers. + assert output_cols is not None # Make MyPy happy rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols] return rv @@ -494,11 +482,6 @@ class {transform.original_class_name}(BaseTransformer): subproject=_SUBPROJECT, custom_tags=dict([("autogen", True)]), ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - custom_tags=dict([("autogen", True)]), - ) def predict_proba( self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_proba_" ) -> Union[DataFrame, pd.DataFrame]: @@ -531,11 +514,6 @@ class {transform.original_class_name}(BaseTransformer): subproject=_SUBPROJECT, custom_tags=dict([("autogen", True)]), ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - custom_tags=dict([("autogen", True)]), - ) def predict_log_proba( self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_log_proba_" ) -> Union[DataFrame, pd.DataFrame]: @@ -564,16 +542,6 @@ class {transform.original_class_name}(BaseTransformer): return output_df @available_if(original_estimator_has_callable("decision_function")) # type: ignore[misc] - @telemetry.send_api_usage_telemetry( - project=_PROJECT, - subproject=_SUBPROJECT, - custom_tags=dict([("autogen", True)]), - ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - custom_tags=dict([("autogen", True)]), - ) def decision_function( self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "decision_function_" ) -> Union[DataFrame, pd.DataFrame]: @@ -656,11 +624,6 @@ class {transform.original_class_name}(BaseTransformer): subproject=_SUBPROJECT, custom_tags=dict([("autogen", True)]), ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - custom_tags=dict([("autogen", True)]), - ) def kneighbors( self, dataset: Union[DataFrame, pd.DataFrame], @@ -713,9 +676,9 @@ class {transform.original_class_name}(BaseTransformer): # For classifier, the type of predict is the same as the type of label if self._sklearn_object._estimator_type == 'classifier': # label columns is the desired type for output - outputs = _infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True) + outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True)) # rename the output columns - outputs = model_signature_utils.rename_features(outputs, self.output_cols) + outputs = list(model_signature_utils.rename_features(outputs, self.output_cols)) self._model_signature_dict["predict"] = ModelSignature(inputs, ([] if self._drop_input_cols else inputs) + outputs) diff --git a/docs/source/_templates/autosummary/class.rst b/docs/source/_templates/autosummary/class.rst index 9939db8b..714c1ce2 100644 --- a/docs/source/_templates/autosummary/class.rst +++ b/docs/source/_templates/autosummary/class.rst @@ -9,9 +9,7 @@ .. autosummary:: {% for item in methods %} - {%- if item not in inherited_members %} ~{{ name }}.{{ item }} - {%- endif %} {%- endfor %} {% endif %} {% endblock %} @@ -23,9 +21,7 @@ .. autosummary:: {% for item in attributes %} - {%- if item not in inherited_members %} ~{{ name }}.{{ item }} - {%- endif %} {%- endfor %} {% endif %} {% endblock %} diff --git a/docs/source/modeling.rst b/docs/source/modeling.rst index 2df9caa9..3884591a 100644 --- a/docs/source/modeling.rst +++ b/docs/source/modeling.rst @@ -304,7 +304,7 @@ snowflake.ml.modeling.metrics .. currentmodule:: snowflake.ml.modeling.metrics -.. rubric:: Classes +.. rubric:: Functions .. autosummary:: :toctree: api/modeling @@ -312,12 +312,20 @@ snowflake.ml.modeling.metrics accuracy_score confusion_matrix correlation + covariance + d2_absolute_error_score + d2_pinball_score + explained_variance_score f1_score fbeta_score log_loss + mean_absolute_error + mean_absolute_percentage_error + mean_squared_error precision_recall_curve precision_recall_fscore_support precision_score + r2_score recall_score roc_auc_score roc_curve @@ -448,6 +456,7 @@ snowflake.ml.modeling.preprocessing Normalizer OneHotEncoder Binarizer + PolynomialFeatures snowflake.ml.modeling.semi_supervised diff --git a/docs/sphinxconf/unsupported_functions_by_class.csv b/docs/sphinxconf/unsupported_functions_by_class.csv index 621e5464..54c45018 100644 --- a/docs/sphinxconf/unsupported_functions_by_class.csv +++ b/docs/sphinxconf/unsupported_functions_by_class.csv @@ -1,153 +1,153 @@ -AdditiveChi2Sampler, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -Nystroem, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -KernelRidge, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -SkewedChi2Sampler, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -PolynomialCountSketch, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -SVC, transform, to_xgboost, to_lightgbm -RBFSampler, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -ARDRegression, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -HistGradientBoostingClassifier, predict_log_proba, transform, to_xgboost, to_lightgbm -LGBMClassifier, predict_log_proba, decision_function, transform, to_xgboost, to_sklearn -LinearSVC, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm -OrthogonalMatchingPursuit, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -NearestNeighbors, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -PoissonRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -Lars, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -Ridge, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -RandomForestClassifier, decision_function, transform, to_xgboost, to_lightgbm -ExtraTreesRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -ComplementNB, decision_function, transform, to_xgboost, to_lightgbm -MultiTaskElasticNetCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -Lasso, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -LinearSVR, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -PassiveAggressiveRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -Perceptron, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm -MultiTaskElasticNet, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -RANSACRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -ColumnTransformer, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -NeighborhoodComponentsAnalysis, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -VotingRegressor, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -HuberRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -SGDClassifier, transform, to_xgboost, to_lightgbm -SGDOneClassSVM, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm -TheilSenRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -SGDRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -TweedieRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -AdaBoostClassifier, transform, to_xgboost, to_lightgbm -BaggingClassifier, transform, to_xgboost, to_lightgbm -AdaBoostRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -BaggingRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -ExtraTreesClassifier, decision_function, transform, to_xgboost, to_lightgbm -GradientBoostingClassifier, transform, to_xgboost, to_lightgbm -RidgeCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -NuSVC, transform, to_xgboost, to_lightgbm -IsolationForest, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm -LarsCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -RadiusNeighborsRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -VotingClassifier, predict_log_proba, decision_function, to_xgboost, to_lightgbm -SVR, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -StackingRegressor, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -MultiTaskLassoCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -NuSVR, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -ExtraTreeClassifier, decision_function, transform, to_xgboost, to_lightgbm -BernoulliRBM, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -DBSCAN, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -DecisionTreeRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -MLPClassifier, decision_function, transform, to_xgboost, to_lightgbm -MLPRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -DecisionTreeClassifier, decision_function, transform, to_xgboost, to_lightgbm -ExtraTreeRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -CalibratedClassifierCV, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm -AffinityPropagation, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -SpectralBiclustering, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -AgglomerativeClustering, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -Birch, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -MeanShift, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -KMeans, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -FeatureAgglomeration, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -BisectingKMeans, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -HistGradientBoostingRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -OPTICS, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -PolynomialFeatures, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -IterativeImputer, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -MultiTaskLasso, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -RandomForestRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -GraphicalLasso, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -MinCovDet, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -LedoitWolf, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -OAS, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -EmpiricalCovariance, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -ShrunkCovariance, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -FactorAnalysis, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -DictionaryLearning, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -FastICA, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -IncrementalPCA, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -KernelPCA, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -QuadraticDiscriminantAnalysis, transform, to_xgboost, to_lightgbm -PCA, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -MiniBatchDictionaryLearning, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -LinearDiscriminantAnalysis, to_xgboost, to_lightgbm -TruncatedSVD, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -SparsePCA, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -GenericUnivariateSelect, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -SelectFwe, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -SelectFpr, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -SequentialFeatureSelector, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -GaussianProcessClassifier, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm -VarianceThreshold, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -GaussianProcessRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -KNNImputer, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -MissingIndicator, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -TSNE, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -Isomap, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -SpectralEmbedding, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -LGBMRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_sklearn -BayesianGaussianMixture, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm +AdditiveChi2Sampler, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +Nystroem, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +KernelRidge, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +SkewedChi2Sampler, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +PolynomialCountSketch, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +SVC, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +RBFSampler, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +ARDRegression, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +HistGradientBoostingClassifier, predict_log_proba, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +LGBMClassifier, predict_log_proba, decision_function, transform, to_xgboost, to_sklearn, kneighbors, fit_predict +LinearSVC, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +OrthogonalMatchingPursuit, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +NearestNeighbors, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, score, fit_predict +PoissonRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +Lars, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +Ridge, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +RandomForestClassifier, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +ExtraTreesRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +ComplementNB, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +MultiTaskElasticNetCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +Lasso, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +LinearSVR, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +PassiveAggressiveRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +Perceptron, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +MultiTaskElasticNet, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +RANSACRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +ColumnTransformer, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, kneighbors, score, fit_predict +NeighborhoodComponentsAnalysis, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +VotingRegressor, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, kneighbors, fit_predict +HuberRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +SGDClassifier, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +SGDOneClassSVM, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm, score, kneighbors +TheilSenRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +SGDRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +TweedieRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +AdaBoostClassifier, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +BaggingClassifier, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +AdaBoostRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +BaggingRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +ExtraTreesClassifier, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +GradientBoostingClassifier, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +RidgeCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +NuSVC, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +IsolationForest, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm, score, kneighbors +LarsCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +RadiusNeighborsRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +VotingClassifier, predict_log_proba, decision_function, to_xgboost, to_lightgbm, kneighbors, fit_predict +SVR, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +StackingRegressor, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, kneighbors, fit_predict +MultiTaskLassoCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +NuSVR, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +ExtraTreeClassifier, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +BernoulliRBM, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +DBSCAN, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, score, kneighbors +DecisionTreeRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +MLPClassifier, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +MLPRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +DecisionTreeClassifier, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +ExtraTreeRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +CalibratedClassifierCV, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +AffinityPropagation, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, score, kneighbors +SpectralBiclustering, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +AgglomerativeClustering, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, score, kneighbors +Birch, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors +MeanShift, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, score, kneighbors +KMeans, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, kneighbors +FeatureAgglomeration, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +BisectingKMeans, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, kneighbors +HistGradientBoostingRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +OPTICS, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, score, kneighbors +PolynomialFeatures, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +IterativeImputer, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +MultiTaskLasso, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +RandomForestRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +GraphicalLasso, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +MinCovDet, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +LedoitWolf, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +OAS, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +EmpiricalCovariance, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +ShrunkCovariance, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +FactorAnalysis, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, kneighbors, fit_predict +DictionaryLearning, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +FastICA, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +IncrementalPCA, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +KernelPCA, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +QuadraticDiscriminantAnalysis, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +PCA, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, kneighbors, fit_predict +MiniBatchDictionaryLearning, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +LinearDiscriminantAnalysis, to_xgboost, to_lightgbm, kneighbors, fit_predict +TruncatedSVD, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +SparsePCA, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +GenericUnivariateSelect, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +SelectFwe, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +SelectFpr, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +SequentialFeatureSelector, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +GaussianProcessClassifier, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +VarianceThreshold, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +GaussianProcessRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +KNNImputer, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +MissingIndicator, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +TSNE, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +Isomap, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +SpectralEmbedding, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +LGBMRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_sklearn, kneighbors, fit_predict +BayesianGaussianMixture, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors GridSearchCV, to_xgboost, to_lightgbm -GaussianMixture, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm -XGBRFClassifier, predict_log_proba, decision_function, transform, to_sklearn, to_lightgbm -OneVsOneClassifier, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm +GaussianMixture, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors +XGBRFClassifier, predict_log_proba, decision_function, transform, to_sklearn, to_lightgbm, kneighbors, fit_predict +OneVsOneClassifier, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict RandomizedSearchCV, to_xgboost, to_lightgbm -OneVsRestClassifier, predict_log_proba, transform, to_xgboost, to_lightgbm -GaussianNB, decision_function, transform, to_xgboost, to_lightgbm -BernoulliNB, decision_function, transform, to_xgboost, to_lightgbm -OutputCodeClassifier, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -CategoricalNB, decision_function, transform, to_xgboost, to_lightgbm -MiniBatchKMeans, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -LabelSpreading, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm -KNeighborsClassifier, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm -KNeighborsRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -LocalOutlierFactor, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm -NearestCentroid, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -KernelDensity, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -RidgeClassifier, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm -SpectralClustering, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -RidgeClassifierCV, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm -RadiusNeighborsClassifier, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm -LabelPropagation, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm -TransformedTargetRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -PassiveAggressiveClassifier, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm -XGBRFRegressor, predict_log_proba, predict_proba, decision_function, transform, to_sklearn, to_lightgbm -XGBClassifier, predict_log_proba, decision_function, transform, to_sklearn, to_lightgbm -GradientBoostingRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -MultinomialNB, decision_function, transform, to_xgboost, to_lightgbm -LogisticRegression, transform, to_xgboost, to_lightgbm -MDS, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -SelectPercentile, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -SelectFdr, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -XGBRegressor, predict_log_proba, predict_proba, decision_function, transform, to_sklearn, to_lightgbm -GraphicalLassoCV, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -SelectKBest, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -MiniBatchSparsePCA, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm -EllipticEnvelope, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm -LassoLars, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -SpectralCoclustering, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -LogisticRegressionCV, transform, to_xgboost, to_lightgbm -LassoLarsCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -LassoCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -ElasticNet, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -LassoLarsIC, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -LinearRegression, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -GammaRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -ElasticNetCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm -BayesianRidge, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm +OneVsRestClassifier, predict_log_proba, transform, to_xgboost, to_lightgbm, kneighbors ,fit_predict +GaussianNB, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +BernoulliNB, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +OutputCodeClassifier, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +CategoricalNB, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +MiniBatchKMeans, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, kneighbors +LabelSpreading, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +KNeighborsClassifier, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm, fit_predict +KNeighborsRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, fit_predict +LocalOutlierFactor, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm, score +NearestCentroid, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +KernelDensity, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +RidgeClassifier, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +SpectralClustering, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, score, kneighbors +RidgeClassifierCV, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +RadiusNeighborsClassifier, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +LabelPropagation, predict_log_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +TransformedTargetRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +PassiveAggressiveClassifier, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +XGBRFRegressor, predict_log_proba, predict_proba, decision_function, transform, to_sklearn, to_lightgbm, kneighbors, fit_predict +XGBClassifier, predict_log_proba, decision_function, transform, to_sklearn, to_lightgbm, kneighbors, fit_predict +GradientBoostingRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +MultinomialNB, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +LogisticRegression, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +MDS, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +SelectPercentile, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +SelectFdr, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +XGBRegressor, predict_log_proba, predict_proba, decision_function, transform, to_sklearn, to_lightgbm, kneighbors, fit_predict +GraphicalLassoCV, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +SelectKBest, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +MiniBatchSparsePCA, predict, predict_log_proba, predict_proba, decision_function, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +EllipticEnvelope, predict_log_proba, predict_proba, transform, to_xgboost, to_lightgbm, kneighbors +LassoLars, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +SpectralCoclustering, predict, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, score, kneighbors, fit_predict +LogisticRegressionCV, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +LassoLarsCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +LassoCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +ElasticNet, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +LassoLarsIC, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +LinearRegression, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +GammaRegressor, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +ElasticNetCV, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict +BayesianRidge, predict_log_proba, predict_proba, decision_function, transform, to_xgboost, to_lightgbm, kneighbors, fit_predict diff --git a/requirements.txt b/requirements.txt index 05a963a1..6a700073 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,9 +28,11 @@ packaging==23.0 pandas==1.5.3 peft==0.5.0 protobuf==3.20.3 +pyarrow==10.0.1 pytest==7.4.0 pytimeparse==1.1.8 pyyaml==6.0 +retrying==1.3.3 ruamel.yaml==0.17.21 s3fs==2023.3.0 scikit-learn==1.3.0 diff --git a/requirements.yml b/requirements.yml index 48a658b7..3b8e37e4 100644 --- a/requirements.yml +++ b/requirements.yml @@ -181,6 +181,9 @@ - spcs_inference - name: protobuf dev_version: 3.20.3 +- name: pyarrow + dev_version: 10.0.1 + version_requirements: '' - name: pytest dev_version: 7.4.0 - name_pypi: torch @@ -193,6 +196,9 @@ - deployment_core - udf_inference - spcs_inference +- name: retrying + dev_version: 1.3.3 + version_requirements: '>=1.3.3,<2' # For fsspec[http] in conda - name_conda: requests dev_version_conda: 2.29.0 diff --git a/snowflake/cortex/_complete.py b/snowflake/cortex/_complete.py index 4dd01e92..9cec5d36 100644 --- a/snowflake/cortex/_complete.py +++ b/snowflake/cortex/_complete.py @@ -23,7 +23,7 @@ def Complete( A column of string responses. """ - return _complete_impl("snowflake.ml.complete", model, prompt, session=session) + return _complete_impl("snowflake.cortex.complete", model, prompt, session=session) def _complete_impl( diff --git a/snowflake/cortex/_extract_answer.py b/snowflake/cortex/_extract_answer.py index 80e6e8ea..0b44bc8a 100644 --- a/snowflake/cortex/_extract_answer.py +++ b/snowflake/cortex/_extract_answer.py @@ -25,7 +25,7 @@ def ExtractAnswer( A column of strings containing answers. """ - return _extract_answer_impl("snowflake.ml.extract_answer", from_text, question, session=session) + return _extract_answer_impl("snowflake.cortex.extract_answer", from_text, question, session=session) def _extract_answer_impl( diff --git a/snowflake/cortex/_sentiment.py b/snowflake/cortex/_sentiment.py index 8843f6fc..e75e6609 100644 --- a/snowflake/cortex/_sentiment.py +++ b/snowflake/cortex/_sentiment.py @@ -22,7 +22,7 @@ def Sentiment( A column of floats. 1 represents positive sentiment, -1 represents negative sentiment. """ - return _sentiment_impl("snowflake.ml.sentiment", text, session=session) + return _sentiment_impl("snowflake.cortex.sentiment", text, session=session) def _sentiment_impl( diff --git a/snowflake/cortex/_summarize.py b/snowflake/cortex/_summarize.py index 54c47050..fd0e11f8 100644 --- a/snowflake/cortex/_summarize.py +++ b/snowflake/cortex/_summarize.py @@ -23,7 +23,7 @@ def Summarize( A column of string summaries. """ - return _summarize_impl("snowflake.ml.summarize", text, session=session) + return _summarize_impl("snowflake.cortex.summarize", text, session=session) def _summarize_impl( diff --git a/snowflake/cortex/_translate.py b/snowflake/cortex/_translate.py index 2c18671c..27f71249 100644 --- a/snowflake/cortex/_translate.py +++ b/snowflake/cortex/_translate.py @@ -27,7 +27,7 @@ def Translate( A column of string translations. """ - return _translate_impl("snowflake.ml.translate", text, from_language, to_language, session=session) + return _translate_impl("snowflake.cortex.translate", text, from_language, to_language, session=session) def _translate_impl( diff --git a/snowflake/ml/_internal/BUILD.bazel b/snowflake/ml/_internal/BUILD.bazel index 4c012bb3..89d559ed 100644 --- a/snowflake/ml/_internal/BUILD.bazel +++ b/snowflake/ml/_internal/BUILD.bazel @@ -43,6 +43,7 @@ py_library( ":env", "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/utils:query_result_checker", + "//snowflake/ml/_internal/utils:retryable_http", ], ) diff --git a/snowflake/ml/_internal/env_utils.py b/snowflake/ml/_internal/env_utils.py index d78806b5..2ba8a16d 100644 --- a/snowflake/ml/_internal/env_utils.py +++ b/snowflake/ml/_internal/env_utils.py @@ -4,6 +4,7 @@ import re import textwrap import warnings +from enum import Enum from importlib import metadata as importlib_metadata from typing import Any, DefaultDict, Dict, List, Optional, Tuple @@ -18,10 +19,22 @@ ) from snowflake.ml._internal.utils import query_result_checker from snowflake.snowpark import session +from snowflake.snowpark._internal import utils as snowpark_utils + + +class CONDA_OS(Enum): + LINUX_64 = "linux-64" + LINUX_AARCH64 = "linux-aarch64" + OSX_64 = "osx-64" + OSX_ARM64 = "osx-arm64" + WIN_64 = "win-64" + NO_ARCH = "noarch" + _SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake" _NODEFAULTS = "nodefaults" _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION: Optional[bool] = None +_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {} _SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {} DEFAULT_CHANNEL_NAME = "" @@ -217,6 +230,7 @@ def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement warnings.warn( f"Package requirement {str(pip_req)} specified, while version {local_dist_version} is installed. " "Local version will be ignored to conform to package requirement.", + stacklevel=2, category=UserWarning, ) return pip_req @@ -265,10 +279,58 @@ def _check_runtime_version_column_existence(session: session.Session) -> bool: return result == 1 -def validate_requirements_in_snowflake_conda_channel( +def get_matched_package_versions_in_snowflake_conda_channel( + req: requirements.Requirement, + python_version: str = snowml_env.PYTHON_VERSION, + conda_os: CONDA_OS = CONDA_OS.LINUX_64, +) -> List[version.Version]: + """Search the snowflake anaconda channel for packages that matches the specifier. Note that this will be the + source of truth for checking whether a package indeed exists in Snowflake conda channel. + + Given that a package comes in different architectures, we only check for the Linux x86_64 architecture and assume + the package exists in other architectures. If such an assumption does not hold true for a certain package, the + caller should specify the architecture to search for. + + Args: + req: Requirement specifier. + python_version: A string of python version where model is run. + conda_os: Specified platform to search availability of the package. + + Returns: + List of package versions that meet the requirement specifier. + """ + # Move the retryable_http import here as when UDF import this file, it won't have the "requests" dependency. + from snowflake.ml._internal.utils import retryable_http + + assert not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call] + + url = f"{_SNOWFLAKE_CONDA_CHANNEL_URL}/{conda_os.value}/repodata.json" + + if req.name not in _SNOWFLAKE_CONDA_PACKAGE_CACHE: + http_client = retryable_http.get_http_client() + parsed_python_version = version.Version(python_version) + python_version_build_str = f"py{parsed_python_version.major}{parsed_python_version.minor}" + repodata = http_client.get(url).json() + assert isinstance(repodata, dict) + packages_info = repodata["packages"] + assert isinstance(packages_info, dict) + version_list = [ + version.parse(package_info["version"]) + for package_info in packages_info.values() + if package_info["name"] == req.name and python_version_build_str in package_info["build"] + ] + _SNOWFLAKE_CONDA_PACKAGE_CACHE[req.name] = version_list + + matched_versions = list(req.specifier.filter(set(_SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req.name, [])))) + return matched_versions + + +def validate_requirements_in_information_schema( session: session.Session, reqs: List[requirements.Requirement], python_version: str ) -> Optional[List[str]]: - """Search the snowflake anaconda channel for packages with version meet the specifier. + """Look up the information_schema table to check if a package with the specified specifier exists in the Snowflake + Conda channel. Note that this is not the source of truth due to the potential delay caused by a package that might + exist in the information_schema table but has not yet become available in the Snowflake Conda channel. Args: session: Snowflake connection session. @@ -285,7 +347,7 @@ def validate_requirements_in_snowflake_conda_channel( ret_list = [] reqs_to_request = [] for req in reqs: - if req.name not in _SNOWFLAKE_CONDA_PACKAGE_CACHE: + if req.name not in _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: reqs_to_request.append(req) if reqs_to_request: pkg_names_str = " OR ".join( @@ -326,13 +388,13 @@ def validate_requirements_in_snowflake_conda_channel( for row in result: req_name = row["PACKAGE_NAME"] req_ver = version.parse(row["VERSION"]) - cached_req_ver_list = _SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req_name, []) + cached_req_ver_list = _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req_name, []) cached_req_ver_list.append(req_ver) - _SNOWFLAKE_CONDA_PACKAGE_CACHE[req_name] = cached_req_ver_list + _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE[req_name] = cached_req_ver_list except snowflake.connector.DataError: return None for req in reqs: - available_versions = list(req.specifier.filter(set(_SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req.name, [])))) + available_versions = list(req.specifier.filter(set(_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req.name, [])))) if not available_versions: return None else: diff --git a/snowflake/ml/_internal/env_utils_test.py b/snowflake/ml/_internal/env_utils_test.py index 44249633..176ddb0d 100644 --- a/snowflake/ml/_internal/env_utils_test.py +++ b/snowflake/ml/_internal/env_utils_test.py @@ -294,7 +294,7 @@ def test_relax_requirement_version(self) -> None: self.assertEqual(env_utils.relax_requirement_version(r), requirements.Requirement("python-package")) self.assertIsNot(env_utils.relax_requirement_version(r), r) - def test_validate_requirements_in_snowflake_conda_channel(self) -> None: + def test_validate_requirements_in_information_schema(self) -> None: m_session = mock_session.MockSession(conn=None, test_case=self) m_session.add_mock_sql( query=textwrap.dedent( @@ -326,7 +326,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: c_session = cast(session.Session, m_session) self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], python_version=snowml_env.PYTHON_VERSION, @@ -336,7 +336,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: # Test cache self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], python_version=snowml_env.PYTHON_VERSION, @@ -345,7 +345,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: ) # clear cache - env_utils._SNOWFLAKE_CONDA_PACKAGE_CACHE = {} + env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} query = textwrap.dedent( """ @@ -366,7 +366,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: c_session = cast(session.Session, m_session) self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost")], python_version=snowml_env.PYTHON_VERSION, @@ -376,7 +376,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: # Test cache self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost")], python_version=snowml_env.PYTHON_VERSION, @@ -401,7 +401,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: c_session = cast(session.Session, m_session) self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], python_version=snowml_env.PYTHON_VERSION, @@ -411,7 +411,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: # Test cache self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], python_version=snowml_env.PYTHON_VERSION, @@ -420,7 +420,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: ) # clear cache - env_utils._SNOWFLAKE_CONDA_PACKAGE_CACHE = {} + env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} query = textwrap.dedent( """ @@ -440,7 +440,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: c_session = cast(session.Session, m_session) self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.7.3")], python_version=snowml_env.PYTHON_VERSION, @@ -450,7 +450,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: # Test cache self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.7.3")], python_version=snowml_env.PYTHON_VERSION, @@ -459,7 +459,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: ) self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost>=1.7,<1.8")], python_version=snowml_env.PYTHON_VERSION, @@ -468,7 +468,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: ) self.assertIsNone( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.7.1, ==1.7.3")], python_version=snowml_env.PYTHON_VERSION, @@ -476,13 +476,13 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: ) # clear cache - env_utils._SNOWFLAKE_CONDA_PACKAGE_CACHE = {} + env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} m_session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) c_session = cast(session.Session, m_session) self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.7.*")], python_version=snowml_env.PYTHON_VERSION, @@ -492,7 +492,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: # Test cache self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.7.*")], python_version=snowml_env.PYTHON_VERSION, @@ -501,13 +501,13 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: ) # clear cache - env_utils._SNOWFLAKE_CONDA_PACKAGE_CACHE = {} + env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} m_session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) c_session = cast(session.Session, m_session) self.assertIsNone( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.3.*")], python_version=snowml_env.PYTHON_VERSION, @@ -516,7 +516,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: # Test cache self.assertIsNone( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.3.*")], python_version=snowml_env.PYTHON_VERSION, @@ -524,7 +524,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: ) # clear cache - env_utils._SNOWFLAKE_CONDA_PACKAGE_CACHE = {} + env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} query = textwrap.dedent( """ @@ -541,7 +541,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: c_session = cast(session.Session, m_session) self.assertIsNone( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("python-package")], python_version=snowml_env.PYTHON_VERSION, @@ -582,7 +582,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: c_session = cast(session.Session, m_session) self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], python_version=snowml_env.PYTHON_VERSION, @@ -592,7 +592,7 @@ def test_validate_requirements_in_snowflake_conda_channel(self) -> None: # Test cache self.assertListEqual( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], python_version=snowml_env.PYTHON_VERSION, diff --git a/snowflake/ml/_internal/file_utils.py b/snowflake/ml/_internal/file_utils.py index 37fc23c8..fb7387e9 100644 --- a/snowflake/ml/_internal/file_utils.py +++ b/snowflake/ml/_internal/file_utils.py @@ -28,6 +28,7 @@ from snowflake import snowpark from snowflake.ml._internal.exceptions import exceptions +from snowflake.snowpark import exceptions as snowpark_exceptions GENERATED_PY_FILE_EXT = (".pyc", ".pyo", ".pyd", ".pyi") @@ -286,8 +287,16 @@ def stage_file_exists( return False +def _retry_on_sql_error(exception: Exception) -> bool: + return isinstance(exception, snowpark_exceptions.SnowparkSQLException) + + def upload_directory_to_stage( - session: snowpark.Session, local_path: pathlib.Path, stage_path: pathlib.PurePosixPath + session: snowpark.Session, + local_path: pathlib.Path, + stage_path: pathlib.PurePosixPath, + *, + statement_params: Optional[Dict[str, Any]] = None, ) -> None: """Upload a local folder recursively to a stage and keep the structure. @@ -295,7 +304,10 @@ def upload_directory_to_stage( session: Snowpark Session. local_path: Local path to upload. stage_path: Base path in the stage. + statement_params: Statement Params. """ + import retrying + file_operation = snowpark.FileOperation(session=session) for root, _, filenames in os.walk(local_path): @@ -305,16 +317,26 @@ def upload_directory_to_stage( stage_dir_path = ( stage_path / pathlib.PurePosixPath(local_file_path.relative_to(local_path).as_posix()).parent ) - file_operation.put( + retrying.retry( + retry_on_exception=_retry_on_sql_error, + stop_max_attempt_number=5, + wait_exponential_multiplier=100, + wait_exponential_max=10000, + )(file_operation.put)( str(local_file_path), str(stage_dir_path), auto_compress=False, overwrite=False, + statement_params=statement_params, ) def download_directory_from_stage( - session: snowpark.Session, stage_path: pathlib.PurePosixPath, local_path: pathlib.Path + session: snowpark.Session, + stage_path: pathlib.PurePosixPath, + local_path: pathlib.Path, + *, + statement_params: Optional[Dict[str, Any]] = None, ) -> None: """Upload a folder in stage recursively to a folder in local and keep the structure. @@ -322,7 +344,10 @@ def download_directory_from_stage( session: Snowpark Session. stage_path: Stage path to download from. local_path: Local path as the base of destination. + statement_params: Statement Params. """ + import retrying + file_operation = file_operation = snowpark.FileOperation(session=session) file_list = [ pathlib.PurePosixPath(stage_path.parts[0], *pathlib.PurePosixPath(row.name).parts[1:]) @@ -331,4 +356,9 @@ def download_directory_from_stage( for stage_file_path in file_list: local_file_dir = local_path / stage_file_path.relative_to(stage_path).parent local_file_dir.mkdir(parents=True, exist_ok=True) - file_operation.get(str(stage_file_path), str(local_file_dir)) + retrying.retry( + retry_on_exception=_retry_on_sql_error, + stop_max_attempt_number=5, + wait_exponential_multiplier=100, + wait_exponential_max=10000, + )(file_operation.get)(str(stage_file_path), str(local_file_dir), statement_params=statement_params) diff --git a/snowflake/ml/_internal/telemetry.py b/snowflake/ml/_internal/telemetry.py index d797931c..30ceeb2b 100644 --- a/snowflake/ml/_internal/telemetry.py +++ b/snowflake/ml/_internal/telemetry.py @@ -42,6 +42,7 @@ class TelemetryField(enum.Enum): NAME = "name" # types of telemetry TYPE_FUNCTION_USAGE = "function_usage" + TYPE_SNOWML_SPCS_USAGE = "snowml_spcs_usage" # message keys for telemetry KEY_PROJECT = "project" KEY_SUBPROJECT = "subproject" @@ -207,6 +208,23 @@ def wrapper(*args: Any, **kwargs: Any) -> None: return wrapper +def send_custom_usage( + project: str, + *, + telemetry_type: str, + subproject: Optional[str] = None, + data: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> None: + active_session = next(iter(session._get_active_sessions())) + assert active_session, "Missing active session object" + + client = _SourceTelemetryClient(conn=active_session._conn._conn, project=project, subproject=subproject) + common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type) + data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs} + client._send(msg=data) + + def send_api_usage_telemetry( project: str, subproject: Optional[str] = None, @@ -228,7 +246,8 @@ def send_api_usage_telemetry( custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None, ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]: """ - Decorator that sends API usage telemetry. + Decorator that sends API usage telemetry and adds function usage statement parameters to the dataframe returned by + the function. Args: project: Project. @@ -253,6 +272,51 @@ def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnVal def wrap(*args: Any, **kwargs: Any) -> _ReturnValue: params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None + api_calls: List[Union[Dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = [] + if api_calls_extractor: + extracted_api_calls = api_calls_extractor(args[0]) + for api_call in extracted_api_calls: + if isinstance(api_call, str): + api_calls.append({TelemetryField.NAME.value: api_call}) + elif callable(api_call): + api_calls.append({TelemetryField.NAME.value: _get_full_func_name(api_call)}) + else: + api_calls.append(api_call) + api_calls.append({TelemetryField.NAME.value: _get_full_func_name(func)}) + + sfqids = None + if sfqids_extractor: + sfqids = sfqids_extractor(args[0]) + + statement_params = get_function_usage_statement_params( + project=project, + subproject=subproject, + function_category=TelemetryField.FUNC_CAT_USAGE.value, + function_name=_get_full_func_name(func), + function_parameters=params, + api_calls=api_calls, + custom_tags=custom_tags, + ) + + def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: Dict[str, Any]) -> _ReturnValue: + """ + Update SnowML function usage statement parameters to the object if it is a Snowpark DataFrame. + Used to track APIs returning a Snowpark DataFrame. + + Args: + obj: Object to check and update. + statement_params: Statement parameters. + + Returns: + Updated object. + """ + if isinstance(obj, dataframe.DataFrame): + if hasattr(obj, "_statement_params") and obj._statement_params: + obj._statement_params.update(statement_params) + else: + obj._statement_params = statement_params # type: ignore[assignment] + return obj + # prioritize `conn_attr_name` over the active session if conn_attr_name: # raise AttributeError if conn attribute does not exist in `self` @@ -266,36 +330,20 @@ def wrap(*args: Any, **kwargs: Any) -> _ReturnValue: # server no default session except snowpark_exceptions.SnowparkSessionException: try: - return func(*args, **kwargs) + return update_stmt_params_if_snowpark_df(func(*args, **kwargs), statement_params) except Exception as e: if isinstance(e, snowml_exceptions.SnowflakeMLException): - e = e.original_exception + raise e.original_exception.with_traceback(e.__traceback__) from None # suppress SnowparkSessionException from telemetry in the stack trace raise e from None conn = active_session._conn._conn if (not active_session.telemetry_enabled) or (conn is None): try: - return func(*args, **kwargs) + return update_stmt_params_if_snowpark_df(func(*args, **kwargs), statement_params) except snowml_exceptions.SnowflakeMLException as e: raise e.original_exception from e - api_calls: List[Dict[str, Any]] = [] - if api_calls_extractor: - extracted_api_calls = api_calls_extractor(args[0]) - for api_call in extracted_api_calls: - if isinstance(api_call, str): - api_calls.append({TelemetryField.NAME.value: api_call}) - elif callable(api_call): - api_calls.append({TelemetryField.NAME.value: _get_full_func_name(api_call)}) - else: - api_calls.append(api_call) - api_calls.append({TelemetryField.NAME.value: _get_full_func_name(func)}) - - sfqids = None - if sfqids_extractor: - sfqids = sfqids_extractor(args[0]) - # TODO(hayu): [SNOW-750287] Optimize telemetry client to a singleton. telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject) telemetry_args = dict( @@ -314,22 +362,24 @@ def wrap(*args: Any, **kwargs: Any) -> _ReturnValue: if hasattr(e, "_snowflake_ml_handled") and e._snowflake_ml_handled: raise e if isinstance(e, snowpark_exceptions.SnowparkClientException): - e = snowml_exceptions.SnowflakeMLException( + me = snowml_exceptions.SnowflakeMLException( error_code=error_codes.INTERNAL_SNOWPARK_ERROR, original_exception=e ) else: - e = snowml_exceptions.SnowflakeMLException( + me = snowml_exceptions.SnowflakeMLException( error_code=error_codes.UNDEFINED, original_exception=e ) - telemetry_args["error"] = repr(e) - telemetry_args["error_code"] = e.error_code - e.original_exception._snowflake_ml_handled = True # type: ignore[attr-defined] - if e.suppress_source_trace: - raise e.original_exception from None else: - raise e.original_exception from e + me = e + telemetry_args["error"] = repr(me) + telemetry_args["error_code"] = me.error_code + me.original_exception._snowflake_ml_handled = True # type: ignore[attr-defined] + if me.suppress_source_trace: + raise me.original_exception from None + else: + raise me.original_exception from e else: - return res + return update_stmt_params_if_snowpark_df(res, statement_params) finally: telemetry.send_function_usage_telemetry(**telemetry_args) global _log_counter @@ -343,68 +393,6 @@ def wrap(*args: Any, **kwargs: Any) -> _ReturnValue: return decorator -def add_stmt_params_to_df( - project: str, - subproject: Optional[str] = None, - *, - function_category: str = TelemetryField.FUNC_CAT_USAGE.value, - func_params_to_log: Optional[Iterable[str]] = None, - api_calls: Optional[ - List[ - Union[ - Dict[str, Union[Callable[..., Any], str]], - Union[Callable[..., Any], str], - ] - ] - ] = None, - custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None, -) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]: - """ - Decorator that adds function usage statement parameters to the dataframe returned by the function. - - Args: - project: Project. - subproject: Subproject. - function_category: Function category. - func_params_to_log: Function parameters to log. - api_calls: API calls in the function. - custom_tags: Custom tags. - - Returns: - Decorator that adds function usage statement parameters to the dataframe returned by the decorated function. - """ - - def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]: - @functools.wraps(func) - def wrap(*args: Any, **kwargs: Any) -> _ReturnValue: - params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None - statement_params = get_function_usage_statement_params( - project=project, - subproject=subproject, - function_category=function_category, - function_name=_get_full_func_name(func), - function_parameters=params, - api_calls=api_calls, - custom_tags=custom_tags, - ) - - try: - res = func(*args, **kwargs) - if isinstance(res, dataframe.DataFrame): - if hasattr(res, "_statement_params") and res._statement_params: - res._statement_params.update(statement_params) - else: - res._statement_params = statement_params # type: ignore[assignment] - except Exception: - raise - else: - return res - - return cast(Callable[_Args, _ReturnValue], wrap) - - return decorator - - def _get_full_func_name(func: Callable[..., Any]) -> str: """ Get the full function name with module and qualname. diff --git a/snowflake/ml/_internal/telemetry_test.py b/snowflake/ml/_internal/telemetry_test.py index 0e373809..95c7b403 100644 --- a/snowflake/ml/_internal/telemetry_test.py +++ b/snowflake/ml/_internal/telemetry_test.py @@ -1,5 +1,6 @@ import inspect import time +import traceback from typing import Any, Dict, Optional from unittest import mock @@ -289,13 +290,23 @@ def foo(self, ex: bool = False) -> None: {"params": {"default_stmt_params": {}}}, {"params": {"default_stmt_params": {"default": 0}}}, ) - def test_add_stmt_params_to_df(self, params: Dict[str, Any]) -> None: + @mock.patch("snowflake.snowpark.session._get_active_sessions") + def test_add_stmt_params_to_df(self, mock_get_active_sessions: mock.MagicMock, params: Dict[str, Any]) -> None: + mock_get_active_sessions.return_value = {self.mock_session} + + def extract_api_calls(captured: Any) -> Any: + assert isinstance(captured, DummyObject) + return captured.api_calls + class DummyObject: - @utils_telemetry.add_stmt_params_to_df( + def __init__(self) -> None: + self.api_calls = [time.time, time.sleep] + + @utils_telemetry.send_api_usage_telemetry( project=_PROJECT, subproject=_SUBPROJECT, func_params_to_log=["default_stmt_params"], - api_calls=[time.time, time.sleep], + api_calls_extractor=extract_api_calls, custom_tags={"custom_tag": "tag"}, ) def foo(self, default_stmt_params: Optional[Dict[str, Any]] = None) -> dataframe.DataFrame: @@ -304,7 +315,7 @@ def foo(self, default_stmt_params: Optional[Dict[str, Any]] = None) -> dataframe mock_df._statement_params = default_stmt_params.copy() # type: ignore[assignment] return mock_df - @utils_telemetry.add_stmt_params_to_df( + @utils_telemetry.send_api_usage_telemetry( project=_PROJECT, ) def foo2(self) -> "DummyObject": @@ -315,8 +326,10 @@ def foo2(self) -> "DummyObject": actual_statement_params = returned_df._statement_params full_func_name_time = utils_telemetry._get_full_func_name(time.time) full_func_name_sleep = utils_telemetry._get_full_func_name(time.sleep) + full_func_name_foo = utils_telemetry._get_full_func_name(DummyObject.foo) api_call_time = {utils_telemetry.TelemetryField.NAME.value: full_func_name_time} api_call_sleep = {utils_telemetry.TelemetryField.NAME.value: full_func_name_sleep} + api_call_foo = {utils_telemetry.TelemetryField.NAME.value: full_func_name_foo} expected_statement_params = { connector_telemetry.TelemetryField.KEY_SOURCE.value: _SOURCE, utils_telemetry.TelemetryField.KEY_PROJECT.value: _PROJECT, @@ -329,7 +342,7 @@ def foo2(self) -> "DummyObject": utils_telemetry.TelemetryField.KEY_FUNC_PARAMS.value: { "default_stmt_params": repr(params["default_stmt_params"]) }, - utils_telemetry.TelemetryField.KEY_API_CALLS.value: [api_call_time, api_call_sleep], + utils_telemetry.TelemetryField.KEY_API_CALLS.value: [api_call_time, api_call_sleep, api_call_foo], utils_telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: {"custom_tag": "tag"}, } self.assertIsNotNone(actual_statement_params) @@ -401,27 +414,102 @@ def nested_foo(self) -> None: self.assertIn(error_codes.INTERNAL_TEST, str(ex.exception)) self.assertNotIn(error_codes.UNDEFINED, str(ex.exception)) + @mock.patch("snowflake.snowpark.session._get_active_sessions") + def test_snowml_nested_error_tb_1(self, mock_get_active_sessions: mock.MagicMock) -> None: + mock_get_active_sessions.return_value = {self.mock_session} + + class DummyObject: + @utils_telemetry.send_api_usage_telemetry( + project=_PROJECT, + ) + def foo(self) -> None: + self.nested_foo() + + def nested_foo(self) -> None: + raise RuntimeError("foo error") + + test_obj = DummyObject() + try: + test_obj.foo() + except RuntimeError: + self.assertIn("nested_foo", traceback.format_exc()) + + @mock.patch("snowflake.snowpark.session._get_active_sessions") + def test_snowml_nested_error_tb_2(self, mock_get_active_sessions: mock.MagicMock) -> None: + mock_get_active_sessions.return_value = {self.mock_session} + + class DummyObject: + @utils_telemetry.send_api_usage_telemetry( + project=_PROJECT, + ) + def foo(self) -> None: + self.nested_foo() + + def nested_foo(self) -> None: + raise exceptions.SnowflakeMLException( + error_code=error_codes.INTERNAL_TEST, + original_exception=RuntimeError("foo error"), + ) + + test_obj = DummyObject() + try: + test_obj.foo() + except RuntimeError: + self.assertIn("nested_foo", traceback.format_exc()) + @mock.patch("snowflake.snowpark.session._get_active_sessions") def test_disable_telemetry(self, mock_get_active_sessions: mock.MagicMock) -> None: mock_session = absltest.mock.MagicMock(spec=session.Session) - mock_server_conn = absltest.mock.MagicMock(spec=server_connection.ServerConnection) - mock_session._conn = mock_server_conn - mock_server_conn._conn = None - mock_get_active_sessions.return_value = {mock_session} + mock_session._conn = self.mock_server_conn mock_session.telemetry_enabled = False + mock_get_active_sessions.return_value = {mock_session} class DummyObject: @utils_telemetry.send_api_usage_telemetry( project=_PROJECT, ) - def foo(self) -> None: - raise exceptions.SnowflakeMLException(error_codes.INTERNAL_TEST, RuntimeError("Message")) + def foo(self) -> dataframe.DataFrame: + return absltest.mock.MagicMock(spec=dataframe.DataFrame) # type: ignore[no-any-return] test_obj = DummyObject() - with self.assertRaises(RuntimeError): - test_obj.foo() + returned_df = test_obj.foo() + actual_statement_params = returned_df._statement_params + # No client telemetry sent. self.mock_telemetry.try_add_log_to_batch.assert_not_called() + assert actual_statement_params is not None # mypy + # Statement parameters updated to the returned dataframe. + self.assertEqual(actual_statement_params[connector_telemetry.TelemetryField.KEY_SOURCE.value], env.SOURCE) + + @mock.patch("snowflake.snowpark.session._get_active_sessions") + def test_send_custom_usage(self, mock_get_active_sessions: mock.MagicMock) -> None: + mock_get_active_sessions.return_value = {self.mock_session} + + project = "m_project" + subproject = "m_subproject" + telemetry_type = "m_telemetry_type" + tag = "m_tag" + data = {"k1": "v1", "k2": {"nested_k2": "nested_v2"}} + kwargs = {utils_telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: tag} + + with mock.patch.object(utils_telemetry._SourceTelemetryClient, "_send", return_value=None) as m_send: + utils_telemetry.send_custom_usage( + project=project, telemetry_type=telemetry_type, subproject=subproject, data=data, **kwargs + ) + + m_send.assert_called_once_with( + msg={ + "source": "SnowML", + "project": project, + "subproject": subproject, + "version": _VERSION, + "python_version": _PYTHON_VERSION, + "operating_system": _OS, + "type": telemetry_type, + "data": data, + "custom_tags": tag, + } + ) if __name__ == "__main__": diff --git a/snowflake/ml/_internal/utils/BUILD.bazel b/snowflake/ml/_internal/utils/BUILD.bazel index cf981880..6aca28de 100644 --- a/snowflake/ml/_internal/utils/BUILD.bazel +++ b/snowflake/ml/_internal/utils/BUILD.bazel @@ -235,3 +235,22 @@ py_test( "//snowflake/ml/test_utils:mock_session", ], ) + +py_library( + name = "spcs_attribution_utils", + srcs = ["spcs_attribution_utils.py"], + deps = [ + ":query_result_checker", + "//snowflake/ml/_internal:telemetry", + ], +) + +py_test( + name = "spcs_attribution_utils_test", + srcs = ["spcs_attribution_utils_test.py"], + deps = [ + ":spcs_attribution_utils", + "//snowflake/ml/test_utils:mock_data_frame", + "//snowflake/ml/test_utils:mock_session", + ], +) diff --git a/snowflake/ml/_internal/utils/retryable_http.py b/snowflake/ml/_internal/utils/retryable_http.py index 433e7ff6..9c37f104 100644 --- a/snowflake/ml/_internal/utils/retryable_http.py +++ b/snowflake/ml/_internal/utils/retryable_http.py @@ -5,11 +5,23 @@ from urllib3.util import retry -def get_http_client() -> requests.Session: - # Set up a retry policy for requests +def get_http_client(total_retries: int = 5, backoff_factor: float = 0.1) -> requests.Session: + """Construct retryable http client. + + Args: + total_retries: Total number of retries to allow. + backoff_factor: A backoff factor to apply between attempts after the second try. Time to sleep is calculated by + {backoff factor} * (2 ** ({number of previous retries})). For example, with default retries of 5 and backoff + factor set to 0.1, each subsequent retry will sleep [0.2s, 0.4s, 0.8s, 1.6s, 3.2s] respectively. + + Returns: + requests.Session object. + + """ + retry_strategy = retry.Retry( - total=3, # total number of retries - backoff_factor=0.1, # 100ms initial delay + total=total_retries, + backoff_factor=backoff_factor, status_forcelist=[ http.HTTPStatus.TOO_MANY_REQUESTS, http.HTTPStatus.INTERNAL_SERVER_ERROR, diff --git a/snowflake/ml/_internal/utils/spcs_attribution_utils.py b/snowflake/ml/_internal/utils/spcs_attribution_utils.py new file mode 100644 index 00000000..7e7018c4 --- /dev/null +++ b/snowflake/ml/_internal/utils/spcs_attribution_utils.py @@ -0,0 +1,122 @@ +import logging +from datetime import datetime +from typing import Any, Dict, Optional + +from snowflake import snowpark +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.utils import query_result_checker + +logger = logging.getLogger(__name__) + +_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f %z" +_COMPUTE_POOL = "compute_pool" +_CREATED_ON = "created_on" +_INSTANCE_FAMILY = "instance_family" +_NAME = "name" +_TELEMETRY_PROJECT = "MLOps" +_TELEMETRY_SUBPROJECT = "SpcsDeployment" +_SERVICE_START = "SPCS_SERVICE_START" +_SERVICE_END = "SPCS_SERVICE_END" + + +def _desc_compute_pool(session: snowpark.Session, compute_pool_name: str) -> Dict[str, Any]: + sql = f"DESC COMPUTE POOL {compute_pool_name}" + result = ( + query_result_checker.SqlResultValidator( + session=session, + query=sql, + ) + .has_column(_INSTANCE_FAMILY) + .has_column(_NAME) + .has_dimensions(expected_rows=1) + .validate() + ) + return result[0].as_dict() + + +def _desc_service(session: snowpark.Session, fully_qualified_name: str) -> Dict[str, Any]: + sql = f"DESC SERVICE {fully_qualified_name}" + result = ( + query_result_checker.SqlResultValidator( + session=session, + query=sql, + ) + .has_column(_COMPUTE_POOL) + .has_dimensions(expected_rows=1) + .validate() + ) + return result[0].as_dict() + + +def _get_current_time() -> datetime: + """ + This method exists to make it easier to mock datetime in test. + + Returns: + current datetime + """ + return datetime.now() + + +def _send_service_telemetry( + fully_qualified_name: Optional[str] = None, + compute_pool_name: Optional[str] = None, + service_details: Optional[Dict[str, Any]] = None, + compute_pool_details: Optional[Dict[str, Any]] = None, + duration_in_seconds: Optional[int] = None, + kwargs: Optional[Dict[str, Any]] = None, +) -> None: + try: + telemetry.send_custom_usage( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + telemetry_type=telemetry.TelemetryField.TYPE_SNOWML_SPCS_USAGE.value, + data={ + "service_name": fully_qualified_name, + "compute_pool_name": compute_pool_name, + "service_details": service_details, + "compute_pool_details": compute_pool_details, + "duration_in_seconds": duration_in_seconds, + }, + kwargs=kwargs, + ) + except Exception as e: + logger.error(f"Failed to send service telemetry: {e}") + + +def record_service_start(session: snowpark.Session, fully_qualified_name: str) -> None: + service_details = _desc_service(session, fully_qualified_name) + compute_pool_name = service_details[_COMPUTE_POOL] + compute_pool_details = _desc_compute_pool(session, compute_pool_name) + + _send_service_telemetry( + fully_qualified_name=fully_qualified_name, + compute_pool_name=compute_pool_name, + service_details=service_details, + compute_pool_details=compute_pool_details, + kwargs={telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: _SERVICE_START}, + ) + + logger.info(f"Service {fully_qualified_name} created with compute pool {compute_pool_name}.") + + +def record_service_end(session: snowpark.Session, fully_qualified_name: str) -> None: + service_details = _desc_service(session, fully_qualified_name) + compute_pool_details = _desc_compute_pool(session, service_details[_COMPUTE_POOL]) + compute_pool_name = service_details[_COMPUTE_POOL] + + created_on_datetime: datetime = service_details[_CREATED_ON] + current_time: datetime = _get_current_time() + current_time = current_time.replace(tzinfo=created_on_datetime.tzinfo) + duration_in_seconds = int((current_time - created_on_datetime).total_seconds()) + + _send_service_telemetry( + fully_qualified_name=fully_qualified_name, + compute_pool_name=compute_pool_name, + service_details=service_details, + compute_pool_details=compute_pool_details, + duration_in_seconds=duration_in_seconds, + kwargs={telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: _SERVICE_END}, + ) + + logger.info(f"Service {fully_qualified_name} deleted from compute pool {compute_pool_name}") diff --git a/snowflake/ml/_internal/utils/spcs_attribution_utils_test.py b/snowflake/ml/_internal/utils/spcs_attribution_utils_test.py new file mode 100644 index 00000000..482c5d0b --- /dev/null +++ b/snowflake/ml/_internal/utils/spcs_attribution_utils_test.py @@ -0,0 +1,135 @@ +import datetime +from typing import Any, Dict, cast +from unittest import mock + +from absl.testing import absltest + +from snowflake import snowpark +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.utils import spcs_attribution_utils +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import session + + +class SpcsAttributionUtilsTest(absltest.TestCase): + def setUp(self) -> None: + super().setUp() + self._m_session = mock_session.MockSession(conn=None, test_case=self) + self._fully_qualified_service_name = "db.schema.my_service" + self._m_compute_pool_name = "my_pool" + self._service_created_on = datetime.datetime.strptime( + "2023-11-16 13:01:00.062 -0800", spcs_attribution_utils._DATETIME_FORMAT + ) + + mock_service_detail = self._get_mock_service_details() + self._m_session.add_mock_sql( + query=f"DESC SERVICE {self._fully_qualified_service_name}", + result=mock_data_frame.MockDataFrame(collect_result=[snowpark.Row(**mock_service_detail)]), + ) + + mock_compute_pool_detail = self._get_mock_compute_pool_details() + self._m_session.add_mock_sql( + query=f"DESC COMPUTE POOL {self._m_compute_pool_name}", + result=mock_data_frame.MockDataFrame(collect_result=[snowpark.Row(**mock_compute_pool_detail)]), + ) + + def _get_mock_service_details(self) -> Dict[str, Any]: + return { + "name": "my_service", + "database_name": "my_db", + "schema_name": "my_schema", + "owner": "Engineer", + "compute_pool": self._m_compute_pool_name, + "spec": "--- spec:", + "dns_name": "service-dummy.my-schema.my-db.snowflakecomputing.internal", + "public_endpoints": {"predict": "dummy.snowflakecomputing.app"}, + "min_instances": 1, + "max_instances": 1, + "created_on": self._service_created_on, + "updated_on": "2023-11-16 13:01:00.595 -0800", + "comment": None, + } + + def _get_mock_compute_pool_details(self) -> Dict[str, Any]: + return { + "name": self._m_compute_pool_name, + "state": "Active", + "min_nodes": 1, + "max_nodes": 1, + "instance_family": "STANDARD_2", + "num_services": 1, + "num_jobs": 2, + "active_nodes": 1, + "idle_nodes": 1, + "created_on": "2023-09-21 09:17:39.627 -0700", + "resumed_on": "2023-09-21 09:17:39.628 -0700", + "updated_on": "2023-11-27 15:08:55.725 -0800", + "owner": "ACCOUNTADMIN", + "comment": None, + } + + def test_record_service_start(self) -> None: + with mock.patch.object(spcs_attribution_utils, "_send_service_telemetry", return_value=None) as m_telemetry: + with self.assertLogs(level="INFO") as cm: + spcs_attribution_utils.record_service_start( + cast(session.Session, self._m_session), self._fully_qualified_service_name + ) + + assert len(cm.output) == 1, "there should only be 1 log" + log = cm.output[0] + + service_details = self._get_mock_service_details() + compute_pool_details = self._get_mock_compute_pool_details() + + self.assertEqual( + log, + f"INFO:snowflake.ml._internal.utils.spcs_attribution_utils:Service " + f"{self._fully_qualified_service_name} created with compute pool {self._m_compute_pool_name}.", + ) + m_telemetry.assert_called_once_with( + fully_qualified_name=self._fully_qualified_service_name, + compute_pool_name=self._m_compute_pool_name, + service_details=service_details, + compute_pool_details=compute_pool_details, + kwargs={telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: spcs_attribution_utils._SERVICE_START}, + ) + + def test_record_service_end(self) -> None: + current_datetime = self._service_created_on + datetime.timedelta(days=2, hours=1, minutes=30, seconds=20) + expected_duration = 178220 # 2 days 1 hour 30 minutes and 20 seconds. + + with mock.patch( + "snowflake.ml._internal.utils.spcs_attribution_utils._get_current_time" + ) as mock_datetime_now, mock.patch.object( + spcs_attribution_utils, "_send_service_telemetry", return_value=None + ) as m_telemetry: + with self.assertLogs(level="INFO") as cm: + mock_datetime_now.return_value = current_datetime + + spcs_attribution_utils.record_service_end( + cast(session.Session, self._m_session), self._fully_qualified_service_name + ) + assert len(cm.output) == 1, "there should only be 1 log" + log = cm.output[0] + + service_details = self._get_mock_service_details() + compute_pool_details = self._get_mock_compute_pool_details() + + self.assertEqual( + log, + f"INFO:snowflake.ml._internal.utils.spcs_attribution_utils:Service " + f"{self._fully_qualified_service_name} deleted from compute pool {self._m_compute_pool_name}", + ) + + m_telemetry.assert_called_once_with( + fully_qualified_name=self._fully_qualified_service_name, + compute_pool_name=self._m_compute_pool_name, + service_details=service_details, + compute_pool_details=compute_pool_details, + duration_in_seconds=expected_duration, + kwargs={telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: spcs_attribution_utils._SERVICE_END}, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/dataset/dataset.py b/snowflake/ml/dataset/dataset.py index c683b523..b51e6e84 100644 --- a/snowflake/ml/dataset/dataset.py +++ b/snowflake/ml/dataset/dataset.py @@ -140,7 +140,7 @@ def to_json(self) -> str: @classmethod def from_json(cls, json_str: str, session: Session) -> "Dataset": - json_dict = json.loads(json_str) + json_dict = json.loads(json_str, strict=False) json_dict["df"] = session.sql(json_dict.pop("df_query")) fs_meta_json = json_dict["feature_store_metadata"] diff --git a/snowflake/ml/feature_store/entity.py b/snowflake/ml/feature_store/entity.py index a1b444a2..44b0ece0 100644 --- a/snowflake/ml/feature_store/entity.py +++ b/snowflake/ml/feature_store/entity.py @@ -5,11 +5,11 @@ to_sql_identifiers, ) -ENTITY_NAME_LENGTH_LIMIT = 32 -FEATURE_VIEW_ENTITY_TAG_DELIMITER = "," -ENTITY_JOIN_KEY_DELIMITER = "," +_ENTITY_NAME_LENGTH_LIMIT = 32 +_FEATURE_VIEW_ENTITY_TAG_DELIMITER = "," +_ENTITY_JOIN_KEY_DELIMITER = "," # join key length limit is the length limit of TAG value -ENTITY_JOIN_KEY_LENGTH_LIMIT = 256 +_ENTITY_JOIN_KEY_LENGTH_LIMIT = 256 class Entity: @@ -35,18 +35,18 @@ def __init__(self, name: str, join_keys: List[str], desc: str = "") -> None: self.desc = desc def _validate(self, name: str, join_keys: List[str]) -> None: - if len(name) > ENTITY_NAME_LENGTH_LIMIT: - raise ValueError(f"Entity name `{name}` exceeds maximum length: {ENTITY_NAME_LENGTH_LIMIT}") - if FEATURE_VIEW_ENTITY_TAG_DELIMITER in name: - raise ValueError(f"Entity name contains invalid char: `{FEATURE_VIEW_ENTITY_TAG_DELIMITER}`") + if len(name) > _ENTITY_NAME_LENGTH_LIMIT: + raise ValueError(f"Entity name `{name}` exceeds maximum length: {_ENTITY_NAME_LENGTH_LIMIT}") + if _FEATURE_VIEW_ENTITY_TAG_DELIMITER in name: + raise ValueError(f"Entity name contains invalid char: `{_FEATURE_VIEW_ENTITY_TAG_DELIMITER}`") if len(set(join_keys)) != len(join_keys): raise ValueError(f"Duplicate join keys detected in: {join_keys}") - if len(FEATURE_VIEW_ENTITY_TAG_DELIMITER.join(join_keys)) > ENTITY_JOIN_KEY_LENGTH_LIMIT: - raise ValueError(f"Total length of join keys exceeded maximum length: {ENTITY_JOIN_KEY_LENGTH_LIMIT}") + if len(_FEATURE_VIEW_ENTITY_TAG_DELIMITER.join(join_keys)) > _ENTITY_JOIN_KEY_LENGTH_LIMIT: + raise ValueError(f"Total length of join keys exceeded maximum length: {_ENTITY_JOIN_KEY_LENGTH_LIMIT}") for k in join_keys: - if ENTITY_JOIN_KEY_DELIMITER in k: - raise ValueError(f"Invalid char `{ENTITY_JOIN_KEY_DELIMITER}` detected in join key {k}") + if _ENTITY_JOIN_KEY_DELIMITER in k: + raise ValueError(f"Invalid char `{_ENTITY_JOIN_KEY_DELIMITER}` detected in join key {k}") def _to_dict(self) -> Dict[str, str]: entity_dict = self.__dict__.copy() diff --git a/snowflake/ml/feature_store/feature_store.py b/snowflake/ml/feature_store/feature_store.py index 7258837c..24cb1b17 100644 --- a/snowflake/ml/feature_store/feature_store.py +++ b/snowflake/ml/feature_store/feature_store.py @@ -25,42 +25,40 @@ ) from snowflake.ml.dataset.dataset import Dataset, FeatureStoreMetadata from snowflake.ml.feature_store.entity import ( - ENTITY_JOIN_KEY_DELIMITER, - ENTITY_NAME_LENGTH_LIMIT, - FEATURE_VIEW_ENTITY_TAG_DELIMITER, + _ENTITY_JOIN_KEY_DELIMITER, + _ENTITY_NAME_LENGTH_LIMIT, + _FEATURE_VIEW_ENTITY_TAG_DELIMITER, Entity, ) from snowflake.ml.feature_store.feature_view import ( - FEATURE_OBJ_TYPE, - FEATURE_VIEW_NAME_DELIMITER, - TIMESTAMP_COL_PLACEHOLDER, + _FEATURE_OBJ_TYPE, + _FEATURE_VIEW_NAME_DELIMITER, + _TIMESTAMP_COL_PLACEHOLDER, FeatureView, FeatureViewSlice, FeatureViewStatus, + FeatureViewVersion, ) from snowflake.snowpark import DataFrame, Row, Session, functions as F from snowflake.snowpark._internal import type_utils, utils as snowpark_utils +from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.types import StructField logger = logging.getLogger(__name__) -ENTITY_TAG_PREFIX = "SNOWML_FEATURE_STORE_ENTITY_" -FEATURE_VIEW_ENTITY_TAG = "SNOWML_FEATURE_STORE_FV_ENTITIES" -FEATURE_VIEW_TS_COL_TAG = "SNOWML_FEATURE_STORE_FV_TS_COL" -FEATURE_STORE_OBJECT_TAG = "SNOWML_FEATURE_STORE_OBJECT" -PROJECT = "FeatureStore" - -# TODO: Enable when ASOF join is released. https://snowflakecomputing.atlassian.net/browse/SNOW-780702 -_ENABLE_ASOF_JOIN = False - -DT_OR_VIEW_QUERY_PATTERN = re.compile( +_ENTITY_TAG_PREFIX = "SNOWML_FEATURE_STORE_ENTITY_" +_FEATURE_VIEW_ENTITY_TAG = "SNOWML_FEATURE_STORE_FV_ENTITIES" +_FEATURE_VIEW_TS_COL_TAG = "SNOWML_FEATURE_STORE_FV_TS_COL" +_FEATURE_STORE_OBJECT_TAG = "SNOWML_FEATURE_STORE_OBJECT" +_PROJECT = "FeatureStore" +_DT_OR_VIEW_QUERY_PATTERN = re.compile( r"""CREATE\ (?P(DYNAMIC\ TABLE|VIEW))\ .* COMMENT\ =\ '(?P.*)'\s* TAG.*?{entity_tag}\ =\ '(?P.*?)',\n .*?{ts_col_tag}\ =\ '(?P.*?)',?.*? AS\ (?P.*) """.format( - entity_tag=FEATURE_VIEW_ENTITY_TAG, ts_col_tag=FEATURE_VIEW_TS_COL_TAG + entity_tag=_FEATURE_VIEW_ENTITY_TAG, ts_col_tag=_FEATURE_VIEW_TS_COL_TAG ), flags=re.DOTALL | re.IGNORECASE | re.X, ) @@ -97,7 +95,7 @@ def wrapper(self: FeatureStore, *args: Any, **kargs: Any) -> Any: def dispatch_decorator(prpr_version: str) -> Callable[..., Any]: def decorator(f: Callable[..., Any]) -> Callable[..., Any]: - @telemetry.send_api_usage_telemetry(project=PROJECT) + @telemetry.send_api_usage_telemetry(project=_PROJECT) @snowpark_utils.private_preview(version=prpr_version) @switch_warehouse @functools.wraps(f) @@ -114,7 +112,7 @@ class FeatureStore: FeatureStore provides APIs to create, materialize, retrieve and manage feature pipelines. """ - @telemetry.send_api_usage_telemetry(project=PROJECT) + @telemetry.send_api_usage_telemetry(project=_PROJECT) @snowpark_utils.private_preview(version="1.0.8") def __init__( self, @@ -143,12 +141,13 @@ def __init__( database = SqlIdentifier(database) name = SqlIdentifier(name) - self._telemetry_stmp = telemetry.get_function_usage_statement_params(PROJECT) + self._telemetry_stmp = telemetry.get_function_usage_statement_params(_PROJECT) self._session: Session = session self._config = _FeatureStoreConfig( database=database, schema=name, ) + self._asof_join_enabled = None # A dict from object name to tuple of search space and object domain. # search space used in query "SHOW LIKE IN " @@ -182,9 +181,9 @@ def __init__( ) for tag in to_sql_identifiers( [ - FEATURE_VIEW_ENTITY_TAG, - FEATURE_VIEW_TS_COL_TAG, - FEATURE_STORE_OBJECT_TAG, + _FEATURE_VIEW_ENTITY_TAG, + _FEATURE_VIEW_TS_COL_TAG, + _FEATURE_STORE_OBJECT_TAG, ] ): self._session.sql(f"CREATE TAG IF NOT EXISTS {self._get_fully_qualified_name(tag)}").collect( @@ -199,7 +198,7 @@ def __init__( logger.info(f"Successfully connected to feature store: {self._config.full_schema_path}.") - @telemetry.send_api_usage_telemetry(project=PROJECT) + @telemetry.send_api_usage_telemetry(project=_PROJECT) @snowpark_utils.private_preview(version="1.0.12") def update_default_warehouse(self, warehouse_name: str) -> None: """Update default warehouse for feature store. @@ -241,7 +240,7 @@ def register_entity(self, entity: Entity) -> None: suppress_source_trace=True, ) - join_keys_str = ENTITY_JOIN_KEY_DELIMITER.join(entity.join_keys) + join_keys_str = _ENTITY_JOIN_KEY_DELIMITER.join(entity.join_keys) full_tag_name = self._get_fully_qualified_name(tag_name) self._session.sql(f"CREATE TAG IF NOT EXISTS {full_tag_name} COMMENT = '{entity.desc}'").collect( statement_params=self._telemetry_stmp @@ -269,7 +268,7 @@ def register_feature_view( Args: feature_view: FeatureView instance to materialize. version: version of the registered FeatureView. - NOTE: `$` is not a valid char for the version identifier. Also version will be capitalized. + NOTE: Version only accepts letters, numbers and underscore. Also version will be capitalized. block: Specify whether the FeatureView backend materialization should be blocking or not. If blocking then the API will wait until the initial FeatureView data is generated. @@ -284,16 +283,15 @@ def register_feature_view( SnowflakeMLException: [RuntimeError] Failed to create dynamic table, task, or view. SnowflakeMLException: [RuntimeError] Failed to find resources. """ - version = SqlIdentifier(version) + version = FeatureViewVersion(version) if feature_view.status != FeatureViewStatus.DRAFT: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.OBJECT_ALREADY_EXISTS, original_exception=ValueError( - f"FeatureView {feature_view.name} with version {feature_view.version} has already been registered." + f"FeatureView {feature_view.name}/{feature_view.version} has already been registered." ), ) - self._validate_version_identifier(version) # TODO: ideally we should move this to FeatureView creation time for e in feature_view.entities: @@ -309,18 +307,16 @@ def register_feature_view( if len(dynamic_table_results) > 0 or len(view_results) > 0: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.OBJECT_ALREADY_EXISTS, - original_exception=ValueError( - f"FeatureView {feature_view.name} with version {version} already exists." - ), + original_exception=ValueError(f"FeatureView {feature_view.name}/{version} already exists."), suppress_source_trace=True, ) fully_qualified_name = self._get_fully_qualified_name(feature_view_name) - entities = FEATURE_VIEW_ENTITY_TAG_DELIMITER.join([e.name for e in feature_view.entities]) + entities = _FEATURE_VIEW_ENTITY_TAG_DELIMITER.join([e.name for e in feature_view.entities]) timestamp_col = ( feature_view.timestamp_col if feature_view.timestamp_col is not None - else SqlIdentifier(TIMESTAMP_COL_PLACEHOLDER) + else SqlIdentifier(_TIMESTAMP_COL_PLACEHOLDER) ) def create_col_desc(col: StructField) -> str: @@ -349,9 +345,9 @@ def create_col_desc(col: StructField) -> str: query = f"""CREATE VIEW {fully_qualified_name} ({column_descs}) COMMENT = '{feature_view.desc}' TAG ( - {FEATURE_VIEW_ENTITY_TAG} = '{entities}', - {FEATURE_VIEW_TS_COL_TAG} = '{timestamp_col}', - {FEATURE_STORE_OBJECT_TAG} = '' + {_FEATURE_VIEW_ENTITY_TAG} = '{entities}', + {_FEATURE_VIEW_TS_COL_TAG} = '{timestamp_col}', + {_FEATURE_STORE_OBJECT_TAG} = '' ) AS {feature_view.query} """ @@ -362,8 +358,8 @@ def create_col_desc(col: StructField) -> str: original_exception=RuntimeError(f"Create view {fully_qualified_name} [\n{query}\n] failed: {e}"), ) from e - logger.info(f"Registered FeatureView {feature_view.name} with version {version}.") - return self.get_feature_view(feature_view.name, version) # type: ignore[no-any-return] + logger.info(f"Registered FeatureView {feature_view.name}/{version}.") + return self.get_feature_view(feature_view.name, str(version)) # type: ignore[no-any-return] @dispatch_decorator(prpr_version="1.1.0") def update_feature_view(self, feature_view: FeatureView) -> None: @@ -482,14 +478,14 @@ def get_feature_view(self, name: str, version: str) -> FeatureView: or incurred exception when reconstructing the FeatureView object. """ name = SqlIdentifier(name) - version = SqlIdentifier(version) + version = FeatureViewVersion(version) fv_name = FeatureView._get_physical_name(name, version) results = self._get_backend_representations(fv_name) if len(results) != 1: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.NOT_FOUND, - original_exception=ValueError(f"Failed to find FeatureView {name} with version {version}: {results}"), + original_exception=ValueError(f"Failed to find FeatureView {name}/{version}: {results}"), ) return self._compose_feature_view(results[0]) @@ -611,7 +607,8 @@ def resume_feature_view(self, feature_view: FeatureView) -> FeatureView: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.SNOWML_UPDATE_FAILED, original_exception=ValueError( - f"FeatureView {feature_view.name} is not in suspended status. Actual status: {feature_view.status}" + f"FeatureView {feature_view.name}/{feature_view.version} is not in suspended status. " + f"Actual status: {feature_view.status}" ), ) @@ -636,7 +633,8 @@ def suspend_feature_view(self, feature_view: FeatureView) -> FeatureView: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.SNOWML_UPDATE_FAILED, original_exception=ValueError( - f"FeatureView {feature_view.name} is not in running status. Actual status: {feature_view.status}" + f"FeatureView {feature_view.name}/{feature_view.version} is not in running status. " + f"Actual status: {feature_view.status}" ), ) return self._update_feature_view_status(feature_view, "SUSPEND") @@ -672,7 +670,7 @@ def delete_feature_view(self, feature_view: FeatureView) -> None: statement_params=self._telemetry_stmp ) - logger.info(f"Deleted FeatureView {feature_view.name} with version {feature_view.version}.") + logger.info(f"Deleted FeatureView {feature_view.name}/{feature_view.version}.") @dispatch_decorator(prpr_version="1.0.8") def list_entities(self) -> DataFrame: @@ -682,10 +680,10 @@ def list_entities(self) -> DataFrame: Returns: Snowpark DataFrame containing the results. """ - prefix_len = len(ENTITY_TAG_PREFIX) + 1 + prefix_len = len(_ENTITY_TAG_PREFIX) + 1 tag_values_df = self._session.sql( f""" - SELECT SUBSTR(TAG_NAME,{prefix_len},{ENTITY_NAME_LENGTH_LIMIT}) AS NAME, + SELECT SUBSTR(TAG_NAME,{prefix_len},{_ENTITY_NAME_LENGTH_LIMIT}) AS NAME, TAG_VALUE AS JOIN_KEYS FROM TABLE( {self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES( @@ -693,16 +691,16 @@ def list_entities(self) -> DataFrame: 'SCHEMA' ) ) - WHERE TAG_NAME LIKE '{ENTITY_TAG_PREFIX}%' + WHERE TAG_NAME LIKE '{_ENTITY_TAG_PREFIX}%' """ ) tag_metadata_df = self._session.sql( - f"SHOW TAGS LIKE '{ENTITY_TAG_PREFIX}%' IN SCHEMA {self._config.full_schema_path}" + f"SHOW TAGS LIKE '{_ENTITY_TAG_PREFIX}%' IN SCHEMA {self._config.full_schema_path}" ) return cast( DataFrame, tag_values_df.join( - right=tag_metadata_df.with_column("NAME", F.substr('"name"', prefix_len, ENTITY_NAME_LENGTH_LIMIT)) + right=tag_metadata_df.with_column("NAME", F.substr('"name"', prefix_len, _ENTITY_NAME_LENGTH_LIMIT)) .with_column_renamed('"comment"', "DESC") .select("NAME", "DESC"), on=["NAME"], @@ -729,7 +727,7 @@ def get_entity(self, name: str) -> Entity: name = SqlIdentifier(name) full_entity_tag_name = self._get_entity_name(name) - prefix_len = len(ENTITY_TAG_PREFIX) + 1 + prefix_len = len(_ENTITY_TAG_PREFIX) + 1 found_tags = self._find_object("TAGS", full_entity_tag_name) if len(found_tags) == 0: @@ -744,7 +742,7 @@ def get_entity(self, name: str) -> Entity: qrc.SqlResultValidator( self._session, f""" - SELECT SUBSTR(TAG_NAME,{prefix_len},{ENTITY_NAME_LENGTH_LIMIT}) AS NAME, + SELECT SUBSTR(TAG_NAME,{prefix_len},{_ENTITY_NAME_LENGTH_LIMIT}) AS NAME, TAG_VALUE AS JOIN_KEYS FROM TABLE( {self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES( @@ -773,7 +771,7 @@ def get_entity(self, name: str) -> Entity: return Entity( name=tag_values[0]["NAME"], - join_keys=tag_values[0]["JOIN_KEYS"].split(ENTITY_JOIN_KEY_DELIMITER), + join_keys=tag_values[0]["JOIN_KEYS"].split(_ENTITY_JOIN_KEY_DELIMITER), desc=found_tags[0]["comment"], ) @@ -826,6 +824,7 @@ def retrieve_feature_values( spine_df: DataFrame, features: Union[List[Union[FeatureView, FeatureViewSlice]], List[str]], spine_timestamp_col: Optional[str] = None, + exclude_columns: Optional[List[str]] = None, ) -> DataFrame: """ Enrich spine dataframe with feature values. Mainly used to generate inference data input. @@ -836,6 +835,7 @@ def retrieve_feature_values( features: List of features to join into the spine_df. Can be a list of FeatureView or FeatureViewSlice, or a list of serialized feature objects from Dataset. spine_timestamp_col: Timestamp column in spine_df for point-in-time feature value lookup. + exclude_columns: Column names to exclude from the result dataframe. Returns: Snowpark DataFrame containing the joined results. @@ -856,6 +856,10 @@ def retrieve_feature_values( cast(List[Union[FeatureView, FeatureViewSlice]], features), spine_timestamp_col, ) + + if exclude_columns is not None: + df = self._exclude_columns(df, exclude_columns) + return df @dispatch_decorator(prpr_version="1.0.8") @@ -908,8 +912,6 @@ def generate_dataset( spine_timestamp_col = SqlIdentifier(spine_timestamp_col) if spine_label_cols is not None: spine_label_cols = to_sql_identifiers(spine_label_cols) # type: ignore[assignment] - if exclude_columns is not None: - exclude_columns = to_sql_identifiers(exclude_columns) # type: ignore[assignment] allowed_save_mode = {"errorifexists", "merge"} if save_mode.lower() not in allowed_save_mode: @@ -967,16 +969,7 @@ def generate_dataset( result_df = self._session.sql(f"SELECT * FROM {snapshot_table}") if exclude_columns is not None: - dataset_cols = to_sql_identifiers(result_df.columns) - for col in exclude_columns: - if col not in dataset_cols: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - f"{col} in exclude_columns not exists in generated dataset columns: {dataset_cols}" - ), - ) - result_df = result_df.drop(exclude_columns) + result_df = self._exclude_columns(result_df, exclude_columns) fs_meta = FeatureStoreMetadata( spine_query=spine_df.queries["queries"][0], @@ -1044,11 +1037,11 @@ def clear(self) -> None: self._session.sql(f"DROP {obj_type[:-1]} {obj_name}").collect() logger.info(f"Deleted {obj_type[:-1]}: {obj_name}.") - entity_tags = self._find_object("TAGS", SqlIdentifier(ENTITY_TAG_PREFIX), prefix_match=True) + entity_tags = self._find_object("TAGS", SqlIdentifier(_ENTITY_TAG_PREFIX), prefix_match=True) all_tags = [ - FEATURE_VIEW_ENTITY_TAG, - FEATURE_VIEW_TS_COL_TAG, - FEATURE_STORE_OBJECT_TAG, + _FEATURE_VIEW_ENTITY_TAG, + _FEATURE_VIEW_TS_COL_TAG, + _FEATURE_STORE_OBJECT_TAG, ] + [SqlIdentifier(row["name"], case_sensitive=True) for row in entity_tags] for tag_name in all_tags: obj_name = self._get_fully_qualified_name(tag_name) @@ -1080,9 +1073,9 @@ def _create_dynamic_table( TARGET_LAG = '{'DOWNSTREAM' if schedule_task else feature_view.refresh_freq}' COMMENT = '{feature_view.desc}' TAG ( - {self._get_fully_qualified_name(FEATURE_VIEW_ENTITY_TAG)} = '{entities}', - {self._get_fully_qualified_name(FEATURE_VIEW_TS_COL_TAG)} = '{timestamp_col}', - {self._get_fully_qualified_name(FEATURE_STORE_OBJECT_TAG)} = '' + {self._get_fully_qualified_name(_FEATURE_VIEW_ENTITY_TAG)} = '{entities}', + {self._get_fully_qualified_name(_FEATURE_VIEW_TS_COL_TAG)} = '{timestamp_col}', + {self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)} = '' ) WAREHOUSE = {warehouse} AS {feature_view.query} @@ -1103,7 +1096,7 @@ def _create_dynamic_table( self._session.sql( f""" ALTER TASK {fully_qualified_name} - SET TAG {self._get_fully_qualified_name(FEATURE_STORE_OBJECT_TAG)} = '' + SET TAG {self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)} = '' """ ).collect(statement_params=self._telemetry_stmp) self._session.sql(f"ALTER TASK {fully_qualified_name} RESUME").collect( @@ -1128,9 +1121,9 @@ def _create_dynamic_table( ) if found_dts[0]["refresh_mode"] != "INCREMENTAL": warnings.warn( - f"Dynamic table: `{fully_qualified_name}` will not refresh in INCREMENTAL mode. " - + "It will likely incurr bigger computation cost. " - + f"The reason is: {found_dts[0]['refresh_mode_reason']}", + "Your pipeline won't be incrementally refreshed due to: " + + f"\"{found_dts[0]['refresh_mode_reason']}\". " + + "It will likely incurr higher cost.", stacklevel=2, category=UserWarning, ) @@ -1155,7 +1148,7 @@ def _dump_dataset( self._session.sql( f"""CREATE TABLE IF NOT EXISTS {fully_qualified_name} ({schema}) CLUSTER BY ({', '.join(join_keys)}) - TAG ({self._get_fully_qualified_name(FEATURE_STORE_OBJECT_TAG)} = '') + TAG ({self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)} = '') """ ).collect(block=True, statement_params=self._telemetry_stmp) except Exception as e: @@ -1187,15 +1180,6 @@ def _dump_dataset( original_exception=RuntimeError(f"Failed to create dataset {fully_qualified_name} with merge: {e}."), ) from e - def _validate_version_identifier(self, version: SqlIdentifier) -> None: - if FEATURE_VIEW_NAME_DELIMITER in version: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - f"Version identifier `{version}` contains invalid character `{FEATURE_VIEW_NAME_DELIMITER}`." - ), - ) - def _validate_entity_exists(self, name: SqlIdentifier) -> bool: full_entity_tag_name = self._get_entity_name(name) found_rows = self._find_object("TAGS", full_entity_tag_name) @@ -1232,6 +1216,9 @@ def _join_features( ), ) + if self._asof_join_enabled is None: + self._asof_join_enabled = self._is_asof_join_enabled() + # TODO: leverage Snowpark dataframe for more concise syntax once it supports AsOfJoin query = spine_df.queries["queries"][0] layer = 0 @@ -1248,7 +1235,7 @@ def _join_features( join_table_name = f.fully_qualified_name() if spine_timestamp_col is not None and f.timestamp_col is not None: - if _ENABLE_ASOF_JOIN: + if self._asof_join_enabled: query = f""" SELECT l_{layer}.*, @@ -1268,7 +1255,6 @@ def _join_features( s_ts_col=spine_timestamp_col, f_df=f.feature_df, f_table_name=join_table_name, - f_cols=cols, f_ts_col=f.timestamp_col, join_keys=join_keys, ) @@ -1288,6 +1274,30 @@ def _join_features( return self._session.sql(query), join_keys + def _is_asof_join_enabled(self) -> bool: + result = None + try: + result = self._session.sql( + """ + WITH + spine AS ( + SELECT "ID", "TS" FROM ( SELECT $1 AS "ID", $2 AS "TS" FROM VALUES (1 :: INT, 100 :: INT)) + ), + feature AS ( + SELECT "ID", "TS" FROM ( SELECT $1 AS "ID", $2 AS "TS" FROM VALUES (1 :: INT, 100 :: INT)) + ) + SELECT * FROM spine + ASOF JOIN feature + MATCH_CONDITION ( spine.ts >= feature.ts ) + ON spine.id = feature.id; + """ + ).collect() + except SnowparkSQLException: + return False + return result is not None and len(result) == 1 + + # Visualize how the query works: + # https://docs.google.com/presentation/d/15fT2F34OFp5RPv2-hZirHw6wliPRVRlPHvoCMIB00oY/edit#slide=id.g25ab53e6c8d_0_32 def _composed_union_window_join_query( self, layer: int, @@ -1295,7 +1305,6 @@ def _composed_union_window_join_query( s_ts_col: SqlIdentifier, f_df: DataFrame, f_table_name: str, - f_cols: List[SqlIdentifier], f_ts_col: SqlIdentifier, join_keys: List[SqlIdentifier], ) -> str: @@ -1352,7 +1361,7 @@ def join_cols(cols: List[SqlIdentifier], end_comma: bool, rename: bool, prefix: + f""" ,last_value({f_col}) IGNORE NULLS OVER ( PARTITION BY {join_keys_str} - ORDER BY {s_ts_col} + ORDER BY {s_ts_col} ASC, {temp_prefix}src ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW ) AS {temp_prefix}{f_col}""" ) @@ -1383,7 +1392,7 @@ def join_cols(cols: List[SqlIdentifier], end_comma: bool, rename: bool, prefix: return complete_query def _get_entity_name(self, raw_name: SqlIdentifier) -> SqlIdentifier: - return SqlIdentifier(identifier.concat_names([ENTITY_TAG_PREFIX, raw_name])) + return SqlIdentifier(identifier.concat_names([_ENTITY_TAG_PREFIX, raw_name])) def _get_fully_qualified_name(self, name: Union[SqlIdentifier, str]) -> str: return f"{self._config.full_schema_path}.{name}" @@ -1419,7 +1428,7 @@ def _update_feature_view_status(self, feature_view: FeatureView, operation: str) ) from e feature_view._status = self.get_feature_view(feature_view.name, feature_view.version).status - logger.info(f"Successfully {operation} FeatureView {feature_view.name} with version {feature_view.version}.") + logger.info(f"Successfully {operation} FeatureView {feature_view.name}/{feature_view.version}.") return feature_view def _find_feature_views( @@ -1447,7 +1456,7 @@ def _find_feature_views( ) ) WHERE LEVEL = 'TABLE' - AND TAG_NAME = '{FEATURE_VIEW_ENTITY_TAG}' + AND TAG_NAME = '{_FEATURE_VIEW_ENTITY_TAG}' """ for fv_name in all_fv_names ] @@ -1461,27 +1470,25 @@ def _find_feature_views( outputs = [] for r in results: if entity_name == SqlIdentifier(r["TAG_VALUE"], case_sensitive=True): - fv_name, version = to_sql_identifiers( - r["OBJECT_NAME"].split(FEATURE_VIEW_NAME_DELIMITER), case_sensitive=True - ) + fv_name, version = r["OBJECT_NAME"].split(_FEATURE_VIEW_NAME_DELIMITER) + fv_name = SqlIdentifier(fv_name, case_sensitive=True) if feature_view_name is not None: if fv_name == feature_view_name: outputs.append(self.get_feature_view(fv_name, version)) else: continue else: - outputs.append(self.get_feature_view(fv_name.identifier(), version.identifier())) + outputs.append(self.get_feature_view(fv_name.identifier(), version)) return outputs def _compose_feature_view(self, row: Row) -> FeatureView: - name, version = to_sql_identifiers(row["name"].split(FEATURE_VIEW_NAME_DELIMITER), case_sensitive=True) - m = re.match(DT_OR_VIEW_QUERY_PATTERN, row["text"]) + name, version = row["name"].split(_FEATURE_VIEW_NAME_DELIMITER) + name = SqlIdentifier(name, case_sensitive=True) + m = re.match(_DT_OR_VIEW_QUERY_PATTERN, row["text"]) if m is None: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INTERNAL_SNOWML_ERROR, - original_exception=RuntimeError( - f"Failed to parse query text for FeatureView {name} with version {version}: {row}." - ), + original_exception=RuntimeError(f"Failed to parse query text for FeatureView {name}/{version}: {row}."), ) if m.group("obj_type") == "DYNAMIC TABLE": @@ -1489,9 +1496,9 @@ def _compose_feature_view(self, row: Row) -> FeatureView: df = self._session.sql(query) desc = m.group("comment") entity_names = m.group("entities") - entities = [self.get_entity(n) for n in entity_names.split(FEATURE_VIEW_ENTITY_TAG_DELIMITER)] + entities = [self.get_entity(n) for n in entity_names.split(_FEATURE_VIEW_ENTITY_TAG_DELIMITER)] ts_col = m.group("ts_col") - timestamp_col = ts_col if ts_col != TIMESTAMP_COL_PLACEHOLDER else None + timestamp_col = ts_col if ts_col != _TIMESTAMP_COL_PLACEHOLDER else None fv = FeatureView._construct_feature_view( name=name, @@ -1517,9 +1524,9 @@ def _compose_feature_view(self, row: Row) -> FeatureView: df = self._session.sql(query) desc = m.group("comment") entity_names = m.group("entities") - entities = [self.get_entity(n) for n in entity_names.split(FEATURE_VIEW_ENTITY_TAG_DELIMITER)] + entities = [self.get_entity(n) for n in entity_names.split(_FEATURE_VIEW_ENTITY_TAG_DELIMITER)] ts_col = m.group("ts_col") - timestamp_col = ts_col if ts_col != TIMESTAMP_COL_PLACEHOLDER else None + timestamp_col = ts_col if ts_col != _TIMESTAMP_COL_PLACEHOLDER else None fv = FeatureView._construct_feature_view( name=name, @@ -1597,7 +1604,7 @@ def _find_object( '{obj_domain}' ) ) - WHERE TAG_NAME = '{FEATURE_STORE_OBJECT_TAG}' + WHERE TAG_NAME = '{_FEATURE_STORE_OBJECT_TAG}' AND TAG_SCHEMA = '{self._config.schema.resolved()}' """ for row in all_rows @@ -1626,7 +1633,7 @@ def _load_serialized_feature_objects( results: List[Union[FeatureView, FeatureViewSlice]] = [] for obj in serialized_feature_objs: try: - obj_type = json.loads(obj)[FEATURE_OBJ_TYPE] + obj_type = json.loads(obj)[_FEATURE_OBJ_TYPE] except Exception as e: raise ValueError(f"Malformed serialized feature object: {obj}") from e @@ -1637,3 +1644,16 @@ def _load_serialized_feature_objects( else: raise ValueError(f"Unsupported feature object type: {obj_type}") return results + + def _exclude_columns(self, df: DataFrame, exclude_columns: List[str]) -> DataFrame: + exclude_columns = to_sql_identifiers(exclude_columns) # type: ignore[assignment] + df_cols = to_sql_identifiers(df.columns) + for col in exclude_columns: + if col not in df_cols: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_ARGUMENT, + original_exception=ValueError( + f"{col} in exclude_columns not exists in dataframe columns: {df_cols}" + ), + ) + return cast(DataFrame, df.drop(exclude_columns)) diff --git a/snowflake/ml/feature_store/feature_view.py b/snowflake/ml/feature_store/feature_view.py index 3672ab7e..ccceff88 100644 --- a/snowflake/ml/feature_store/feature_view.py +++ b/snowflake/ml/feature_store/feature_view.py @@ -1,11 +1,16 @@ from __future__ import annotations import json +import re from collections import OrderedDict from dataclasses import dataclass from enum import Enum from typing import Dict, List, Optional +from snowflake.ml._internal.exceptions import ( + error_codes, + exceptions as snowml_exceptions, +) from snowflake.ml._internal.utils.identifier import concat_names from snowflake.ml._internal.utils.sql_identifier import ( SqlIdentifier, @@ -21,9 +26,25 @@ _NumericType, ) -FEATURE_VIEW_NAME_DELIMITER = "$" -TIMESTAMP_COL_PLACEHOLDER = "FS_TIMESTAMP_COL_PLACEHOLDER_VAL" -FEATURE_OBJ_TYPE = "FEATURE_OBJ_TYPE" +_FEATURE_VIEW_NAME_DELIMITER = "$" +_TIMESTAMP_COL_PLACEHOLDER = "FS_TIMESTAMP_COL_PLACEHOLDER_VAL" +_FEATURE_OBJ_TYPE = "FEATURE_OBJ_TYPE" +_FEATURE_VIEW_VERSION_RE = re.compile("^([A-Za-z0-9_]*)$") + + +class FeatureViewVersion(str): + def __new__(cls, version: str) -> FeatureViewVersion: + if not _FEATURE_VIEW_VERSION_RE.match(version): + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_ARGUMENT, + original_exception=ValueError( + f"`{version}` is not a valid feature view version. Only letter, number and underscore is allowed." + ), + ) + return super().__new__(cls, version.upper()) + + def __init__(self, version: str) -> None: + return super().__init__() class FeatureViewStatus(Enum): @@ -52,16 +73,16 @@ def to_json(self) -> str: fvs_dict = { "feature_view_ref": self.feature_view_ref.to_json(), "names": self.names, - FEATURE_OBJ_TYPE: self.__class__.__name__, + _FEATURE_OBJ_TYPE: self.__class__.__name__, } return json.dumps(fvs_dict) @classmethod def from_json(cls, json_str: str, session: Session) -> FeatureViewSlice: json_dict = json.loads(json_str) - if FEATURE_OBJ_TYPE not in json_dict or json_dict[FEATURE_OBJ_TYPE] != cls.__name__: + if _FEATURE_OBJ_TYPE not in json_dict or json_dict[_FEATURE_OBJ_TYPE] != cls.__name__: raise ValueError(f"Invalid json str for {cls.__name__}: {json_str}") - del json_dict[FEATURE_OBJ_TYPE] + del json_dict[_FEATURE_OBJ_TYPE] json_dict["feature_view_ref"] = FeatureView.from_json(json_dict["feature_view_ref"], session) return cls(**json_dict) @@ -108,7 +129,7 @@ def __init__( ) self._desc: str = desc self._query: str = self._get_query() - self._version: Optional[SqlIdentifier] = None + self._version: Optional[FeatureViewVersion] = None self._status: FeatureViewStatus = FeatureViewStatus.DRAFT self._feature_desc: OrderedDict[SqlIdentifier, str] = OrderedDict((f, "") for f in self._get_feature_names()) self._refresh_freq: Optional[str] = refresh_freq @@ -211,7 +232,7 @@ def query(self) -> str: return self._query @property - def version(self) -> Optional[SqlIdentifier]: + def version(self) -> Optional[FeatureViewVersion]: return self._version @property @@ -280,9 +301,9 @@ def _get_query(self) -> str: return str(self._feature_df.queries["queries"][0]) def _validate(self) -> None: - if FEATURE_VIEW_NAME_DELIMITER in self._name: + if _FEATURE_VIEW_NAME_DELIMITER in self._name: raise ValueError( - f"FeatureView name `{self._name}` contains invalid character `{FEATURE_VIEW_NAME_DELIMITER}`." + f"FeatureView name `{self._name}` contains invalid character `{_FEATURE_VIEW_NAME_DELIMITER}`." ) unescaped_df_cols = to_sql_identifiers(self._feature_df.columns) @@ -295,8 +316,8 @@ def _validate(self) -> None: if self._timestamp_col is not None: ts_col = self._timestamp_col - if ts_col == SqlIdentifier(TIMESTAMP_COL_PLACEHOLDER): - raise ValueError(f"Invalid timestamp_col name, cannot be {TIMESTAMP_COL_PLACEHOLDER}.") + if ts_col == SqlIdentifier(_TIMESTAMP_COL_PLACEHOLDER): + raise ValueError(f"Invalid timestamp_col name, cannot be {_TIMESTAMP_COL_PLACEHOLDER}.") if ts_col not in to_sql_identifiers(self._feature_df.columns): raise ValueError(f"timestamp_col {ts_col} is not found in input dataframe.") @@ -364,13 +385,13 @@ def to_df(self, session: Session) -> DataFrame: def to_json(self) -> str: state_dict = self._to_dict() - state_dict[FEATURE_OBJ_TYPE] = self.__class__.__name__ + state_dict[_FEATURE_OBJ_TYPE] = self.__class__.__name__ return json.dumps(state_dict) @classmethod def from_json(cls, json_str: str, session: Session) -> FeatureView: json_dict = json.loads(json_str) - if FEATURE_OBJ_TYPE not in json_dict or json_dict[FEATURE_OBJ_TYPE] != cls.__name__: + if _FEATURE_OBJ_TYPE not in json_dict or json_dict[_FEATURE_OBJ_TYPE] != cls.__name__: raise ValueError(f"Invalid json str for {cls.__name__}: {json_str}") return FeatureView._construct_feature_view( @@ -391,13 +412,13 @@ def from_json(cls, json_str: str, session: Session) -> FeatureView: ) @staticmethod - def _get_physical_name(fv_name: SqlIdentifier, fv_version: SqlIdentifier) -> SqlIdentifier: + def _get_physical_name(fv_name: SqlIdentifier, fv_version: FeatureViewVersion) -> SqlIdentifier: return SqlIdentifier( concat_names( [ - fv_name, - FEATURE_VIEW_NAME_DELIMITER, - fv_version, + str(fv_name), + _FEATURE_VIEW_NAME_DELIMITER, + str(fv_version), ] ) ) @@ -426,7 +447,7 @@ def _construct_feature_view( timestamp_col=timestamp_col, desc=desc, ) - fv._version = SqlIdentifier(version) if version is not None else None + fv._version = FeatureViewVersion(version) if version is not None else None fv._status = status fv._refresh_freq = refresh_freq fv._database = SqlIdentifier(database) if database is not None else None diff --git a/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.ipynb b/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.ipynb index e79daf24..6098577e 100644 --- a/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.ipynb +++ b/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.ipynb @@ -5,9 +5,9 @@ "id": "0bb54abc", "metadata": {}, "source": [ - "Version: 0.3.0\n", - "\n", - "Updated date: 12/02/2023" + "- snowflake-ml-python version: 1.1.0\n", + "- Feature Store PrPr Version: 0.3.1\n", + "- Updated date: 12/11/2023" ] }, { @@ -571,7 +571,6 @@ "outputs": [], "source": [ "from snowflake.ml.registry import model_registry\n", - "import time\n", "\n", "registry = model_registry.ModelRegistry(\n", " session=session, \n", @@ -595,10 +594,13 @@ "metadata": {}, "outputs": [], "source": [ + "DATASET_NAME = \"MY_DATASET\"\n", + "DATASET_VERSION = \"V1\"\n", + "\n", "my_dataset = registry.log_artifact(\n", " artifact=training_data,\n", - " name=\"MY_COOL_DATASET\",\n", - " version=\"V1\",\n", + " name=DATASET_NAME,\n", + " version=DATASET_VERSION,\n", ")" ] }, @@ -617,7 +619,7 @@ "metadata": {}, "outputs": [], "source": [ - "model_name = f\"MY_RANDOM_FOREST_REGRESSOR_{time.time()}\"\n", + "model_name = \"MY_MODEL\"\n", "\n", "model_ref = registry.log_model(\n", " model_name=model_name,\n", @@ -649,8 +651,8 @@ "from snowflake.ml.dataset.dataset import Dataset\n", "\n", "registered_dataset = registry.get_artifact(\n", - " my_dataset.name, \n", - " my_dataset.version)\n", + " DATASET_NAME, \n", + " DATASET_VERSION)\n", "test_df = spine_df.limit(3).select(\"WINE_ID\")\n", "\n", "enriched_df = fs.retrieve_feature_values(\n", diff --git a/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.ipynb b/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.ipynb index cd65c70d..19b163c8 100644 --- a/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.ipynb +++ b/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.ipynb @@ -5,9 +5,9 @@ "id": "4f029c96", "metadata": {}, "source": [ - "Notebook version: 0.3.0\n", - "\n", - "Updated date: 12/02/2023" + "- snowflake-ml-python version: 1.1.0\n", + "- Feature Store PrPr version: 0.3.2\n", + "- Updated date: 12/11/2023" ] }, { @@ -600,25 +600,6 @@ "Now let's predict with the model and the feature values retrieved from feature store. " ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "e41f0552", - "metadata": {}, - "outputs": [], - "source": [ - "# Prepare some source prediction data \n", - "\n", - "pred_df = training_data.df.to_pandas().sample(3, random_state=996)[ \n", - " ['PULOCATIONID', 'DOLOCATIONID', 'PICKUP_TS']] \n", - "pred_df = session.create_dataframe(pred_df) \n", - "pred_df = pred_df.select( \n", - " 'PULOCATIONID', \n", - " 'DOLOCATIONID', \n", - " F.cast(pred_df.PICKUP_TS / 1000000, TimestampType())\n", - " .alias('PICKUP_TS'))" - ] - }, { "cell_type": "code", "execution_count": null, @@ -626,6 +607,9 @@ "metadata": {}, "outputs": [], "source": [ + "pred_df = training_data.df.sample(0.01).select(\n", + " ['PULOCATIONID', 'DOLOCATIONID', 'PICKUP_TS'])\n", + "\n", "enriched_df = fs.retrieve_feature_values(\n", " spine_df=pred_df, \n", " features=training_data.load_features(), \n", @@ -653,7 +637,6 @@ "outputs": [], "source": [ "from snowflake.ml.registry import model_registry\n", - "import time\n", "\n", "registry = model_registry.ModelRegistry(\n", " session=session, \n", @@ -677,10 +660,13 @@ "metadata": {}, "outputs": [], "source": [ + "DATASET_NAME = \"MY_DATASET\"\n", + "DATASET_VERSION = \"V1\"\n", + "\n", "my_dataset = registry.log_artifact(\n", " artifact=training_data,\n", - " name=\"MY_COOL_DATASET\",\n", - " version=\"V1\",\n", + " name=DATASET_NAME,\n", + " version=DATASET_VERSION,\n", ")" ] }, @@ -699,7 +685,7 @@ "metadata": {}, "outputs": [], "source": [ - "model_name = f\"MY_MODEL_{time.time()}\"\n", + "model_name = \"MY_MODEL\"\n", "\n", "model_ref = registry.log_model(\n", " model_name=model_name,\n", @@ -729,8 +715,8 @@ "from snowflake.ml.dataset.dataset import Dataset\n", "\n", "registered_dataset = registry.get_artifact(\n", - " my_dataset.name, \n", - " my_dataset.version)\n", + " DATASET_NAME, \n", + " DATASET_VERSION)\n", "\n", "enriched_df = fs.retrieve_feature_values(\n", " spine_df=pred_df, \n", diff --git a/snowflake/ml/feature_store/notebooks/internal_demo/Time_Series_Feature_Demo.ipynb b/snowflake/ml/feature_store/notebooks/internal_demo/Time_Series_Feature_Demo.ipynb index 8777e5b6..8e9fc82f 100644 --- a/snowflake/ml/feature_store/notebooks/internal_demo/Time_Series_Feature_Demo.ipynb +++ b/snowflake/ml/feature_store/notebooks/internal_demo/Time_Series_Feature_Demo.ipynb @@ -51,10 +51,31 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 32, "id": "da1a922d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "# Scale cell width with the browser window to accommodate .show() commands for wider tables.\n", "from IPython.display import display, HTML\n", @@ -74,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 31, "id": "11935b50", "metadata": {}, "outputs": [], @@ -123,7 +144,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 33, "id": "f39a3f77", "metadata": {}, "outputs": [], @@ -139,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 34, "id": "e665bd41", "metadata": {}, "outputs": [], @@ -149,10 +170,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 35, "id": "75bfcfd1", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-------------------------------------------------------------------------------------------------------------------------------------\n", + "|\"TRIP_DISTANCE\" |\"FARE_AMOUNT\" |\"PASSENGER_COUNT\" |\"PULOCATIONID\" |\"DOLOCATIONID\" |\"PICKUP_TS\" |\"DROPOFF_TS\" |\n", + "-------------------------------------------------------------------------------------------------------------------------------------\n", + "|3.2 |14.0 |1 |48 |262 |2016-01-01 00:12:22 |2016-01-01 00:29:14 |\n", + "|1.0 |9.5 |2 |162 |48 |2016-01-01 00:41:31 |2016-01-01 00:55:10 |\n", + "|0.9 |6.0 |1 |246 |90 |2016-01-01 00:53:37 |2016-01-01 00:59:57 |\n", + "|0.8 |5.0 |1 |170 |162 |2016-01-01 00:13:28 |2016-01-01 00:18:07 |\n", + "|1.8 |11.0 |1 |161 |140 |2016-01-01 00:33:04 |2016-01-01 00:47:14 |\n", + "|2.3 |11.0 |1 |141 |137 |2016-01-01 00:49:47 |2016-01-01 01:04:44 |\n", + "|13.8 |43.0 |1 |100 |53 |2016-01-01 00:41:58 |2016-01-01 01:22:06 |\n", + "|3.46 |20.0 |5 |48 |79 |2016-01-01 00:25:28 |2016-01-01 00:55:46 |\n", + "|0.83 |5.5 |4 |79 |107 |2016-01-01 00:56:57 |2016-01-01 01:02:24 |\n", + "|0.87 |7.0 |1 |164 |164 |2016-01-01 00:10:08 |2016-01-01 00:23:05 |\n", + "-------------------------------------------------------------------------------------------------------------------------------------\n", + "\n" + ] + } + ], "source": [ "source_df = session.table(\"SNOWML_FEATURE_STORE_TEST_DB.TEST_DATASET.yellow_tripdata_2016_01\")\n", "\n", @@ -183,7 +226,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 36, "id": "6c37a635", "metadata": {}, "outputs": [], @@ -212,10 +255,64 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 37, "id": "70609920", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
NAMEJOIN_KEYSDESC
0TRIP_DROPOFFDOLOCATIONID
1TRIP_PICKUPPULOCATIONID
\n", + "
" + ], + "text/plain": [ + " NAME JOIN_KEYS DESC\n", + "0 TRIP_DROPOFF DOLOCATIONID \n", + "1 TRIP_PICKUP PULOCATIONID " + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "trip_pickup = Entity(name=\"trip_pickup\", join_keys=[\"PULOCATIONID\"])\n", "trip_dropoff = Entity(name=\"trip_dropoff\", join_keys=[\"DOLOCATIONID\"])\n", @@ -251,7 +348,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 38, "id": "995b4bcd", "metadata": {}, "outputs": [], @@ -263,6 +360,7 @@ " packages=[\"numpy\", \"pandas\", \"pytimeparse\"],\n", " replace=True,\n", " session=session,\n", + " immutable=True,\n", ")\n", "def vec_window_end_compute(\n", " x: T.PandasSeries[datetime.datetime],\n", @@ -289,10 +387,47 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 39, "id": "7d0c4339", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------------------------\n", + "|\"PULOCATIONID\" |\"TS\" |\"MEAN_FARE_2_HR\" |\"MEAN_FARE_5_HR\" |\n", + "----------------------------------------------------------------------------------\n", + "|98 |2016-01-01 04:45:00 |26.0 |26.0 |\n", + "|98 |2016-01-01 14:00:00 |19.75 |19.75 |\n", + "|98 |2016-01-02 22:30:00 |156.5 |156.5 |\n", + "|225 |2016-01-01 00:30:00 |9.6 |9.6 |\n", + "|225 |2016-01-01 00:45:00 |11.833333333333334 |11.833333333333334 |\n", + "|225 |2016-01-01 01:00:00 |15.045454545454545 |15.045454545454545 |\n", + "|225 |2016-01-01 01:15:00 |13.928571428571429 |13.928571428571429 |\n", + "|225 |2016-01-01 01:30:00 |12.717948717948717 |12.717948717948717 |\n", + "|225 |2016-01-01 01:45:00 |13.169811320754716 |13.169811320754716 |\n", + "|225 |2016-01-01 02:00:00 |12.607142857142858 |12.607142857142858 |\n", + "----------------------------------------------------------------------------------\n", + "\n", + "--------------------------------------------------------------------------------\n", + "|\"DOLOCATIONID\" |\"TS\" |\"COUNT_TRIP_2_HR\" |\"COUNT_TRIP_5_HR\" |\n", + "--------------------------------------------------------------------------------\n", + "|227 |2016-01-01 00:30:00 |2 |2 |\n", + "|227 |2016-01-01 00:45:00 |5 |5 |\n", + "|227 |2016-01-01 01:00:00 |12 |12 |\n", + "|227 |2016-01-01 01:15:00 |16 |16 |\n", + "|227 |2016-01-01 01:30:00 |21 |21 |\n", + "|227 |2016-01-01 01:45:00 |25 |25 |\n", + "|227 |2016-01-01 02:00:00 |33 |33 |\n", + "|227 |2016-01-01 02:15:00 |43 |43 |\n", + "|227 |2016-01-01 02:30:00 |48 |50 |\n", + "|227 |2016-01-01 02:45:00 |53 |58 |\n", + "--------------------------------------------------------------------------------\n", + "\n" + ] + } + ], "source": [ "from snowflake.snowpark import Window\n", "from snowflake.snowpark.functions import col\n", @@ -370,24 +505,45 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 40, "id": "f0cd2075", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/tbao/Desktop/Snowflake/snowml/snowflake/ml/feature_store/feature_store.py:334: UserWarning: Your pipeline won't be incrementally refreshed due to: \"Query contains the function 'VEC_WINDOW_END', but change tracking is not supported on queries with non-IMMUTABLE user-defined functions.\". It will likely incurr higher cost.\n", + " self._create_dynamic_table(\n" + ] + } + ], "source": [ - "pickup_fv = FeatureView(name=\"trip_pickup_time_series_features\", entities=[trip_pickup], feature_df=pickup_df, timestamp_col=\"ts\")\n", - "pickup_fv = fs.register_feature_view(feature_view=pickup_fv, version=\"v1\", refresh_freq=\"1 minute\", block=True)" + "pickup_fv = FeatureView(\n", + " name=\"trip_pickup_features\", \n", + " entities=[trip_pickup], \n", + " feature_df=pickup_df, \n", + " timestamp_col=\"ts\",\n", + " refresh_freq=\"1 minute\",\n", + ").attach_feature_desc({\"MEAN_FARE_2_HR\": \"avg fare over past 2hr\"})\n", + "pickup_fv = fs.register_feature_view(feature_view=pickup_fv, version=\"v1\", block=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 42, "id": "d8960b0e", "metadata": {}, "outputs": [], "source": [ - "dropoff_fv = FeatureView(name=\"trip_dropoff_time_series_features\", entities=[trip_dropoff], feature_df=dropoff_df, timestamp_col=\"ts\")\n", - "fs.register_feature_view(feature_view=dropoff_fv, version=\"v1\", refresh_freq=\"1 minute\", block=True)" + "dropoff_fv = FeatureView(\n", + " name=\"trip_dropoff_features\", \n", + " entities=[trip_dropoff], \n", + " feature_df=dropoff_df, \n", + " timestamp_col=\"ts\",\n", + " refresh_freq=\"1 minute\",\n", + ").attach_feature_desc({\"COUNT_TRIP_2_HR\": \"trip count over past 2hr\"})\n", + "dropoff_fv = fs.register_feature_view(feature_view=dropoff_fv, version=\"v1\", block=True)" ] }, { @@ -403,12 +559,42 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 43, "id": "bc93de79", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------------------------------------------------------------------------------------------------------------------\n", + "|\"NAME\" |\"VERSION\" |\"ENTITIES\" |\"FEATURE_DESC\" |\n", + "---------------------------------------------------------------------------------------------------------------------\n", + "|TRIP_DROPOFF_FEATURES |V1 |[ |{ |\n", + "| | | { | \"COUNT_TRIP_2_HR\": \"trip count over past 2hr\", |\n", + "| | | \"desc\": \"\", | \"COUNT_TRIP_5_HR\": \"\" |\n", + "| | | \"join_keys\": [ |} |\n", + "| | | \"DOLOCATIONID\" | |\n", + "| | | ], | |\n", + "| | | \"name\": \"TRIP_DROPOFF\" | |\n", + "| | | } | |\n", + "| | |] | |\n", + "|TRIP_PICKUP_FEATURES |V1 |[ |{ |\n", + "| | | { | \"MEAN_FARE_2_HR\": \"\", |\n", + "| | | \"desc\": \"\", | \"MEAN_FARE_5_HR\": \"\" |\n", + "| | | \"join_keys\": [ |} |\n", + "| | | \"PULOCATIONID\" | |\n", + "| | | ], | |\n", + "| | | \"name\": \"TRIP_PICKUP\" | |\n", + "| | | } | |\n", + "| | |] | |\n", + "---------------------------------------------------------------------------------------------------------------------\n", + "\n" + ] + } + ], "source": [ - "fs.list_feature_views(entity_name=\"trip_pickup\").select([\"NAME\", \"VERSION\", \"ENTITIES\", \"FEATURE_DESC\"]).show()" + "fs.list_feature_views().select([\"NAME\", \"VERSION\", \"ENTITIES\", \"FEATURE_DESC\"]).show()" ] }, { @@ -422,10 +608,43 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 44, "id": "a4e3376c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-----------------------------------------------------------------------------------------------------------------------------------------------------------\n", + "|\"DOLOCATIONID\" |\"PICKUP_TS\" |\"PULOCATIONID\" |\"FARE_AMOUNT\" |\"MEAN_FARE_2_HR\" |\"MEAN_FARE_5_HR\" |\"COUNT_TRIP_2_HR\" |\"COUNT_TRIP_5_HR\" |\n", + "-----------------------------------------------------------------------------------------------------------------------------------------------------------\n", + "|262 |2016-01-01 00:12:22 |48 |14.0 |NULL |NULL |NULL |NULL |\n", + "|48 |2016-01-01 00:41:31 |162 |9.5 |11.451428571428572 |11.451428571428572 |137 |137 |\n", + "|90 |2016-01-01 00:53:37 |246 |6.0 |13.765232974910393 |13.765232974910393 |214 |214 |\n", + "|162 |2016-01-01 00:13:28 |170 |5.0 |NULL |NULL |NULL |NULL |\n", + "|140 |2016-01-01 00:33:04 |161 |11.0 |13.203869047619047 |13.203869047619047 |83 |83 |\n", + "|137 |2016-01-01 00:49:47 |141 |11.0 |10.352534562211982 |10.352534562211982 |244 |244 |\n", + "|53 |2016-01-01 00:41:58 |100 |43.0 |15.816091954022989 |15.816091954022989 |NULL |NULL |\n", + "|79 |2016-01-01 00:25:28 |48 |20.0 |15.685714285714285 |15.685714285714285 |43 |43 |\n", + "|79 |2016-01-01 00:25:28 |48 |20.0 |15.685714285714285 |15.685714285714285 |43 |43 |\n", + "|79 |2016-01-01 00:25:28 |48 |20.0 |15.685714285714285 |15.685714285714285 |43 |43 |\n", + "-----------------------------------------------------------------------------------------------------------------------------------------------------------\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "{'queries': ['SELECT * FROM FS_TIME_SERIES_EXAMPLE.AWESOME_FS.yellow_tripdata_2016_01_training_data_2023_12_12_14_10_32'],\n", + " 'post_actions': []}" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "spine_df = source_df.select([\"PULOCATIONID\", \"DOLOCATIONID\", \"PICKUP_TS\", \"FARE_AMOUNT\"])\n", "training_data = fs.generate_dataset(\n", @@ -434,18 +653,115 @@ " materialized_table=\"yellow_tripdata_2016_01_training_data\",\n", " spine_timestamp_col=\"PICKUP_TS\",\n", " spine_label_cols = [\"FARE_AMOUNT\"]\n", - ")\n", - "\n", - "training_data.df.show()\n", - "training_data.df.queries" + ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 45, "id": "6bced5e5", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
DOLOCATIONIDPULOCATIONIDMEAN_FARE_2_HRMEAN_FARE_5_HRCOUNT_TRIP_2_HRCOUNT_TRIP_5_HR
359595902498.8069859.258179404.0995.0
3605621707910.24211110.517555821.02251.0
681540501079.4160969.226157394.0956.0
9510794815110.30851110.2789401289.03380.0
477164792499.5620169.5541241124.02827.0
\n", + "
" + ], + "text/plain": [ + " DOLOCATIONID PULOCATIONID MEAN_FARE_2_HR MEAN_FARE_5_HR \\\n", + "359595 90 249 8.806985 9.258179 \n", + "360562 170 79 10.242111 10.517555 \n", + "681540 50 107 9.416096 9.226157 \n", + "951079 48 151 10.308511 10.278940 \n", + "477164 79 249 9.562016 9.554124 \n", + "\n", + " COUNT_TRIP_2_HR COUNT_TRIP_5_HR \n", + "359595 404.0 995.0 \n", + "360562 821.0 2251.0 \n", + "681540 394.0 956.0 \n", + "951079 1289.0 3380.0 \n", + "477164 1124.0 2827.0 " + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import numpy as np\n", "from sklearn.model_selection import train_test_split\n", @@ -460,10 +776,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 46, "id": "8f0e6902", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "31.1498254012058 %\n", + "Mean squared error: 91.42\n" + ] + } + ], "source": [ "from sklearn.impute import SimpleImputer\n", "from sklearn.pipeline import make_pipeline\n", @@ -491,10 +816,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 47, "id": "c57a81e2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:The database \"my_cool_registry\" already exists. Skipping creation.\n", + "WARNING:absl:The schema \"my_cool_registry\"._SYSTEM_MODEL_REGISTRY_SCHEMA already exists. Skipping creation.\n" + ] + } + ], "source": [ "from snowflake.ml.registry import model_registry, artifact\n", "import time\n", @@ -504,33 +838,53 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 48, "id": "4caab287", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:snowflake.snowpark:ModelRegistry.log_artifact() is in private preview since 1.0.10. Do not use it in production. \n" + ] + } + ], "source": [ - "artifact_ref = registry.log_artifact(\n", - " artifact_type=artifact.ArtifactType.DATASET,\n", - " artifact_name=\"MY_COOL_DATASET\",\n", - " artifact_spec=training_data.to_json(),\n", - " artifact_version=\"V1\",\n", + "DATASET_NAME = \"MY_DATASET\"\n", + "DATASET_VERSION = f\"V1_{time.time()}\"\n", + "\n", + "my_dataset = registry.log_artifact(\n", + " artifact=training_data,\n", + " name=DATASET_NAME,\n", + " version=DATASET_VERSION,\n", ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 50, "id": "a935926a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/anaconda3/envs/feature_store_demo/lib/python3.8/site-packages/snowflake/ml/model/model_signature.py:55: UserWarning: The sample input has 959457 rows, thus a truncation happened before inferring signature. This might cause inaccurate signature inference. If that happens, consider specifying signature manually.\n", + " warnings.warn(\n" + ] + } + ], "source": [ - "model_name = f\"my_model_{time.time()}\"\n", + "model_name = \"MY_MODEL\"\n", + "model_version = f\"V1_{time.time()}\"\n", "\n", "model_ref = registry.log_model(\n", " model_name=model_name,\n", - " model_version=\"v1\",\n", + " model_version=model_version,\n", " model=estimator,\n", - " artifacts=[artifact_ref],\n", + " artifacts=[my_dataset],\n", ")" ] }, @@ -545,37 +899,61 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "999a633d", - "metadata": {}, - "outputs": [], - "source": [ - "# Prepare some source prediction data\n", - "pred_df = training_pd.sample(3, random_state=996)[['PULOCATIONID', 'DOLOCATIONID', 'PICKUP_TS']]\n", - "pred_df = session.create_dataframe(pred_df)\n", - "pred_df = pred_df.select('PULOCATIONID', 'DOLOCATIONID', F.cast(pred_df.PICKUP_TS / 1000000, TimestampType()).alias('PICKUP_TS'))" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 51, "id": "0a18a5ea", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:snowflake.snowpark:FeatureStore.retrieve_feature_values() is in private preview since 1.0.8. Do not use it in production. \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 9.71003863 13.95909809 13.95909809 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268 11.70994268\n", + " 9.7267722 10.61725586 10.61725586 12.77915638 12.77915638 12.77915638\n", + " 15.89874567 13.24272348 13.24272348 13.24272348 13.24272348]\n" + ] + } + ], "source": [ - "# Enrich source prediction data with features\n", - "from snowflake.ml.dataset.dataset import Dataset\n", - "\n", - "registered_artifact = registry.get_artifact(\n", - " artifact_ref.name, \n", - " artifact_ref.version)\n", - "registered_dataset = Dataset.from_json(registered_artifact._spec, session)\n", + "pred_df = training_data.df.sample(0.01).select(\n", + " ['PULOCATIONID', 'DOLOCATIONID', 'PICKUP_TS'])\n", "\n", "enriched_df = fs.retrieve_feature_values(\n", " spine_df=pred_df, \n", - " features=registered_dataset.load_features(), \n", + " features=training_data.load_features(), \n", " spine_timestamp_col='PICKUP_TS'\n", - ").drop(['PICKUP_TS']).to_pandas()" + ").drop(['PICKUP_TS']).to_pandas()\n", + "\n", + "pred = estimator.predict(enriched_df)\n", + "print(pred)" ] }, { @@ -588,7 +966,7 @@ "model_ref = model_registry.ModelReference(\n", " registry=registry, \n", " model_name=model_name, \n", - " model_version=\"v1\"\n", + " model_version=model_version,\n", ").load_model()\n", "\n", "pred = model_ref.predict(enriched_df)\n", @@ -670,9 +1048,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python [conda env:feature_store_demo]", "language": "python", - "name": "python3" + "name": "conda-env-feature_store_demo-py" }, "language_info": { "codemirror_mode": { @@ -684,7 +1062,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.8.18" } }, "nbformat": 4, diff --git a/snowflake/ml/feature_store/tests/feature_store_case_sensitivity_test.py b/snowflake/ml/feature_store/tests/feature_store_case_sensitivity_test.py index 6a43c218..4788c65d 100644 --- a/snowflake/ml/feature_store/tests/feature_store_case_sensitivity_test.py +++ b/snowflake/ml/feature_store/tests/feature_store_case_sensitivity_test.py @@ -279,11 +279,11 @@ def test_join_keys_and_ts_col(self, equi_names: List[str], diff_names: List[str] @parameterized.parameters( [ ( - [("foo", "bar"), ("foo", "BAR"), ("FOO", "BAR"), ('"FOO"', '"BAR"')], - [('"foo"', "bar"), ("foo", '"bar"')], + [("foo", "bar"), ("foo", "BAR"), ("FOO", "BAR"), ('"FOO"', "BAR")], + [('"foo"', "bar")], ), ( - [('"abc"', "def"), ('"abc"', "DEF"), ('"abc"', '"DEF"')], + [('"abc"', "def"), ('"abc"', "DEF")], [("abc", "def")], ), ] @@ -314,7 +314,7 @@ def test_feature_view_names_and_versions_combination( for equi_full_name in equi_full_names: fv_name = equi_full_name[0] version = equi_full_name[1] - with self.assertRaisesRegex(ValueError, "FeatureView .* with version .* already exists"): + with self.assertRaisesRegex(ValueError, "FeatureView .* already exists"): fv = FeatureView(name=fv_name, entities=[e], feature_df=df) fs.register_feature_view(fv, version, block=True) @@ -383,6 +383,46 @@ def test_find_objects(self, equi_names: List[str], diff_names: List[str]) -> Non self.assertEqual(len(fs._find_object("SCHEMAS", SqlIdentifier(name))), 0) self._session.sql(f"DROP SCHEMA IF EXISTS {FS_INTEG_TEST_DB}.{equi_names[0]}").collect() + def test_feature_view_version(self) -> None: + current_schema = create_random_schema(self._session, "TEST_FEATURE_VIEW_VERSION") + fs = FeatureStore( + self._session, + FS_INTEG_TEST_DB, + current_schema, + default_warehouse=self._test_warehouse_name, + creation_mode=CreationMode.CREATE_IF_NOT_EXIST, + ) + self._active_fs.append(fs) + + df = self._session.create_dataframe([1, 2, 3], schema=["a"]) + e = Entity(name="MY_COOL_ENTITY", join_keys=["a"]) + fs.register_entity(e) + fv = FeatureView(name="MY_FV", entities=[e], feature_df=df) + + # 1: register with lowercase, get it back with lowercase/uppercase + fs.register_feature_view(fv, "a1", block=True) + fs.get_feature_view("MY_FV", "A1") + fs.get_feature_view("MY_FV", "a1") + + # 2: register with uppercase, get it back with lowercase/uppercase + fs.register_feature_view(fv, "B2", block=True) + fs.get_feature_view("MY_FV", "b2") + fs.get_feature_view("MY_FV", "B2") + + # 3. register with valid characters + fs.register_feature_view(fv, "V2_1", block=True) + fs.get_feature_view("MY_FV", "v2_1") + fs.register_feature_view(fv, "3", block=True) + fs.get_feature_view("MY_FV", "3") + + # 4: register with invalid characters + with self.assertRaisesRegex(ValueError, ".* is not a valid feature view version.*"): + fs.register_feature_view(fv, "abc$", block=True) + with self.assertRaisesRegex(ValueError, ".* is not a valid feature view version.*"): + fs.register_feature_view(fv, "abc#", block=True) + with self.assertRaisesRegex(ValueError, ".* is not a valid feature view version.*"): + fs.register_feature_view(fv, '"abc"', block=True) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/feature_store/tests/feature_store_object_test.py b/snowflake/ml/feature_store/tests/feature_store_object_test.py index ed7fefd9..425f3c54 100644 --- a/snowflake/ml/feature_store/tests/feature_store_object_test.py +++ b/snowflake/ml/feature_store/tests/feature_store_object_test.py @@ -9,8 +9,8 @@ FeatureViewStatus, ) from snowflake.ml.feature_store.feature_view import ( - FEATURE_OBJ_TYPE, - TIMESTAMP_COL_PLACEHOLDER, + _FEATURE_OBJ_TYPE, + _TIMESTAMP_COL_PLACEHOLDER, ) from snowflake.ml.utils.connection_params import SnowflakeLoginOptions from snowflake.snowpark import Session @@ -75,7 +75,7 @@ def test_invalid_timestamp_col(self) -> None: e = Entity(name="foo", join_keys=["a"]) with self.assertRaisesRegex(ValueError, "Invalid timestamp_col name.*"): - FeatureView(name="my_fv", entities=[e], feature_df=df, timestamp_col=TIMESTAMP_COL_PLACEHOLDER) + FeatureView(name="my_fv", entities=[e], feature_df=df, timestamp_col=_TIMESTAMP_COL_PLACEHOLDER) with self.assertRaisesRegex(ValueError, "timestamp_col.*is not found in input dataframe.*"): FeatureView(name="my_fv", entities=[e], feature_df=df, timestamp_col="d") @@ -97,7 +97,7 @@ def test_feature_view_serde(self) -> None: serialized = fv.to_json() self.assertEqual(fv, FeatureView.from_json(serialized, self._session)) - malformed = json.dumps({FEATURE_OBJ_TYPE: "foobar"}) + malformed = json.dumps({_FEATURE_OBJ_TYPE: "foobar"}) with self.assertRaisesRegex(ValueError, "Invalid json str for FeatureView.*"): FeatureView.from_json(malformed, self._session) @@ -109,7 +109,7 @@ def test_feature_view_slice_serde(self) -> None: serialized = fv_slice.to_json() self.assertEqual(fv_slice, FeatureViewSlice.from_json(serialized, self._session)) - malformed = json.dumps({FEATURE_OBJ_TYPE: "foobar"}) + malformed = json.dumps({_FEATURE_OBJ_TYPE: "foobar"}) with self.assertRaisesRegex(ValueError, "Invalid json str for FeatureViewSlice.*"): FeatureViewSlice.from_json(malformed, self._session) diff --git a/snowflake/ml/feature_store/tests/feature_store_test.py b/snowflake/ml/feature_store/tests/feature_store_test.py index 5a345a30..63859b01 100644 --- a/snowflake/ml/feature_store/tests/feature_store_test.py +++ b/snowflake/ml/feature_store/tests/feature_store_test.py @@ -23,10 +23,10 @@ FeatureViewStatus, ) from snowflake.ml.feature_store.feature_store import ( - ENTITY_TAG_PREFIX, - FEATURE_STORE_OBJECT_TAG, - FEATURE_VIEW_ENTITY_TAG, - FEATURE_VIEW_TS_COL_TAG, + _ENTITY_TAG_PREFIX, + _FEATURE_STORE_OBJECT_TAG, + _FEATURE_VIEW_ENTITY_TAG, + _FEATURE_VIEW_TS_COL_TAG, ) from snowflake.ml.utils.connection_params import SnowflakeLoginOptions from snowflake.snowpark import Session, exceptions as snowpark_exceptions @@ -697,6 +697,18 @@ def test_retrieve_feature_values(self) -> None: sort_cols=["ID"], ) + df = fs.retrieve_feature_values( + spine_df=spine_df, features=[fv1.slice(["name"]), fv2], exclude_columns=["NAME"] + ) + compare_dataframe( + actual_df=df.to_pandas(), + target_data={ + "ID": [1, 2], + "AGE": [20, 30], + }, + sort_cols=["ID"], + ) + # test retrieve_feature_values with serialized feature objects fv1_slice = fv1.slice(["name"]) dataset = fs.generate_dataset(spine_df, features=[fv1_slice, fv2]) @@ -1017,14 +1029,14 @@ def test_create_and_cleanup_tags(self) -> None: self.assertIsNotNone(fs) res = self._session.sql( - f"SHOW TAGS LIKE '{FEATURE_VIEW_ENTITY_TAG}' IN SCHEMA {fs._config.full_schema_path}" + f"SHOW TAGS LIKE '{_FEATURE_VIEW_ENTITY_TAG}' IN SCHEMA {fs._config.full_schema_path}" ).collect() self.assertEqual(len(res), 1) self._session.sql(f"DROP SCHEMA IF EXISTS {FS_INTEG_TEST_DB}.{current_schema}").collect() row_list = self._session.sql( - f"SHOW TAGS LIKE '{FEATURE_VIEW_ENTITY_TAG}' IN DATABASE {fs._config.database}" + f"SHOW TAGS LIKE '{_FEATURE_VIEW_ENTITY_TAG}' IN DATABASE {fs._config.database}" ).collect() for row in row_list: self.assertNotEqual(row["schema_name"], current_schema) @@ -1048,7 +1060,7 @@ def test_generate_dataset(self) -> None: refresh_freq="DOWNSTREAM", ) fv2 = fs.register_feature_view(feature_view=fv2, version="v1", block=True) - spine_df = self._session.create_dataframe([(1, 101)], schema=["id", "ts"]) + spine_df = self._session.create_dataframe([(1, 100), (1, 101)], schema=["id", "ts"]) # Generate dataset the first time ds1 = fs.generate_dataset( @@ -1060,13 +1072,13 @@ def test_generate_dataset(self) -> None: compare_dataframe( actual_df=ds1.df.to_pandas(), target_data={ - "ID": [1], - "TS": [101], - "NAME": ["jonh"], - "TITLE": ["boss"], - "AGE": [20], + "ID": [1, 1], + "TS": [100, 101], + "NAME": ["jonh", "jonh"], + "TITLE": ["boss", "boss"], + "AGE": [20, 20], }, - sort_cols=["ID"], + sort_cols=["ID", "TS"], ) self.assertEqual([fv1, fv2], fs.load_feature_views_from_dataset(ds1)) @@ -1081,13 +1093,13 @@ def test_generate_dataset(self) -> None: compare_dataframe( actual_df=ds2.df.to_pandas(), target_data={ - "ID": [1], - "TS": [101], - "NAME": ["jonh"], - "TITLE": ["boss"], - "AGE": [20], + "ID": [1, 1], + "TS": [100, 101], + "NAME": ["jonh", "jonh"], + "TITLE": ["boss", "boss"], + "AGE": [20, 20], }, - sort_cols=["ID"], + sort_cols=["ID", "TS"], ) # New data should properly appear @@ -1102,37 +1114,37 @@ def test_generate_dataset(self) -> None: compare_dataframe( actual_df=ds3.df.to_pandas(), target_data={ - "ID": [1, 2], - "TS": [101, 202], - "NAME": ["jonh", "porter"], - "TITLE": ["boss", "manager"], - "AGE": [20, 30], + "ID": [1, 1, 2], + "TS": [100, 101, 202], + "NAME": ["jonh", "jonh", "porter"], + "TITLE": ["boss", "boss", "manager"], + "AGE": [20, 20, 30], }, - sort_cols=["ID"], + sort_cols=["ID", "TS"], ) # Snapshot should remain the same compare_dataframe( actual_df=self._session.sql(f"SELECT * FROM {ds1.snapshot_table}").to_pandas(), target_data={ - "ID": [1], - "TS": [101], - "NAME": ["jonh"], - "TITLE": ["boss"], - "AGE": [20], + "ID": [1, 1], + "TS": [100, 101], + "NAME": ["jonh", "jonh"], + "TITLE": ["boss", "boss"], + "AGE": [20, 20], }, - sort_cols=["ID"], + sort_cols=["ID", "TS"], ) compare_dataframe( actual_df=self._session.sql(f"SELECT * FROM {ds3.snapshot_table}").to_pandas(), target_data={ - "ID": [1, 2], - "TS": [101, 202], - "NAME": ["jonh", "porter"], - "TITLE": ["boss", "manager"], - "AGE": [20, 30], + "ID": [1, 1, 2], + "TS": [100, 101, 202], + "NAME": ["jonh", "jonh", "porter"], + "TITLE": ["boss", "boss", "manager"], + "AGE": [20, 20, 30], }, - sort_cols=["ID"], + sort_cols=["ID", "TS"], ) # Generate dataset with exclude_columns and check both materialization and non-materialization path @@ -1246,10 +1258,10 @@ def check_fs_objects(expected_count: int) -> None: result = self._session.sql(f"SHOW TASKS LIKE 'FV$V1' IN SCHEMA {full_schema_path}").collect() self.assertEqual(len(result), expected_count) expected_tags = [ - FEATURE_VIEW_ENTITY_TAG, - FEATURE_VIEW_TS_COL_TAG, - FEATURE_STORE_OBJECT_TAG, - f"{ENTITY_TAG_PREFIX}foo", + _FEATURE_VIEW_ENTITY_TAG, + _FEATURE_VIEW_TS_COL_TAG, + _FEATURE_STORE_OBJECT_TAG, + f"{_ENTITY_TAG_PREFIX}foo", ] for tag in expected_tags: result = self._session.sql(f"SHOW TAGS LIKE '{tag}' in {full_schema_path}").collect() @@ -1295,7 +1307,7 @@ def minus_one(x: int) -> int: df = self._session.table(self._mock_table).select(call_udf(udf_name, col("id")).alias("uid"), "name") fv = FeatureView(name="fv", entities=[entity], feature_df=df, refresh_freq="1h") - with self.assertWarnsRegex(UserWarning, "Dynamic table: `.*` will not refresh in INCREMENTAL mode"): + with self.assertWarnsRegex(UserWarning, "Your pipeline won't be incrementally refreshed due to:"): fs.register_feature_view(feature_view=fv, version="V1") def test_switch_warehouse(self) -> None: diff --git a/snowflake/ml/model/BUILD.bazel b/snowflake/ml/model/BUILD.bazel index dbb1238e..1ce48b8f 100644 --- a/snowflake/ml/model/BUILD.bazel +++ b/snowflake/ml/model/BUILD.bazel @@ -26,6 +26,7 @@ py_library( "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/utils:formatting", "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model/_deploy_client/warehouse:infer_template", "//snowflake/ml/model/_signatures:base_handler", "//snowflake/ml/model/_signatures:builtins_handler", @@ -55,7 +56,6 @@ py_library( ":model_signature", ":type_hints", "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/model/_deploy_client/snowservice:deploy", "//snowflake/ml/model/_deploy_client/warehouse:deploy", "//snowflake/ml/model/_deploy_client/warehouse:infer_template", diff --git a/snowflake/ml/model/_api.py b/snowflake/ml/model/_api.py index 66a4580c..f69bfb34 100644 --- a/snowflake/ml/model/_api.py +++ b/snowflake/ml/model/_api.py @@ -7,7 +7,6 @@ error_codes, exceptions as snowml_exceptions, ) -from snowflake.ml._internal.utils import identifier from snowflake.ml.model import ( deploy_platforms, model_signature, @@ -188,6 +187,10 @@ def save_model( Returns: Model """ + if options is None: + options = {} + options["_legacy_save"] = True + m = model_composer.ModelComposer(session=session, stage_path=stage_path) m.save( name=name, @@ -481,6 +484,7 @@ def predict( # Get options INTERMEDIATE_OBJ_NAME = "tmp_result" sig = deployment["signature"] + identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED # Validate and prepare input if not isinstance(X, SnowparkDataFrame): @@ -491,7 +495,7 @@ def predict( else: keep_order = False output_with_input_features = True - model_signature._validate_snowpark_data(X, sig.inputs) + identifier_rule = model_signature._validate_snowpark_data(X, sig.inputs) s_df = X if statement_params: @@ -500,10 +504,14 @@ def predict( else: s_df._statement_params = statement_params # type: ignore[assignment] + original_cols = s_df.columns + # Infer and get intermediate result input_cols = [] - for col_name in s_df.columns: - literal_col_name = identifier.get_unescaped_names(col_name) + for input_feature in sig.inputs: + literal_col_name = input_feature.name + col_name = identifier_rule.get_identifier_from_feature(input_feature.name) + input_cols.extend( [ F.lit(literal_col_name), @@ -511,29 +519,28 @@ def predict( ] ) - # TODO[shchen]: SNOW-870032, For SnowService, external function name cannot be double quoted, else it results in - # external function no found. udf_name = deployment["name"] - output_obj = F.call_udf(udf_name, F.object_construct(*input_cols)) - - if output_with_input_features: - df_res = s_df.with_column(INTERMEDIATE_OBJ_NAME, output_obj) - else: - df_res = s_df.select(output_obj.alias(INTERMEDIATE_OBJ_NAME)) + output_obj = F.call_udf(udf_name, F.object_construct_keep_null(*input_cols)) + df_res = s_df.with_column(INTERMEDIATE_OBJ_NAME, output_obj) if keep_order: df_res = df_res.order_by( - F.col(INTERMEDIATE_OBJ_NAME)[infer_template._KEEP_ORDER_COL_NAME], + F.col(infer_template._KEEP_ORDER_COL_NAME), ascending=True, ) + if not output_with_input_features: + df_res = df_res.drop(*original_cols) + # Prepare the output output_cols = [] + output_col_names = [] for output_feature in sig.outputs: output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_feature.name].astype(output_feature.as_snowpark_type())) + output_col_names.append(identifier_rule.get_identifier_from_feature(output_feature.name)) df_res = df_res.with_columns( - [identifier.get_inferred_name(output_feature.name) for output_feature in sig.outputs], + output_col_names, output_cols, ).drop(INTERMEDIATE_OBJ_NAME) diff --git a/snowflake/ml/model/_client/BUILD.bazel b/snowflake/ml/model/_client/BUILD.bazel new file mode 100644 index 00000000..e69de29b diff --git a/snowflake/ml/model/_client/model/BUILD.bazel b/snowflake/ml/model/_client/model/BUILD.bazel new file mode 100644 index 00000000..1fc4f5a7 --- /dev/null +++ b/snowflake/ml/model/_client/model/BUILD.bazel @@ -0,0 +1,59 @@ +load("//bazel:py_rules.bzl", "py_library", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "model_impl", + srcs = ["model_impl.py"], + deps = [ + ":model_version_impl", + "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model/_client/ops:model_ops", + ], +) + +py_test( + name = "model_impl_test", + srcs = ["model_impl_test.py"], + deps = [ + ":model_impl", + ":model_version_impl", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/test_utils:mock_session", + ], +) + +py_library( + name = "model_version_impl", + srcs = ["model_version_impl.py"], + deps = [ + ":model_method_info", + "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model:model_signature", + "//snowflake/ml/model/_client/ops:model_ops", + ], +) + +py_test( + name = "model_version_impl_test", + srcs = ["model_version_impl_test.py"], + deps = [ + ":model_version_impl", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model:model_signature", + "//snowflake/ml/model/_client/ops:metadata_ops", + "//snowflake/ml/model/_client/ops:model_ops", + "//snowflake/ml/test_utils:mock_data_frame", + "//snowflake/ml/test_utils:mock_session", + ], +) + +py_library( + name = "model_method_info", + srcs = ["model_method_info.py"], + deps = [ + "//snowflake/ml/model:model_signature", + ], +) diff --git a/snowflake/ml/model/_client/model/model_impl.py b/snowflake/ml/model/_client/model/model_impl.py new file mode 100644 index 00000000..cf781d26 --- /dev/null +++ b/snowflake/ml/model/_client/model/model_impl.py @@ -0,0 +1,176 @@ +from typing import List, Union + +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.model import model_version_impl +from snowflake.ml.model._client.ops import model_ops + +_TELEMETRY_PROJECT = "MLOps" +_TELEMETRY_SUBPROJECT = "ModelManagement" + + +class Model: + """Model Object containing multiple versions. Mapping to SQL's MODEL object.""" + + _model_ops: model_ops.ModelOperator + _model_name: sql_identifier.SqlIdentifier + + def __init__(self) -> None: + raise RuntimeError("Model's initializer is not meant to be used. Use `get_model` from registry instead.") + + @classmethod + def _ref( + cls, + model_ops: model_ops.ModelOperator, + *, + model_name: sql_identifier.SqlIdentifier, + ) -> "Model": + self: "Model" = object.__new__(cls) + self._model_ops = model_ops + self._model_name = model_name + return self + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, Model): + return False + return self._model_ops == __value._model_ops and self._model_name == __value._model_name + + @property + def name(self) -> str: + return self._model_name.identifier() + + @property + def fully_qualified_name(self) -> str: + return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name) + + @property + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def description(self) -> str: + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + return self._model_ops.get_comment( + model_name=self._model_name, + statement_params=statement_params, + ) + + @description.setter + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def description(self, description: str) -> None: + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + return self._model_ops.set_comment( + comment=description, + model_name=self._model_name, + statement_params=statement_params, + ) + + @property + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def default(self) -> model_version_impl.ModelVersion: + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + class_name=self.__class__.__name__, + ) + default_version_name = self._model_ops._model_version_client.get_default_version( + model_name=self._model_name, statement_params=statement_params + ) + return self.version(default_version_name) + + @default.setter + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def default(self, version: Union[str, model_version_impl.ModelVersion]) -> None: + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + class_name=self.__class__.__name__, + ) + if isinstance(version, str): + version_name = sql_identifier.SqlIdentifier(version) + else: + version_name = version._version_name + self._model_ops._model_version_client.set_default_version( + model_name=self._model_name, version_name=version_name, statement_params=statement_params + ) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def version(self, version_name: str) -> model_version_impl.ModelVersion: + """Get a model version object given a version name in the model. + + Args: + version_name: The name of version + + Raises: + ValueError: Raised when the version requested does not exist. + + Returns: + The model version object. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + version_id = sql_identifier.SqlIdentifier(version_name) + if self._model_ops.validate_existence( + model_name=self._model_name, + version_name=version_id, + statement_params=statement_params, + ): + return model_version_impl.ModelVersion._ref( + self._model_ops, + model_name=self._model_name, + version_name=version_id, + ) + else: + raise ValueError( + f"Unable to find version with name {version_id.identifier()} in model {self.fully_qualified_name}" + ) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def list_versions(self) -> List[model_version_impl.ModelVersion]: + """List all versions in the model. + + Returns: + A List of ModelVersion object representing all versions in the model. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + version_names = self._model_ops.list_models_or_versions( + model_name=self._model_name, + statement_params=statement_params, + ) + return [ + model_version_impl.ModelVersion._ref( + self._model_ops, + model_name=self._model_name, + version_name=version_name, + ) + for version_name in version_names + ] + + def delete_version(self, version_name: str) -> None: + raise NotImplementedError("Deleting version has not been supported yet.") diff --git a/snowflake/ml/model/_client/model/model_impl_test.py b/snowflake/ml/model/_client/model/model_impl_test.py new file mode 100644 index 00000000..1fca7576 --- /dev/null +++ b/snowflake/ml/model/_client/model/model_impl_test.py @@ -0,0 +1,133 @@ +from typing import cast +from unittest import mock + +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.model import model_impl, model_version_impl +from snowflake.ml.model._client.ops import model_ops +from snowflake.ml.model._client.sql import model_version +from snowflake.ml.test_utils import mock_session +from snowflake.snowpark import Session + + +class ModelImplTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.c_session = cast(Session, self.m_session) + self.m_model = model_impl.Model._ref( + model_ops.ModelOperator( + self.c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ), + model_name=sql_identifier.SqlIdentifier("MODEL"), + ) + + def test_property(self) -> None: + self.assertEqual(self.m_model.name, "MODEL") + self.assertEqual(self.m_model.fully_qualified_name, 'TEMP."test".MODEL') + + def test_version_1(self) -> None: + m_mv = model_version_impl.ModelVersion._ref( + self.m_model._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + ) + with mock.patch.object( + self.m_model._model_ops, "validate_existence", return_value=True + ) as mock_validate_existence: + mv = self.m_model.version("v1") + self.assertEqual(mv, m_mv) + mock_validate_existence.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=mock.ANY, + ) + + def test_version_2(self) -> None: + with mock.patch.object( + self.m_model._model_ops, "validate_existence", return_value=False + ) as mock_validate_existence: + with self.assertRaisesRegex(ValueError, 'Unable to find version with name V1 in model TEMP."test"'): + self.m_model.version("v1") + mock_validate_existence.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=mock.ANY, + ) + + def test_list_versions(self) -> None: + m_mv_1 = model_version_impl.ModelVersion._ref( + self.m_model._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + ) + m_mv_2 = model_version_impl.ModelVersion._ref( + self.m_model._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + ) + with mock.patch.object( + self.m_model._model_ops, + "list_models_or_versions", + return_value=[sql_identifier.SqlIdentifier("V1"), sql_identifier.SqlIdentifier("v1", case_sensitive=True)], + ) as mock_list_models_or_versions: + mv_list = self.m_model.list_versions() + self.assertListEqual(mv_list, [m_mv_1, m_mv_2]) + mock_list_models_or_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + + def test_description_getter(self) -> None: + with mock.patch.object( + self.m_model._model_ops, "get_comment", return_value="this is a comment" + ) as mock_get_comment: + self.assertEqual("this is a comment", self.m_model.description) + mock_get_comment.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + + def test_description_setter(self) -> None: + with mock.patch.object(self.m_model._model_ops, "set_comment") as mock_set_comment: + self.m_model.description = "this is a comment" + mock_set_comment.assert_called_once_with( + comment="this is a comment", + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + + def test_default_getter(self) -> None: + mock_model_ops = absltest.mock.MagicMock(spec=model_ops.ModelOperator) + mock_model_version_client = absltest.mock.MagicMock(spec=model_version.ModelVersionSQLClient) + self.m_model._model_ops = mock_model_ops + mock_model_ops._session = self.m_session + mock_model_ops._model_version_client = mock_model_version_client + mock_model_version_client.get_default_version.return_value = "V1" + + default_model_version = self.m_model.default + self.assertEqual(default_model_version.version_name, "V1") + mock_model_version_client.get_default_version.assert_called() + + def test_default_setter(self) -> None: + mock_model_version_client = absltest.mock.MagicMock(spec=model_version.ModelVersionSQLClient) + self.m_model._model_ops._model_version_client = mock_model_version_client + + # str + self.m_model.default = "V1" # type: ignore[assignment] + mock_model_version_client.set_default_version.assert_called() + + # ModelVersion + mv = model_version_impl.ModelVersion._ref( + self.m_model._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V2"), + ) + self.m_model.default = mv + mock_model_version_client.set_default_version.assert_called() + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/model/_client/model/model_method_info.py b/snowflake/ml/model/_client/model/model_method_info.py new file mode 100644 index 00000000..013eace5 --- /dev/null +++ b/snowflake/ml/model/_client/model/model_method_info.py @@ -0,0 +1,19 @@ +from typing import TypedDict + +from typing_extensions import Required + +from snowflake.ml.model import model_signature + + +class ModelMethodInfo(TypedDict): + """Method information. + + Attributes: + name: Name of the method to be called via SQL. + target_method: actual target method name to be called. + signature: The signature of the model method. + """ + + name: Required[str] + target_method: Required[str] + signature: Required[model_signature.ModelSignature] diff --git a/snowflake/ml/model/_client/model/model_version_impl.py b/snowflake/ml/model/_client/model/model_version_impl.py new file mode 100644 index 00000000..df353f60 --- /dev/null +++ b/snowflake/ml/model/_client/model/model_version_impl.py @@ -0,0 +1,291 @@ +import re +from typing import Any, Callable, Dict, List, Optional, Union + +import pandas as pd + +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import model_signature +from snowflake.ml.model._client.model import model_method_info +from snowflake.ml.model._client.ops import metadata_ops, model_ops +from snowflake.snowpark import dataframe + +_TELEMETRY_PROJECT = "MLOps" +_TELEMETRY_SUBPROJECT = "ModelManagement" + + +class ModelVersion: + """Model Version Object representing a specific version of the model that could be run.""" + + _model_ops: model_ops.ModelOperator + _model_name: sql_identifier.SqlIdentifier + _version_name: sql_identifier.SqlIdentifier + + def __init__(self) -> None: + raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.") + + @classmethod + def _ref( + cls, + model_ops: model_ops.ModelOperator, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + ) -> "ModelVersion": + self: "ModelVersion" = object.__new__(cls) + self._model_ops = model_ops + self._model_name = model_name + self._version_name = version_name + return self + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, ModelVersion): + return False + return ( + self._model_ops == __value._model_ops + and self._model_name == __value._model_name + and self._version_name == __value._version_name + ) + + @property + def model_name(self) -> str: + return self._model_name.identifier() + + @property + def version_name(self) -> str: + return self._version_name.identifier() + + @property + def fully_qualified_model_name(self) -> str: + return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name) + + @property + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def description(self) -> str: + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + return self._model_ops.get_comment( + model_name=self._model_name, + version_name=self._version_name, + statement_params=statement_params, + ) + + @description.setter + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def description(self, description: str) -> None: + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + return self._model_ops.set_comment( + comment=description, + model_name=self._model_name, + version_name=self._version_name, + statement_params=statement_params, + ) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def list_metrics(self) -> Dict[str, Any]: + """Show all metrics logged with the model version. + + Returns: + A dictionary showing the metrics + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + return self._model_ops._metadata_ops.load( + model_name=self._model_name, version_name=self._version_name, statement_params=statement_params + )["metrics"] + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def get_metric(self, metric_name: str) -> Any: + """Get the value of a specific metric. + + Args: + metric_name: The name of the metric + + Raises: + KeyError: Raised when the requested metric name does not exist. + + Returns: + The value of the metric. + """ + metrics = self.list_metrics() + if metric_name not in metrics: + raise KeyError(f"Cannot find metric with name {metric_name}.") + return metrics[metric_name] + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def set_metric(self, metric_name: str, value: Any) -> None: + """Set the value of a specific metric name + + Args: + metric_name: The name of the metric + value: The value of the metric. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + metrics = self.list_metrics() + metrics[metric_name] = value + self._model_ops._metadata_ops.save( + metadata_ops.ModelVersionMetadataSchema(metrics=metrics), + model_name=self._model_name, + version_name=self._version_name, + statement_params=statement_params, + ) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def delete_metric(self, metric_name: str) -> None: + """Delete a metric from metric storage. + + Args: + metric_name: The name of the metric to be deleted. + + Raises: + KeyError: Raised when the requested metric name does not exist. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + metrics = self.list_metrics() + if metric_name not in metrics: + raise KeyError(f"Cannot find metric with name {metric_name}.") + del metrics[metric_name] + self._model_ops._metadata_ops.save( + metadata_ops.ModelVersionMetadataSchema(metrics=metrics), + model_name=self._model_name, + version_name=self._version_name, + statement_params=statement_params, + ) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def list_methods(self) -> List[model_method_info.ModelMethodInfo]: + """List all method information in a model version that is callable. + + Returns: + A list of ModelMethodInfo object containing the following information: + - name: The name of the method to be called (both in SQL and in Python SDK). + - target_method: The original method name in the logged Python object. + - Signature: Python signature of the original method. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + # TODO(SNOW-986673, SNOW-986675): Avoid parsing manifest and meta file and put Python signature into user_data. + manifest = self._model_ops.get_model_version_manifest( + model_name=self._model_name, + version_name=self._version_name, + statement_params=statement_params, + ) + model_meta = self._model_ops.get_model_version_native_packing_meta( + model_name=self._model_name, + version_name=self._version_name, + statement_params=statement_params, + ) + return_methods_info: List[model_method_info.ModelMethodInfo] = [] + for method in manifest["methods"]: + # Method's name is resolved so we need to use case_sensitive as True to get the user-facing identifier. + method_name = sql_identifier.SqlIdentifier(method["name"], case_sensitive=True).identifier() + # Method's handler is `functions..infer` + assert re.match( + r"^functions\.([^\d\W]\w*)\.infer$", method["handler"] + ), f"Get unexpected handler name {method['handler']}" + target_method = method["handler"].split(".")[1] + signature_dict = model_meta["signatures"][target_method] + method_info = model_method_info.ModelMethodInfo( + name=method_name, + target_method=target_method, + signature=model_signature.ModelSignature.from_dict(signature_dict), + ) + return_methods_info.append(method_info) + + return return_methods_info + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def run( + self, + X: Union[pd.DataFrame, dataframe.DataFrame], + *, + method_name: Optional[str] = None, + ) -> Union[pd.DataFrame, dataframe.DataFrame]: + """Invoke a method in a model version object + + Args: + X: The input data. Could be pandas DataFrame or Snowpark DataFrame + method_name: The method name to run. It is the name you will use to call a method in SQL. Defaults to None. + It can only be None if there is only 1 method. + + Raises: + ValueError: No method with the corresponding name is available. + ValueError: There are more than 1 target methods available in the model but no method name specified. + + Returns: + The prediction data. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + + methods: List[model_method_info.ModelMethodInfo] = self.list_methods() + if method_name: + req_method_name = sql_identifier.SqlIdentifier(method_name).identifier() + find_method: Callable[[model_method_info.ModelMethodInfo], bool] = ( + lambda method: method["name"] == req_method_name + ) + target_method_info = next( + filter(find_method, methods), + None, + ) + if target_method_info is None: + raise ValueError( + f"There is no method with name {method_name} available in the model" + f" {self.fully_qualified_model_name} version {self.version_name}" + ) + elif len(methods) != 1: + raise ValueError( + f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}" + f" version {self.version_name}. Please specify a `method_name` when calling the `run` method." + ) + else: + target_method_info = methods[0] + return self._model_ops.invoke_method( + method_name=sql_identifier.SqlIdentifier(target_method_info["name"]), + signature=target_method_info["signature"], + X=X, + model_name=self._model_name, + version_name=self._version_name, + statement_params=statement_params, + ) diff --git a/snowflake/ml/model/_client/model/model_version_impl_test.py b/snowflake/ml/model/_client/model/model_version_impl_test.py new file mode 100644 index 00000000..84f30fcf --- /dev/null +++ b/snowflake/ml/model/_client/model/model_version_impl_test.py @@ -0,0 +1,351 @@ +import textwrap +from typing import cast +from unittest import mock + +import yaml +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import model_signature +from snowflake.ml.model._client.model import model_version_impl +from snowflake.ml.model._client.ops import metadata_ops, model_ops +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import Session + +_DUMMY_SIG = { + "predict": model_signature.ModelSignature( + inputs=[ + model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="input"), + ], + outputs=[model_signature.FeatureSpec(name="output", dtype=model_signature.DataType.FLOAT)], + ) +} + + +class ModelVersionImplTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.c_session = cast(Session, self.m_session) + self.m_mv = model_version_impl.ModelVersion._ref( + model_ops.ModelOperator( + self.c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + ) + + def test_property(self) -> None: + self.assertEqual(self.m_mv.model_name, "MODEL") + self.assertEqual(self.m_mv.fully_qualified_model_name, 'TEMP."test".MODEL') + self.assertEqual(self.m_mv.version_name, '"v1"') + + def test_list_metrics(self) -> None: + m_metadata = metadata_ops.ModelVersionMetadataSchema(metrics={}) + with mock.patch.object(self.m_mv._model_ops._metadata_ops, "load", return_value=m_metadata) as mock_load: + self.assertDictEqual({}, self.m_mv.list_metrics()) + mock_load.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + def test_get_metric_1(self) -> None: + m_metadata = metadata_ops.ModelVersionMetadataSchema(metrics={"a": 1}) + with mock.patch.object(self.m_mv._model_ops._metadata_ops, "load", return_value=m_metadata) as mock_load: + self.assertEqual(1, self.m_mv.get_metric("a")) + mock_load.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + def test_get_metric_2(self) -> None: + m_metadata = metadata_ops.ModelVersionMetadataSchema(metrics={"a": 1}) + with mock.patch.object(self.m_mv._model_ops._metadata_ops, "load", return_value=m_metadata) as mock_load: + with self.assertRaisesRegex(KeyError, "Cannot find metric with name b"): + self.assertEqual(1, self.m_mv.get_metric("b")) + mock_load.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + def test_set_metric_1(self) -> None: + m_metadata = metadata_ops.ModelVersionMetadataSchema(metrics={"a": 1}) + with mock.patch.object( + self.m_mv._model_ops._metadata_ops, "load", return_value=m_metadata + ) as mock_load, mock.patch.object(self.m_mv._model_ops._metadata_ops, "save") as mock_save: + self.m_mv.set_metric("a", 2) + mock_load.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + mock_save.assert_called_once_with( + metadata_ops.ModelVersionMetadataSchema(metrics={"a": 2}), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + def test_set_metric_2(self) -> None: + m_metadata = metadata_ops.ModelVersionMetadataSchema(metrics={"a": 1}) + with mock.patch.object( + self.m_mv._model_ops._metadata_ops, "load", return_value=m_metadata + ) as mock_load, mock.patch.object(self.m_mv._model_ops._metadata_ops, "save") as mock_save: + self.m_mv.set_metric("b", 2) + mock_load.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + mock_save.assert_called_once_with( + metadata_ops.ModelVersionMetadataSchema(metrics={"a": 1, "b": 2}), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + def test_delete_metric_1(self) -> None: + m_metadata = metadata_ops.ModelVersionMetadataSchema(metrics={"a": 1}) + with mock.patch.object( + self.m_mv._model_ops._metadata_ops, "load", return_value=m_metadata + ) as mock_load, mock.patch.object(self.m_mv._model_ops._metadata_ops, "save") as mock_save: + self.m_mv.delete_metric("a") + mock_load.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + mock_save.assert_called_once_with( + metadata_ops.ModelVersionMetadataSchema(metrics={}), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + def test_delete_metric_2(self) -> None: + m_metadata = metadata_ops.ModelVersionMetadataSchema(metrics={"a": 1}) + with mock.patch.object( + self.m_mv._model_ops._metadata_ops, "load", return_value=m_metadata + ) as mock_load, mock.patch.object(self.m_mv._model_ops._metadata_ops, "save") as mock_save: + with self.assertRaisesRegex(KeyError, "Cannot find metric with name b"): + self.m_mv.delete_metric("b") + mock_load.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + mock_save.assert_not_called() + + def test_list_methods(self) -> None: + m_manifest = { + "manifest_version": "1.0", + "runtimes": { + "python_runtime": { + "language": "PYTHON", + "version": "3.8", + "imports": ["model.zip", "runtimes/python_runtime/snowflake-ml-python.zip"], + "dependencies": {"conda": "runtimes/python_runtime/env/conda.yml"}, + } + }, + "methods": [ + { + "name": "predict", + "runtime": "python_runtime", + "type": "FUNCTION", + "handler": "functions.predict.infer", + "inputs": [{"name": "input", "type": "FLOAT"}], + "outputs": [{"type": "OBJECT"}], + }, + { + "name": "__CALL__", + "runtime": "python_runtime", + "type": "FUNCTION", + "handler": "functions.__call__.infer", + "inputs": [{"name": "INPUT", "type": "FLOAT"}], + "outputs": [{"type": "OBJECT"}], + }, + ], + } + m_meta_yaml = yaml.safe_load( + textwrap.dedent( + """ + creation_timestamp: '2023-11-20 18:14:06.357187' + env: + conda: env/conda.yml + cuda_version: null + pip: env/requirements.txt + python_version: '3.8' + snowpark_ml_version: 1.0.13+ca79e1b0720d35abd021c33707de789dc63918cc + metadata: null + min_snowpark_ml_version: 1.0.12 + model_type: sklearn + models: + SKLEARN_MODEL: + artifacts: {} + handler_version: '2023-12-01' + model_type: sklearn + name: SKLEARN_MODEL + options: {} + path: model.pkl + name: SKLEARN_MODEL + signatures: + predict: + inputs: + - name: input + type: FLOAT + outputs: + - name: output + type: FLOAT + __call__: + inputs: + - name: input + type: FLOAT + outputs: + - name: output + type: FLOAT + version: '2023-12-01' + """ + ) + ) + with mock.patch.object( + self.m_mv._model_ops, "get_model_version_manifest", return_value=m_manifest + ) as mock_get_model_version_manifest, mock.patch.object( + self.m_mv._model_ops, "get_model_version_native_packing_meta", return_value=m_meta_yaml + ) as mock_get_model_version_native_packing_meta: + methods = self.m_mv.list_methods() + mock_get_model_version_manifest.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + mock_get_model_version_native_packing_meta.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + self.assertEqual( + methods, + [ + { + "name": '"predict"', + "target_method": "predict", + "signature": _DUMMY_SIG["predict"], + }, + { + "name": "__CALL__", + "target_method": "__call__", + "signature": _DUMMY_SIG["predict"], + }, + ], + ) + + def test_run(self) -> None: + m_df = mock_data_frame.MockDataFrame() + m_methods = [ + { + "name": '"predict"', + "target_method": "predict", + "signature": _DUMMY_SIG["predict"], + }, + { + "name": "__CALL__", + "target_method": "__call__", + "signature": _DUMMY_SIG["predict"], + }, + ] + with mock.patch.object(self.m_mv, "list_methods", return_value=m_methods) as mock_list_methods: + with self.assertRaisesRegex(ValueError, "There is no method with name PREDICT available in the model"): + self.m_mv.run(m_df, method_name="PREDICT") + mock_list_methods.assert_called_once_with() + + with mock.patch.object(self.m_mv, "list_methods", return_value=m_methods) as mock_list_methods: + with self.assertRaisesRegex(ValueError, "There are more than 1 target methods available in the model"): + self.m_mv.run(m_df) + mock_list_methods.assert_called_once_with() + + with mock.patch.object( + self.m_mv, "list_methods", return_value=m_methods + ) as mock_list_methods, mock.patch.object( + self.m_mv._model_ops, "invoke_method", return_value=m_df + ) as mock_invoke_method: + self.m_mv.run(m_df, method_name='"predict"') + mock_list_methods.assert_called_once_with() + mock_invoke_method.assert_called_once_with( + method_name='"predict"', + signature=_DUMMY_SIG["predict"], + X=m_df, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + with mock.patch.object( + self.m_mv, "list_methods", return_value=m_methods + ) as mock_list_methods, mock.patch.object( + self.m_mv._model_ops, "invoke_method", return_value=m_df + ) as mock_invoke_method: + self.m_mv.run(m_df, method_name="__call__") + mock_list_methods.assert_called_once_with() + mock_invoke_method.assert_called_once_with( + method_name="__CALL__", + signature=_DUMMY_SIG["predict"], + X=m_df, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + def test_run_without_method_name(self) -> None: + m_df = mock_data_frame.MockDataFrame() + m_methods = [ + { + "name": '"predict"', + "target_method": "predict", + "signature": _DUMMY_SIG["predict"], + }, + ] + + with mock.patch.object( + self.m_mv, "list_methods", return_value=m_methods + ) as mock_list_methods, mock.patch.object( + self.m_mv._model_ops, "invoke_method", return_value=m_df + ) as mock_invoke_method: + self.m_mv.run(m_df) + mock_list_methods.assert_called_once_with() + mock_invoke_method.assert_called_once_with( + method_name='"predict"', + signature=_DUMMY_SIG["predict"], + X=m_df, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + def test_description_getter(self) -> None: + with mock.patch.object( + self.m_mv._model_ops, "get_comment", return_value="this is a comment" + ) as mock_get_comment: + self.assertEqual("this is a comment", self.m_mv.description) + mock_get_comment.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + def test_description_setter(self) -> None: + with mock.patch.object(self.m_mv._model_ops, "set_comment") as mock_set_comment: + self.m_mv.description = "this is a comment" + mock_set_comment.assert_called_once_with( + comment="this is a comment", + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/model/_client/ops/BUILD.bazel b/snowflake/ml/model/_client/ops/BUILD.bazel new file mode 100644 index 00000000..4775fa4d --- /dev/null +++ b/snowflake/ml/model/_client/ops/BUILD.bazel @@ -0,0 +1,59 @@ +load("//bazel:py_rules.bzl", "py_library", "py_test") + +package(default_visibility = [ + "//snowflake/ml/model/_client/model:__pkg__", + "//snowflake/ml/registry:__pkg__", +]) + +py_library( + name = "model_ops", + srcs = ["model_ops.py"], + deps = [ + ":metadata_ops", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model:model_signature", + "//snowflake/ml/model:type_hints", + "//snowflake/ml/model/_client/sql:model", + "//snowflake/ml/model/_client/sql:model_version", + "//snowflake/ml/model/_client/sql:stage", + "//snowflake/ml/model/_model_composer:model_composer", + "//snowflake/ml/model/_model_composer/model_manifest", + "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", + "//snowflake/ml/model/_packager/model_meta", + "//snowflake/ml/model/_packager/model_meta:model_meta_schema", + "//snowflake/ml/model/_signatures:snowpark_handler", + ], +) + +py_test( + name = "model_ops_test", + srcs = ["model_ops_test.py"], + deps = [ + ":model_ops", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model:model_signature", + "//snowflake/ml/model/_signatures:snowpark_handler", + "//snowflake/ml/test_utils:mock_data_frame", + "//snowflake/ml/test_utils:mock_session", + ], +) + +py_library( + name = "metadata_ops", + srcs = ["metadata_ops.py"], + deps = [ + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model/_client/sql:model", + "//snowflake/ml/model/_client/sql:model_version", + ], +) + +py_test( + name = "metadata_ops_test", + srcs = ["metadata_ops_test.py"], + deps = [ + ":metadata_ops", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/test_utils:mock_session", + ], +) diff --git a/snowflake/ml/model/_client/ops/metadata_ops.py b/snowflake/ml/model/_client/ops/metadata_ops.py new file mode 100644 index 00000000..4ba7c11d --- /dev/null +++ b/snowflake/ml/model/_client/ops/metadata_ops.py @@ -0,0 +1,107 @@ +import json +from typing import Any, Dict, Optional, TypedDict + +from typing_extensions import NotRequired + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.sql import ( + model as model_sql, + model_version as model_version_sql, +) +from snowflake.snowpark import session + +MODEL_VERSION_METADATA_SCHEMA_VERSION = "2024-01-01" + + +class ModelVersionMetadataSchema(TypedDict): + metrics: NotRequired[Dict[str, Any]] + + +class MetadataOperator: + def __init__( + self, + session: session.Session, + *, + database_name: sql_identifier.SqlIdentifier, + schema_name: sql_identifier.SqlIdentifier, + ) -> None: + self._model_client = model_sql.ModelSQLClient( + session, + database_name=database_name, + schema_name=schema_name, + ) + self._model_version_client = model_version_sql.ModelVersionSQLClient( + session, + database_name=database_name, + schema_name=schema_name, + ) + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, MetadataOperator): + return False + return ( + self._model_client == __value._model_client and self._model_version_client == __value._model_version_client + ) + + @staticmethod + def _parse(metadata_dict: Dict[str, Any]) -> ModelVersionMetadataSchema: + loaded_metadata_schema_version = metadata_dict.get("snowpark_ml_schema_version", None) + if loaded_metadata_schema_version is None: + return ModelVersionMetadataSchema(metrics={}) + elif ( + not isinstance(loaded_metadata_schema_version, str) + or loaded_metadata_schema_version != MODEL_VERSION_METADATA_SCHEMA_VERSION + ): + raise ValueError(f"Unsupported model metadata schema version {loaded_metadata_schema_version} confronted.") + loaded_metrics = metadata_dict.get("metrics", {}) + if not isinstance(loaded_metrics, dict): + raise ValueError(f"Metrics in the metadata is expected to be a dictionary, getting {loaded_metrics}") + return ModelVersionMetadataSchema(metrics=loaded_metrics) + + def _get_current_metadata_dict( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + version_info_list = self._model_client.show_versions( + model_name=model_name, version_name=version_name, statement_params=statement_params + ) + assert len(version_info_list) == 1 + version_info = version_info_list[0] + metadata_str = version_info.metadata + if not metadata_str: + return {} + res = json.loads(metadata_str) + if not isinstance(res, dict): + raise ValueError(f"Metadata is expected to be a dictionary, getting {res}") + return res + + def load( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> ModelVersionMetadataSchema: + metadata_dict = self._get_current_metadata_dict( + model_name=model_name, version_name=version_name, statement_params=statement_params + ) + return MetadataOperator._parse(metadata_dict) + + def save( + self, + metadata: ModelVersionMetadataSchema, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + metadata_dict = self._get_current_metadata_dict( + model_name=model_name, version_name=version_name, statement_params=statement_params + ) + metadata_dict.update({**metadata, "snowpark_ml_schema_version": MODEL_VERSION_METADATA_SCHEMA_VERSION}) + self._model_version_client.set_metadata( + metadata_dict, model_name=model_name, version_name=version_name, statement_params=statement_params + ) diff --git a/snowflake/ml/model/_client/ops/metadata_ops_test.py b/snowflake/ml/model/_client/ops/metadata_ops_test.py new file mode 100644 index 00000000..c545de62 --- /dev/null +++ b/snowflake/ml/model/_client/ops/metadata_ops_test.py @@ -0,0 +1,418 @@ +import json +from typing import Any, Dict, cast +from unittest import mock + +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.ops import metadata_ops +from snowflake.ml.test_utils import mock_session +from snowflake.snowpark import Row, Session + + +class metadataOpsTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.m_statement_params = {"test": "1"} + self.c_session = cast(Session, self.m_session) + self.m_ops = metadata_ops.MetadataOperator( + self.c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ) + + def test_get_metadata_dict_1(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="Model", + metadata=None, + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + ] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + metadata_dict = self.m_ops._get_current_metadata_dict( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + self.assertDictEqual(metadata_dict, {}) + mock_show_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_get_metadata_dict_2(self) -> None: + m_meta: Dict[str, Any] = {} + m_list_res = [ + Row( + create_on="06/01", + name="Model", + metadata=json.dumps(m_meta), + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + ] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + metadata_dict = self.m_ops._get_current_metadata_dict( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + self.assertDictEqual(metadata_dict, m_meta) + mock_show_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_get_metadata_dict_3(self) -> None: + m_meta = {"metrics": 1} + m_list_res = [ + Row( + create_on="06/01", + name="Model", + metadata=json.dumps(m_meta), + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + ] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + metadata_dict = self.m_ops._get_current_metadata_dict( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + self.assertDictEqual(metadata_dict, m_meta) + mock_show_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_get_metadata_dict_4(self) -> None: + m_meta = "metrics" + m_list_res = [ + Row( + create_on="06/01", + name="Model", + metadata=json.dumps(m_meta), + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + ] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + with self.assertRaisesRegex(ValueError, "Metadata is expected to be a dictionary"): + self.m_ops._get_current_metadata_dict( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_show_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_load_1(self) -> None: + m_meta: Dict[str, Any] = {} + with mock.patch.object( + self.m_ops, "_get_current_metadata_dict", return_value=m_meta + ) as mock_get_current_metadata_dict: + loaded_meta = self.m_ops.load( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + self.assertDictEqual( + loaded_meta, + metadata_ops.ModelVersionMetadataSchema(metrics={}), + ) + mock_get_current_metadata_dict.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_load_2(self) -> None: + m_meta: Dict[str, Any] = {"metrics": {"a": 1}} + with mock.patch.object( + self.m_ops, "_get_current_metadata_dict", return_value=m_meta + ) as mock_get_current_metadata_dict: + loaded_meta = self.m_ops.load( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + self.assertDictEqual( + loaded_meta, + metadata_ops.ModelVersionMetadataSchema(metrics={}), + ) + mock_get_current_metadata_dict.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_load_3(self) -> None: + m_meta: Dict[str, Any] = {"snowpark_ml_schema_version": 1} + with mock.patch.object( + self.m_ops, "_get_current_metadata_dict", return_value=m_meta + ) as mock_get_current_metadata_dict: + with self.assertRaisesRegex(ValueError, "Unsupported model metadata schema version"): + self.m_ops.load( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_get_current_metadata_dict.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_load_4(self) -> None: + m_meta: Dict[str, Any] = {"snowpark_ml_schema_version": "2023-12-01"} + with mock.patch.object( + self.m_ops, "_get_current_metadata_dict", return_value=m_meta + ) as mock_get_current_metadata_dict: + with self.assertRaisesRegex(ValueError, "Unsupported model metadata schema version"): + self.m_ops.load( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_get_current_metadata_dict.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_load_5(self) -> None: + m_meta: Dict[str, Any] = {"snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION} + with mock.patch.object( + self.m_ops, "_get_current_metadata_dict", return_value=m_meta + ) as mock_get_current_metadata_dict: + loaded_meta = self.m_ops.load( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + self.assertDictEqual( + loaded_meta, + metadata_ops.ModelVersionMetadataSchema(metrics={}), + ) + mock_get_current_metadata_dict.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_load_6(self) -> None: + m_meta: Dict[str, Any] = { + "snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION, + "metrics": 1, + } + with mock.patch.object( + self.m_ops, "_get_current_metadata_dict", return_value=m_meta + ) as mock_get_current_metadata_dict: + with self.assertRaisesRegex(ValueError, "Metrics in the metadata is expected to be a dictionary"): + self.m_ops.load( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_get_current_metadata_dict.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_load_7(self) -> None: + m_meta: Dict[str, Any] = { + "snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION, + "metrics": {"a": 1}, + } + with mock.patch.object( + self.m_ops, "_get_current_metadata_dict", return_value=m_meta + ) as mock_get_current_metadata_dict: + loaded_meta = self.m_ops.load( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + self.assertDictEqual( + loaded_meta, + metadata_ops.ModelVersionMetadataSchema( + metrics={"a": 1}, + ), + ) + mock_get_current_metadata_dict.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_load_8(self) -> None: + m_meta: Dict[str, Any] = { + "snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION, + "metrics": {"a": 1}, + "metrics_2": 2, + } + with mock.patch.object( + self.m_ops, "_get_current_metadata_dict", return_value=m_meta + ) as mock_get_current_metadata_dict: + loaded_meta = self.m_ops.load( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + self.assertDictEqual( + loaded_meta, + metadata_ops.ModelVersionMetadataSchema( + metrics={"a": 1}, + ), + ) + mock_get_current_metadata_dict.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_save_1(self) -> None: + m_meta: Dict[str, Any] = {} + with mock.patch.object(self.m_ops, "_get_current_metadata_dict", return_value=m_meta), mock.patch.object( + self.m_ops._model_version_client, "set_metadata" + ) as mock_set_metadata: + self.m_ops.save( + metadata_ops.ModelVersionMetadataSchema(metrics={}), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_set_metadata.assert_called_once_with( + {"snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION, "metrics": {}}, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_save_2(self) -> None: + m_meta: Dict[str, Any] = {"metrics": 1} + with mock.patch.object(self.m_ops, "_get_current_metadata_dict", return_value=m_meta), mock.patch.object( + self.m_ops._model_version_client, "set_metadata" + ) as mock_set_metadata: + self.m_ops.save( + metadata_ops.ModelVersionMetadataSchema(metrics={}), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_set_metadata.assert_called_once_with( + {"snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION, "metrics": {}}, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_save_3(self) -> None: + m_meta: Dict[str, Any] = {"snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION} + with mock.patch.object(self.m_ops, "_get_current_metadata_dict", return_value=m_meta), mock.patch.object( + self.m_ops._model_version_client, "set_metadata" + ) as mock_set_metadata: + self.m_ops.save( + metadata_ops.ModelVersionMetadataSchema(), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_set_metadata.assert_called_once_with( + {"snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION}, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_save_4(self) -> None: + m_meta: Dict[str, Any] = { + "snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION, + } + with mock.patch.object(self.m_ops, "_get_current_metadata_dict", return_value=m_meta), mock.patch.object( + self.m_ops._model_version_client, "set_metadata" + ) as mock_set_metadata: + self.m_ops.save( + metadata_ops.ModelVersionMetadataSchema(metrics={}), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_set_metadata.assert_called_once_with( + {"snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION, "metrics": {}}, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_save_5(self) -> None: + m_meta: Dict[str, Any] = { + "snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION, + "metrics_2": {}, + } + with mock.patch.object(self.m_ops, "_get_current_metadata_dict", return_value=m_meta), mock.patch.object( + self.m_ops._model_version_client, "set_metadata" + ) as mock_set_metadata: + self.m_ops.save( + metadata_ops.ModelVersionMetadataSchema(metrics={}), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_set_metadata.assert_called_once_with( + { + "snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION, + "metrics": {}, + "metrics_2": {}, + }, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_save_6(self) -> None: + m_meta: Dict[str, Any] = { + "snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION, + "metrics": {"a": 1}, + } + with mock.patch.object(self.m_ops, "_get_current_metadata_dict", return_value=m_meta), mock.patch.object( + self.m_ops._model_version_client, "set_metadata" + ) as mock_set_metadata: + self.m_ops.save( + metadata_ops.ModelVersionMetadataSchema(metrics={}), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_set_metadata.assert_called_once_with( + {"snowpark_ml_schema_version": metadata_ops.MODEL_VERSION_METADATA_SCHEMA_VERSION, "metrics": {}}, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/model/_client/ops/model_ops.py b/snowflake/ml/model/_client/ops/model_ops.py new file mode 100644 index 00000000..40ce8914 --- /dev/null +++ b/snowflake/ml/model/_client/ops/model_ops.py @@ -0,0 +1,308 @@ +import pathlib +import tempfile +from typing import Any, Dict, List, Optional, Union, cast + +import yaml + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import model_signature, type_hints +from snowflake.ml.model._client.ops import metadata_ops +from snowflake.ml.model._client.sql import ( + model as model_sql, + model_version as model_version_sql, + stage as stage_sql, +) +from snowflake.ml.model._model_composer import model_composer +from snowflake.ml.model._model_composer.model_manifest import ( + model_manifest, + model_manifest_schema, +) +from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema +from snowflake.ml.model._signatures import snowpark_handler +from snowflake.snowpark import dataframe, session +from snowflake.snowpark._internal import utils as snowpark_utils + + +class ModelOperator: + def __init__( + self, + session: session.Session, + *, + database_name: sql_identifier.SqlIdentifier, + schema_name: sql_identifier.SqlIdentifier, + ) -> None: + # Ideally, we should only keep session object inside the client, however, some components other than client + # are requiring session object like ModelComposer and SnowparkDataFrameHandler. We currently cannot refractor + # them all but we should try to avoid use the _session object here unless no other choice. + self._session = session + self._stage_client = stage_sql.StageSQLClient( + session, + database_name=database_name, + schema_name=schema_name, + ) + self._model_client = model_sql.ModelSQLClient( + session, + database_name=database_name, + schema_name=schema_name, + ) + self._model_version_client = model_version_sql.ModelVersionSQLClient( + session, + database_name=database_name, + schema_name=schema_name, + ) + self._metadata_ops = metadata_ops.MetadataOperator( + session, + database_name=database_name, + schema_name=schema_name, + ) + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, ModelOperator): + return False + return ( + self._stage_client == __value._stage_client + and self._model_client == __value._model_client + and self._model_version_client == __value._model_version_client + ) + + def prepare_model_stage_path(self, *, statement_params: Optional[Dict[str, Any]] = None) -> str: + stage_name = sql_identifier.SqlIdentifier( + snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE) + ) + self._stage_client.create_tmp_stage(stage_name=stage_name, statement_params=statement_params) + return f"@{self._stage_client.fully_qualified_stage_name(stage_name)}/model" + + def create_from_stage( + self, + composed_model: model_composer.ModelComposer, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + stage_path = str(composed_model.stage_path) + if self.validate_existence( + model_name=model_name, + statement_params=statement_params, + ): + if self.validate_existence( + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ): + raise ValueError( + f"Model {self._model_version_client.fully_qualified_model_name(model_name)} " + f"version {version_name} already existed." + ) + else: + self._model_version_client.add_version_from_stage( + stage_path=stage_path, + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + else: + self._model_version_client.create_from_stage( + stage_path=stage_path, + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + + def list_models_or_versions( + self, + *, + model_name: Optional[sql_identifier.SqlIdentifier] = None, + statement_params: Optional[Dict[str, Any]] = None, + ) -> List[sql_identifier.SqlIdentifier]: + if model_name: + res = self._model_client.show_versions( + model_name=model_name, + statement_params=statement_params, + ) + else: + res = self._model_client.show_models( + statement_params=statement_params, + ) + return [sql_identifier.SqlIdentifier(row.name, case_sensitive=True) for row in res] + + def validate_existence( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: Optional[sql_identifier.SqlIdentifier] = None, + statement_params: Optional[Dict[str, Any]] = None, + ) -> bool: + if version_name: + res = self._model_client.show_versions( + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + else: + res = self._model_client.show_models( + model_name=model_name, + statement_params=statement_params, + ) + return len(res) == 1 + + def get_comment( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: Optional[sql_identifier.SqlIdentifier] = None, + statement_params: Optional[Dict[str, Any]] = None, + ) -> str: + if version_name: + res = self._model_client.show_versions( + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + else: + res = self._model_client.show_models( + model_name=model_name, + statement_params=statement_params, + ) + assert len(res) == 1 + return cast(str, res[0].comment) + + def set_comment( + self, + *, + comment: str, + model_name: sql_identifier.SqlIdentifier, + version_name: Optional[sql_identifier.SqlIdentifier] = None, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + if version_name: + self._model_version_client.set_comment( + comment=comment, + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + else: + self._model_client.set_comment( + comment=comment, + model_name=model_name, + statement_params=statement_params, + ) + + def get_model_version_manifest( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> model_manifest_schema.ModelManifestDict: + with tempfile.TemporaryDirectory() as tmpdir: + self._model_version_client.get_file( + model_name=model_name, + version_name=version_name, + file_path=pathlib.PurePosixPath(model_manifest.ModelManifest.MANIFEST_FILE_REL_PATH), + target_path=pathlib.Path(tmpdir), + statement_params=statement_params, + ) + mm = model_manifest.ModelManifest(pathlib.Path(tmpdir)) + return mm.load() + + def get_model_version_native_packing_meta( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> model_meta_schema.ModelMetadataDict: + with tempfile.TemporaryDirectory() as tmpdir: + model_meta_file_path = self._model_version_client.get_file( + model_name=model_name, + version_name=version_name, + file_path=pathlib.PurePosixPath( + model_composer.ModelComposer.MODEL_DIR_REL_PATH, model_meta.MODEL_METADATA_FILE + ), + target_path=pathlib.Path(tmpdir), + statement_params=statement_params, + ) + with open(model_meta_file_path, encoding="utf-8") as f: + raw_model_meta = yaml.safe_load(f) + return model_meta.ModelMetadata._validate_model_metadata(raw_model_meta) + + def invoke_method( + self, + *, + method_name: sql_identifier.SqlIdentifier, + signature: model_signature.ModelSignature, + X: Union[type_hints.SupportedDataType, dataframe.DataFrame], + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, str]] = None, + ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]: + identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED + + # Validate and prepare input + if not isinstance(X, dataframe.DataFrame): + keep_order = True + output_with_input_features = False + df = model_signature._convert_and_validate_local_data(X, signature.inputs) + s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(self._session, df, keep_order=keep_order) + else: + keep_order = False + output_with_input_features = True + identifier_rule = model_signature._validate_snowpark_data(X, signature.inputs) + s_df = X + + original_cols = s_df.columns + + # Compose input and output names + input_args = [] + for input_feature in signature.inputs: + col_name = identifier_rule.get_sql_identifier_from_feature(input_feature.name) + + input_args.append(col_name) + + returns = [] + for output_feature in signature.outputs: + output_name = identifier_rule.get_sql_identifier_from_feature(output_feature.name) + returns.append((output_feature.name, output_feature.as_snowpark_type(), output_name)) + # Avoid removing output cols when output_with_input_features is False + if output_name in original_cols: + original_cols.remove(output_name) + + df_res = self._model_version_client.invoke_method( + method_name=method_name, + input_df=s_df, + input_args=input_args, + returns=returns, + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + + if keep_order: + df_res = df_res.sort( + "_ID", + ascending=True, + ) + + if not output_with_input_features: + df_res = df_res.drop(*original_cols) + + # Get final result + if not isinstance(X, dataframe.DataFrame): + return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=signature.outputs) + else: + return df_res + + def delete_model_or_version( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: Optional[sql_identifier.SqlIdentifier] = None, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + # TODO: Delete version is not supported yet. + self._model_client.drop_model( + model_name=model_name, + statement_params=statement_params, + ) diff --git a/snowflake/ml/model/_client/ops/model_ops_test.py b/snowflake/ml/model/_client/ops/model_ops_test.py new file mode 100644 index 00000000..317f39fd --- /dev/null +++ b/snowflake/ml/model/_client/ops/model_ops_test.py @@ -0,0 +1,604 @@ +import os +import pathlib +import tempfile +import textwrap +from typing import List, cast +from unittest import mock + +import numpy as np +import pandas as pd +import yaml +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import model_signature +from snowflake.ml.model._client.ops import model_ops +from snowflake.ml.model._signatures import snowpark_handler +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import DataFrame, Row, Session, types as spt +from snowflake.snowpark._internal import utils as snowpark_utils + +_DUMMY_SIG = { + "predict": model_signature.ModelSignature( + inputs=[ + model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="input"), + ], + outputs=[model_signature.FeatureSpec(name="output", dtype=model_signature.DataType.FLOAT)], + ) +} + + +class ModelOpsTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.m_statement_params = {"test": "1"} + self.c_session = cast(Session, self.m_session) + self.m_ops = model_ops.ModelOperator( + self.c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ) + + def test_prepare_model_stage_path(self) -> None: + with mock.patch.object(self.m_ops._stage_client, "create_tmp_stage",) as mock_create_stage, mock.patch.object( + snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_STAGE_ABCDEF0123" + ) as mock_random_name_for_temp_object: + stage_path = self.m_ops.prepare_model_stage_path( + statement_params=self.m_statement_params, + ) + self.assertEqual(stage_path, '@TEMP."test".SNOWPARK_TEMP_STAGE_ABCDEF0123/model') + mock_random_name_for_temp_object.assert_called_once_with(snowpark_utils.TempObjectType.STAGE) + mock_create_stage.assert_called_once_with( + stage_name=sql_identifier.SqlIdentifier("SNOWPARK_TEMP_STAGE_ABCDEF0123"), + statement_params=self.m_statement_params, + ) + + def test_list_models_or_versions_1(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="MODEL", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + Row( + create_on="06/01", + name="Model", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + ] + with mock.patch.object(self.m_ops._model_client, "show_models", return_value=m_list_res) as mock_show_models: + res = self.m_ops.list_models_or_versions( + statement_params=self.m_statement_params, + ) + self.assertListEqual( + res, + [ + sql_identifier.SqlIdentifier("MODEL", case_sensitive=True), + sql_identifier.SqlIdentifier("Model", case_sensitive=True), + ], + ) + mock_show_models.assert_called_once_with( + statement_params=self.m_statement_params, + ) + + def test_list_models_or_versions_2(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="v1", + comment="This is a comment", + model_name="MODEL", + is_default_version=True, + ), + Row( + create_on="06/01", + name="V1", + comment="This is a comment", + model_name="MODEL", + is_default_version=False, + ), + ] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + res = self.m_ops.list_models_or_versions( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + self.assertListEqual( + res, + [ + sql_identifier.SqlIdentifier("v1", case_sensitive=True), + sql_identifier.SqlIdentifier("V1", case_sensitive=True), + ], + ) + mock_show_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + + def test_validate_existence_1(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="Model", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + ] + with mock.patch.object(self.m_ops._model_client, "show_models", return_value=m_list_res) as mock_show_models: + res = self.m_ops.validate_existence( + model_name=sql_identifier.SqlIdentifier("Model", case_sensitive=True), + statement_params=self.m_statement_params, + ) + self.assertTrue(res) + mock_show_models.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("Model", case_sensitive=True), + statement_params=self.m_statement_params, + ) + + def test_validate_existence_2(self) -> None: + m_list_res: List[Row] = [] + with mock.patch.object(self.m_ops._model_client, "show_models", return_value=m_list_res) as mock_show_models: + res = self.m_ops.validate_existence( + model_name=sql_identifier.SqlIdentifier("Model", case_sensitive=True), + statement_params=self.m_statement_params, + ) + self.assertFalse(res) + mock_show_models.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("Model", case_sensitive=True), + statement_params=self.m_statement_params, + ) + + def test_validate_existence_3(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="v1", + comment="This is a comment", + model_name="MODEL", + is_default_version=True, + ), + ] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + res = self.m_ops.validate_existence( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + self.assertTrue(res) + mock_show_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + + def test_validate_existence_4(self) -> None: + m_list_res: List[Row] = [] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + res = self.m_ops.validate_existence( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + self.assertFalse(res) + mock_show_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + + def test_get_model_version_manifest(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + m_manifest = { + "manifest_version": "1.0", + "runtimes": { + "python_runtime": { + "language": "PYTHON", + "version": "3.8", + "imports": ["model.zip", "runtimes/python_runtime/snowflake-ml-python.zip"], + "dependencies": {"conda": "runtimes/python_runtime/env/conda.yml"}, + } + }, + "methods": [ + { + "name": "predict", + "runtime": "python_runtime", + "type": "FUNCTION", + "handler": "functions.predict.infer", + "inputs": [{"name": "input", "type": "FLOAT"}], + "outputs": [{"type": "OBJECT"}], + }, + { + "name": "__CALL__", + "runtime": "python_runtime", + "type": "FUNCTION", + "handler": "functions.__call__.infer", + "inputs": [{"name": "INPUT", "type": "FLOAT"}], + "outputs": [{"type": "OBJECT"}], + }, + ], + } + m_manifest_path = os.path.join(tmpdir, "MANIFEST.yml") + with open(m_manifest_path, "w", encoding="utf-8") as f: + yaml.safe_dump(m_manifest, f) + with mock.patch.object(tempfile.TemporaryDirectory, "__enter__", return_value=tmpdir), mock.patch.object( + self.m_ops._model_version_client, "get_file", return_value=m_manifest_path + ) as mock_get_file: + manifest_res = self.m_ops.get_model_version_manifest( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + mock_get_file.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + file_path=pathlib.PurePosixPath("MANIFEST.yml"), + target_path=mock.ANY, + statement_params=self.m_statement_params, + ) + self.assertDictEqual(manifest_res, m_manifest) + + def test_get_model_version_native_packing_meta(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + m_meta_yaml = textwrap.dedent( + """ + creation_timestamp: '2023-11-20 18:14:06.357187' + env: + conda: env/conda.yml + cuda_version: null + pip: env/requirements.txt + python_version: '3.8' + snowpark_ml_version: 1.0.13+ca79e1b0720d35abd021c33707de789dc63918cc + metadata: null + min_snowpark_ml_version: 1.0.12 + model_type: sklearn + models: + SKLEARN_MODEL: + artifacts: {} + handler_version: '2023-12-01' + model_type: sklearn + name: SKLEARN_MODEL + options: {} + path: model.pkl + name: SKLEARN_MODEL + signatures: + predict: + inputs: + - name: input_feature_0 + type: DOUBLE + outputs: + - name: output_feature_0 + type: BOOL + version: '2023-12-01' + """ + ) + m_meta_path = os.path.join(tmpdir, "model.yaml") + with open(m_meta_path, "w", encoding="utf-8") as f: + f.write(m_meta_yaml) + with mock.patch.object( + self.m_ops._model_version_client, "get_file", return_value=m_meta_path + ) as mock_get_file: + manifest_res = self.m_ops.get_model_version_native_packing_meta( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + mock_get_file.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + file_path=pathlib.PurePosixPath("model/model.yaml"), + target_path=mock.ANY, + statement_params=self.m_statement_params, + ) + self.assertDictEqual(manifest_res, yaml.safe_load(m_meta_yaml)) + + def test_create_from_stage_1(self) -> None: + mock_composer = mock.MagicMock() + mock_composer.stage_path = '@TEMP."test".MODEL/V1' + + with mock.patch.object( + self.m_ops._model_version_client, "create_from_stage", return_value='TEMP."test".MODEL' + ) as mock_create_from_stage, mock.patch.object( + self.m_ops._model_version_client, "add_version_from_stage", return_value='TEMP."test".MODEL' + ) as mock_add_version_from_stage, mock.patch.object( + self.m_ops._model_client, "show_models", return_value=[] + ): + self.m_ops.create_from_stage( + composed_model=mock_composer, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_create_from_stage.assert_called_once_with( + stage_path='@TEMP."test".MODEL/V1', + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_add_version_from_stage.assert_not_called() + + def test_create_from_stage_2(self) -> None: + mock_composer = mock.MagicMock() + mock_composer.stage_path = '@TEMP."test".MODEL/V1' + m_list_res = [ + Row( + create_on="06/01", + name="Model", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + ] + with mock.patch.object( + self.m_ops._model_version_client, "create_from_stage", return_value='TEMP."test".MODEL' + ) as mock_create_from_stage, mock.patch.object( + self.m_ops._model_version_client, "add_version_from_stage", return_value='TEMP."test".MODEL' + ) as mock_add_version_from_stage, mock.patch.object( + self.m_ops._model_client, "show_models", return_value=m_list_res + ), mock.patch.object( + self.m_ops._model_client, attribute="show_versions", return_value=[] + ): + self.m_ops.create_from_stage( + composed_model=mock_composer, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_create_from_stage.assert_not_called() + mock_add_version_from_stage.assert_called_once_with( + stage_path='@TEMP."test".MODEL/V1', + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_create_from_stage_3(self) -> None: + mock_composer = mock.MagicMock() + mock_composer.stage_path = '@TEMP."test".MODEL/V1' + m_list_res_models = ( + Row( + create_on="06/01", + name="Model", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + ) + m_list_res_versions = [ + Row( + create_on="06/01", + name="v1", + comment="This is a comment", + model_name="MODEL", + is_default_version=True, + ), + ] + with mock.patch.object( + self.m_ops._model_version_client, "create_from_stage", return_value='TEMP."test".MODEL' + ) as mock_create_from_stage, mock.patch.object( + self.m_ops._model_version_client, "add_version_from_stage", return_value='TEMP."test".MODEL' + ) as mock_add_version_from_stagel, mock.patch.object( + self.m_ops._model_client, "show_models", return_value=m_list_res_models + ), mock.patch.object( + self.m_ops._model_client, attribute="show_versions", return_value=m_list_res_versions + ): + with self.assertRaisesRegex(ValueError, 'Model TEMP."test".MODEL version V1 already existed.'): + self.m_ops.create_from_stage( + composed_model=mock_composer, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_create_from_stage.assert_not_called() + mock_add_version_from_stagel.assert_not_called() + + def test_invoke_method_1(self) -> None: + pd_df = pd.DataFrame([["1.0"]], columns=["input"], dtype=np.float32) + m_sig = _DUMMY_SIG["predict"] + m_df = mock_data_frame.MockDataFrame() + m_df.__setattr__("_statement_params", None) + m_df.__setattr__("columns", ["COL1", "COL2"]) + m_df.add_mock_sort("_ID", ascending=True).add_mock_drop("COL1", "COL2") + with mock.patch.object( + snowpark_handler.SnowparkDataFrameHandler, "convert_from_df", return_value=m_df + ) as mock_convert_from_df, mock.patch.object( + self.m_ops._model_version_client, "invoke_method", return_value=m_df + ) as mock_invoke_method, mock.patch.object( + snowpark_handler.SnowparkDataFrameHandler, "convert_to_df", return_value=pd_df + ) as mock_convert_to_df: + self.m_ops.invoke_method( + method_name=sql_identifier.SqlIdentifier("PREDICT"), + signature=m_sig, + X=pd_df, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_convert_from_df.assert_called_once_with(self.c_session, mock.ANY, keep_order=True) + mock_invoke_method.assert_called_once_with( + method_name=sql_identifier.SqlIdentifier("PREDICT"), + input_df=m_df, + input_args=['"input"'], + returns=[("output", spt.FloatType(), '"output"')], + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_convert_to_df.assert_called_once_with(m_df, features=m_sig.outputs) + + def test_invoke_method_1_no_drop(self) -> None: + pd_df = pd.DataFrame([["1.0"]], columns=["input"], dtype=np.float32) + m_sig = _DUMMY_SIG["predict"] + m_df = mock_data_frame.MockDataFrame() + m_df.__setattr__("_statement_params", None) + m_df.__setattr__("columns", ["COL1", '"output"']) + m_df.add_mock_sort("_ID", ascending=True).add_mock_drop("COL1") + with mock.patch.object( + snowpark_handler.SnowparkDataFrameHandler, "convert_from_df", return_value=m_df + ) as mock_convert_from_df, mock.patch.object( + self.m_ops._model_version_client, "invoke_method", return_value=m_df + ) as mock_invoke_method, mock.patch.object( + snowpark_handler.SnowparkDataFrameHandler, "convert_to_df", return_value=pd_df + ) as mock_convert_to_df: + self.m_ops.invoke_method( + method_name=sql_identifier.SqlIdentifier("PREDICT"), + signature=m_sig, + X=pd_df, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_convert_from_df.assert_called_once_with(self.c_session, mock.ANY, keep_order=True) + mock_invoke_method.assert_called_once_with( + method_name=sql_identifier.SqlIdentifier("PREDICT"), + input_df=m_df, + input_args=['"input"'], + returns=[("output", spt.FloatType(), '"output"')], + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_convert_to_df.assert_called_once_with(m_df, features=m_sig.outputs) + + def test_invoke_method_2(self) -> None: + m_sig = _DUMMY_SIG["predict"] + m_df = mock_data_frame.MockDataFrame() + m_df.__setattr__("columns", ["COL1", "COL2"]) + with mock.patch.object( + snowpark_handler.SnowparkDataFrameHandler, "convert_from_df" + ) as mock_convert_from_df, mock.patch.object( + model_signature, "_validate_snowpark_data", return_value=model_signature.SnowparkIdentifierRule.NORMALIZED + ) as mock_validate_snowpark_data, mock.patch.object( + self.m_ops._model_version_client, "invoke_method", return_value=m_df + ) as mock_invoke_method, mock.patch.object( + snowpark_handler.SnowparkDataFrameHandler, "convert_to_df" + ) as mock_convert_to_df: + self.m_ops.invoke_method( + method_name=sql_identifier.SqlIdentifier("PREDICT"), + signature=m_sig, + X=cast(DataFrame, m_df), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_convert_from_df.assert_not_called() + mock_validate_snowpark_data.assert_called_once_with(m_df, m_sig.inputs) + + mock_invoke_method.assert_called_once_with( + method_name=sql_identifier.SqlIdentifier("PREDICT"), + input_df=m_df, + input_args=["INPUT"], + returns=[("output", spt.FloatType(), "OUTPUT")], + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_convert_to_df.assert_not_called() + + def test_get_comment_1(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="v1", + comment="This is a comment", + model_name="MODEL", + is_default_version=True, + ), + ] + with mock.patch.object(self.m_ops._model_client, "show_models", return_value=m_list_res) as mock_show_models: + res = self.m_ops.get_comment( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + self.assertEqual(res, "This is a comment") + mock_show_models.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + + def test_get_comment_2(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="V1", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + ] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + res = self.m_ops.get_comment( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + self.assertEqual(res, "This is a comment") + mock_show_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_set_comment_1(self) -> None: + with mock.patch.object(self.m_ops._model_client, "set_comment") as mock_set_comment: + self.m_ops.set_comment( + comment="This is a comment", + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + mock_set_comment.assert_called_once_with( + comment="This is a comment", + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + + def test_set_comment_2(self) -> None: + with mock.patch.object(self.m_ops._model_version_client, "set_comment") as mock_set_comment: + self.m_ops.set_comment( + comment="This is a comment", + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_set_comment.assert_called_once_with( + comment="This is a comment", + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + + def test_delete_model_or_version(self) -> None: + with mock.patch.object( + self.m_ops._model_client, + "drop_model", + ) as mock_drop_model: + self.m_ops.delete_model_or_version( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + mock_drop_model.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/model/_client/sql/BUILD.bazel b/snowflake/ml/model/_client/sql/BUILD.bazel new file mode 100644 index 00000000..465e2f24 --- /dev/null +++ b/snowflake/ml/model/_client/sql/BUILD.bazel @@ -0,0 +1,63 @@ +load("//bazel:py_rules.bzl", "py_library", "py_test") + +package(default_visibility = ["//snowflake/ml/model/_client/ops:__pkg__"]) + +py_library( + name = "model", + srcs = ["model.py"], + deps = [ + "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:sql_identifier", + ], +) + +py_test( + name = "model_test", + srcs = ["model_test.py"], + deps = [ + ":model", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/test_utils:mock_data_frame", + "//snowflake/ml/test_utils:mock_session", + ], +) + +py_library( + name = "model_version", + srcs = ["model_version.py"], + deps = [ + "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:sql_identifier", + ], +) + +py_test( + name = "model_version_test", + srcs = ["model_version_test.py"], + deps = [ + ":model_version", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/test_utils:mock_data_frame", + "//snowflake/ml/test_utils:mock_session", + ], +) + +py_library( + name = "stage", + srcs = ["stage.py"], + deps = [ + "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:sql_identifier", + ], +) + +py_test( + name = "stage_test", + srcs = ["stage_test.py"], + deps = [ + ":stage", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/test_utils:mock_data_frame", + "//snowflake/ml/test_utils:mock_session", + ], +) diff --git a/snowflake/ml/model/_client/sql/model.py b/snowflake/ml/model/_client/sql/model.py new file mode 100644 index 00000000..040b5dea --- /dev/null +++ b/snowflake/ml/model/_client/sql/model.py @@ -0,0 +1,75 @@ +from typing import Any, Dict, List, Optional + +from snowflake.ml._internal.utils import identifier, sql_identifier +from snowflake.snowpark import row, session + + +class ModelSQLClient: + def __init__( + self, + session: session.Session, + *, + database_name: sql_identifier.SqlIdentifier, + schema_name: sql_identifier.SqlIdentifier, + ) -> None: + self._session = session + self._database_name = database_name + self._schema_name = schema_name + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, ModelSQLClient): + return False + return self._database_name == __value._database_name and self._schema_name == __value._schema_name + + def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str: + return identifier.get_schema_level_object_identifier( + self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier() + ) + + def show_models( + self, + *, + model_name: Optional[sql_identifier.SqlIdentifier] = None, + statement_params: Optional[Dict[str, Any]] = None, + ) -> List[row.Row]: + fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()]) + like_sql = "" + if model_name: + like_sql = f" LIKE '{model_name.resolved()}'" + res = self._session.sql(f"SHOW MODELS{like_sql} IN SCHEMA {fully_qualified_schema_name}") + + return res.collect(statement_params=statement_params) + + def show_versions( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: Optional[sql_identifier.SqlIdentifier] = None, + statement_params: Optional[Dict[str, Any]] = None, + ) -> List[row.Row]: + like_sql = "" + if version_name: + like_sql = f" LIKE '{version_name.resolved()}'" + res = self._session.sql(f"SHOW VERSIONS{like_sql} IN MODEL {self.fully_qualified_model_name(model_name)}") + + return res.collect(statement_params=statement_params) + + def set_comment( + self, + *, + comment: str, + model_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + comment_sql = f"COMMENT ON MODEL {self.fully_qualified_model_name(model_name)} IS $${comment}$$" + self._session.sql(comment_sql).collect(statement_params=statement_params) + + def drop_model( + self, + *, + model_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + self._session.sql(f"DROP MODEL {self.fully_qualified_model_name(model_name)}").collect( + statement_params=statement_params + ) diff --git a/snowflake/ml/model/_client/sql/model_test.py b/snowflake/ml/model/_client/sql/model_test.py new file mode 100644 index 00000000..2d0c133a --- /dev/null +++ b/snowflake/ml/model/_client/sql/model_test.py @@ -0,0 +1,164 @@ +from typing import cast + +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.sql import model as model_sql +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import Row, Session + + +class ModelSQLTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + + def test_show_models_1(self) -> None: + m_statement_params = {"test": "1"} + m_df_final = mock_data_frame.MockDataFrame( + collect_result=[ + Row( + create_on="06/01", + name="MODEL", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + Row( + create_on="06/01", + name="Model", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + ], + collect_statement_params=m_statement_params, + ) + self.m_session.add_mock_sql("""SHOW MODELS IN SCHEMA TEMP."test" """, m_df_final) + c_session = cast(Session, self.m_session) + model_sql.ModelSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).show_models( + statement_params=m_statement_params, + ) + + def test_show_models_2(self) -> None: + m_statement_params = {"test": "1"} + m_df_final = mock_data_frame.MockDataFrame( + collect_result=[ + Row( + create_on="06/01", + name="Model", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + ), + ], + collect_statement_params=m_statement_params, + ) + self.m_session.add_mock_sql("""SHOW MODELS LIKE 'Model' IN SCHEMA TEMP."test" """, m_df_final) + c_session = cast(Session, self.m_session) + model_sql.ModelSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).show_models( + model_name=sql_identifier.SqlIdentifier("Model", case_sensitive=True), + statement_params=m_statement_params, + ) + + def test_show_versions_1(self) -> None: + m_statement_params = {"test": "1"} + m_df_final = mock_data_frame.MockDataFrame( + collect_result=[ + Row( + create_on="06/01", + name="v1", + comment="This is a comment", + model_name="MODEL", + is_default_version=True, + ), + Row( + create_on="06/01", + name="V1", + comment="This is a comment", + model_name="MODEL", + is_default_version=False, + ), + ], + collect_statement_params=m_statement_params, + ) + self.m_session.add_mock_sql("""SHOW VERSIONS IN MODEL TEMP."test".MODEL""", m_df_final) + c_session = cast(Session, self.m_session) + model_sql.ModelSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).show_versions( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=m_statement_params, + ) + + def test_show_versions_2(self) -> None: + m_statement_params = {"test": "1"} + m_df_final = mock_data_frame.MockDataFrame( + collect_result=[ + Row( + create_on="06/01", + name="v1", + comment="This is a comment", + model_name="MODEL", + is_default_version=True, + ), + ], + collect_statement_params=m_statement_params, + ) + self.m_session.add_mock_sql("""SHOW VERSIONS LIKE 'v1' IN MODEL TEMP."test".MODEL""", m_df_final) + c_session = cast(Session, self.m_session) + model_sql.ModelSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).show_versions( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=m_statement_params, + ) + + def test_set_comment_for_model(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame(collect_result=[Row("")], collect_statement_params=m_statement_params) + comment = "This is my comment" + self.m_session.add_mock_sql(f"""COMMENT ON MODEL TEMP."test".MODEL IS $${comment}$$""", m_df) + c_session = cast(Session, self.m_session) + model_sql.ModelSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).set_comment( + model_name=sql_identifier.SqlIdentifier("MODEL"), comment=comment, statement_params=m_statement_params + ) + + def test_drop_model(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row("Model MODEL successfully dropped.")], collect_statement_params=m_statement_params + ) + self.m_session.add_mock_sql("""DROP MODEL TEMP."test".MODEL""", m_df) + c_session = cast(Session, self.m_session) + model_sql.ModelSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).drop_model( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=m_statement_params, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/model/_client/sql/model_version.py b/snowflake/ml/model/_client/sql/model_version.py new file mode 100644 index 00000000..7ffc7d22 --- /dev/null +++ b/snowflake/ml/model/_client/sql/model_version.py @@ -0,0 +1,213 @@ +import json +import pathlib +import textwrap +from typing import Any, Dict, List, Optional, Tuple +from urllib.parse import ParseResult + +from snowflake.ml._internal.utils import identifier, sql_identifier +from snowflake.snowpark import dataframe, functions as F, session, types as spt +from snowflake.snowpark._internal import utils as snowpark_utils + + +def _normalize_url_for_sql(url: str) -> str: + if url.startswith("'") and url.endswith("'"): + url = url[1:-1] + url = url.replace("'", "\\'") + return f"'{url}'" + + +class ModelVersionSQLClient: + def __init__( + self, + session: session.Session, + *, + database_name: sql_identifier.SqlIdentifier, + schema_name: sql_identifier.SqlIdentifier, + ) -> None: + self._session = session + self._database_name = database_name + self._schema_name = schema_name + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, ModelVersionSQLClient): + return False + return self._database_name == __value._database_name and self._schema_name == __value._schema_name + + def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str: + return identifier.get_schema_level_object_identifier( + self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier() + ) + + def create_from_stage( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + stage_path: str, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + self._version_name = version_name + self._session.sql( + f"CREATE MODEL {self.fully_qualified_model_name(model_name)} WITH VERSION {version_name.identifier()}" + f" FROM {stage_path}" + ).collect(statement_params=statement_params) + + # TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...` + def add_version_from_stage( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + stage_path: str, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + self._version_name = version_name + self._session.sql( + f"ALTER MODEL {self.fully_qualified_model_name(model_name)} ADD VERSION {version_name.identifier()}" + f" FROM {stage_path}" + ).collect(statement_params=statement_params) + + def set_default_version( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + self._session.sql( + f"ALTER MODEL {self.fully_qualified_model_name(model_name)} " + f"SET DEFAULT_VERSION = {version_name.identifier()}" + ).collect(statement_params=statement_params) + + def get_default_version( + self, + *, + model_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> str: + # TODO: Replace SHOW with DESC when available. + default_version: str = ( + self._session.sql(f"SHOW VERSIONS IN MODEL {self.fully_qualified_model_name(model_name)}") + .filter('"is_default_version" = TRUE')[['"name"']] + .collect(statement_params=statement_params)[0][0] + ) + return default_version + + def get_file( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + file_path: pathlib.PurePosixPath, + target_path: pathlib.Path, + statement_params: Optional[Dict[str, Any]] = None, + ) -> pathlib.Path: + stage_location = pathlib.PurePosixPath( + self.fully_qualified_model_name(model_name), "versions", version_name.resolved(), file_path + ).as_posix() + stage_location_url = ParseResult( + scheme="snow", netloc="model", path=stage_location, params="", query="", fragment="" + ).geturl() + local_location = target_path.absolute().as_posix() + local_location_url = ParseResult( + scheme="file", netloc="", path=local_location, params="", query="", fragment="" + ).geturl() + + self._session.sql( + f"GET {_normalize_url_for_sql(stage_location_url)} {_normalize_url_for_sql(local_location_url)}" + ).collect(statement_params=statement_params) + return target_path / file_path.name + + def set_comment( + self, + *, + comment: str, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + comment_sql = ( + f"ALTER MODEL {self.fully_qualified_model_name(model_name)} " + f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$" + ) + self._session.sql(comment_sql).collect(statement_params=statement_params) + + def invoke_method( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + method_name: sql_identifier.SqlIdentifier, + input_df: dataframe.DataFrame, + input_args: List[sql_identifier.SqlIdentifier], + returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]], + statement_params: Optional[Dict[str, Any]] = None, + ) -> dataframe.DataFrame: + tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE) + INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier( + self._database_name.identifier(), + self._schema_name.identifier(), + tmp_table_name, + ) + input_df.write.save_as_table( # type: ignore[call-overload] + table_name=INTERMEDIATE_TABLE_NAME, + mode="errorifexists", + table_type="temporary", + statement_params=statement_params, + ) + + INTERMEDIATE_OBJ_NAME = "TMP_RESULT" + + module_version_alias = "MODEL_VERSION_ALIAS" + model_version_alias_sql = ( + f"WITH {module_version_alias} AS " + f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}" + ) + + args_sql_list = [] + for input_arg_value in input_args: + args_sql_list.append(input_arg_value) + + args_sql = ", ".join(args_sql_list) + + sql = textwrap.dedent( + f"""{model_version_alias_sql} + SELECT *, + {module_version_alias}!{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME} + FROM {INTERMEDIATE_TABLE_NAME}""" + ) + + output_df = self._session.sql(sql) + + # Prepare the output + output_cols = [] + output_names = [] + + for output_name, output_type, output_col_name in returns: + output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_name].astype(output_type)) + output_names.append(output_col_name) + + output_df = output_df.with_columns( + col_names=output_names, + values=output_cols, + ).drop(INTERMEDIATE_OBJ_NAME) + + if statement_params: + output_df._statement_params = statement_params # type: ignore[assignment] + + return output_df + + def set_metadata( + self, + metadata_dict: Dict[str, Any], + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + json_metadata = json.dumps(metadata_dict) + sql = ( + f"ALTER MODEL {self.fully_qualified_model_name(model_name)} MODIFY VERSION {version_name.identifier()}" + f" SET METADATA=$${json_metadata}$$" + ) + self._session.sql(sql).collect(statement_params=statement_params) diff --git a/snowflake/ml/model/_client/sql/model_version_test.py b/snowflake/ml/model/_client/sql/model_version_test.py new file mode 100644 index 00000000..f732f714 --- /dev/null +++ b/snowflake/ml/model/_client/sql/model_version_test.py @@ -0,0 +1,156 @@ +import pathlib +from typing import cast +from unittest import mock + +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.sql import model_version as model_version_sql +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import DataFrame, Row, Session, functions as F, types as spt +from snowflake.snowpark._internal import utils as snowpark_utils + + +class ModelVersionSQLTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + + def test_create_from_stage(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row("Model MODEL successfully created.")], collect_statement_params=m_statement_params + ) + stage_path = '@TEMP."test".MODEL/V1' + self.m_session.add_mock_sql(f"""CREATE MODEL TEMP."test".MODEL WITH VERSION V1 FROM {stage_path}""", m_df) + c_session = cast(Session, self.m_session) + model_version_sql.ModelVersionSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).create_from_stage( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + stage_path=stage_path, + statement_params=m_statement_params, + ) + + def test_add_version_from_stage(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row("Model MODEL successfully altered.")], collect_statement_params=m_statement_params + ) + stage_path = '@TEMP."test".MODEL/V2' + self.m_session.add_mock_sql(f"""ALTER MODEL TEMP."test".MODEL ADD VERSION V2 FROM {stage_path}""", m_df) + c_session = cast(Session, self.m_session) + model_version_sql.ModelVersionSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).add_version_from_stage( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V2"), + stage_path=stage_path, + statement_params=m_statement_params, + ) + + def test_set_comment(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame(collect_result=[Row("")], collect_statement_params=m_statement_params) + comment = "This is my comment" + self.m_session.add_mock_sql( + f"""ALTER MODEL TEMP."test".MODEL MODIFY VERSION "v1" SET COMMENT=$${comment}$$""", m_df + ) + c_session = cast(Session, self.m_session) + model_version_sql.ModelVersionSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).set_comment( + comment=comment, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=m_statement_params, + ) + + def test_get_file(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame(collect_result=[Row()], collect_statement_params=m_statement_params) + self.m_session.add_mock_sql( + """GET 'snow://model/TEMP."test".MODEL/versions/v1/model.yaml' 'file:///tmp'""", m_df + ) + c_session = cast(Session, self.m_session) + res = model_version_sql.ModelVersionSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).get_file( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + file_path=pathlib.PurePosixPath("model.yaml"), + target_path=pathlib.Path("/tmp"), + statement_params=m_statement_params, + ) + self.assertEqual(res, pathlib.Path("/tmp/model.yaml")) + + def test_invoke_method(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame() + self.m_session.add_mock_sql( + """WITH MODEL_VERSION_ALIAS AS MODEL TEMP."test".MODEL VERSION V1 + SELECT *, + MODEL_VERSION_ALIAS!PREDICT(COL1, COL2) AS TMP_RESULT + FROM TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123""", + m_df, + ) + m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]).add_mock_drop("TMP_RESULT") + c_session = cast(Session, self.m_session) + mock_writer = mock.MagicMock() + m_df.__setattr__("write", mock_writer) + with mock.patch.object(mock_writer, "save_as_table") as mock_save_as_table, mock.patch.object( + snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_TABLE_ABCDEF0123" + ) as mock_random_name_for_temp_object: + model_version_sql.ModelVersionSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).invoke_method( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + method_name=sql_identifier.SqlIdentifier("PREDICT"), + input_df=cast(DataFrame, m_df), + input_args=[sql_identifier.SqlIdentifier("COL1"), sql_identifier.SqlIdentifier("COL2")], + returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], + statement_params=m_statement_params, + ) + mock_random_name_for_temp_object.assert_called_once_with(snowpark_utils.TempObjectType.TABLE) + mock_save_as_table.assert_called_once_with( + table_name='TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123', + mode="errorifexists", + table_type="temporary", + statement_params=m_statement_params, + ) + + def test_set_metadata(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame(collect_result=[Row("")], collect_statement_params=m_statement_params) + metadata = {"metrics": {"a": 1, "c": "This is my comment"}, "other": 2.0} + self.m_session.add_mock_sql( + """ALTER MODEL TEMP."test".MODEL MODIFY VERSION "v1" + SET METADATA=$${"metrics": {"a": 1, "c": "This is my comment"}, "other": 2.0}$$""", + m_df, + ) + c_session = cast(Session, self.m_session) + model_version_sql.ModelVersionSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).set_metadata( + metadata, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=m_statement_params, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/model/_client/sql/stage.py b/snowflake/ml/model/_client/sql/stage.py new file mode 100644 index 00000000..8b9750a6 --- /dev/null +++ b/snowflake/ml/model/_client/sql/stage.py @@ -0,0 +1,40 @@ +from typing import Any, Dict, Optional + +from snowflake.ml._internal.utils import identifier, sql_identifier +from snowflake.snowpark import session + + +class StageSQLClient: + def __init__( + self, + session: session.Session, + *, + database_name: sql_identifier.SqlIdentifier, + schema_name: sql_identifier.SqlIdentifier, + ) -> None: + self._session = session + self._database_name = database_name + self._schema_name = schema_name + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, StageSQLClient): + return False + return self._database_name == __value._database_name and self._schema_name == __value._schema_name + + def fully_qualified_stage_name( + self, + stage_name: sql_identifier.SqlIdentifier, + ) -> str: + return identifier.get_schema_level_object_identifier( + self._database_name.identifier(), self._schema_name.identifier(), stage_name.identifier() + ) + + def create_tmp_stage( + self, + *, + stage_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + self._session.sql(f"CREATE TEMPORARY STAGE {self.fully_qualified_stage_name(stage_name)}").collect( + statement_params=statement_params + ) diff --git a/snowflake/ml/model/_client/sql/stage_test.py b/snowflake/ml/model/_client/sql/stage_test.py new file mode 100644 index 00000000..8422fd62 --- /dev/null +++ b/snowflake/ml/model/_client/sql/stage_test.py @@ -0,0 +1,33 @@ +from typing import cast + +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.sql import stage as stage_sql +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import Row, Session + + +class StageSQLTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + + def test_create_tmp_stage(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row("Stage MODEL successfully created.")], collect_statement_params=m_statement_params + ) + self.m_session.add_mock_sql("""CREATE TEMPORARY STAGE TEMP."test".MODEL""", m_df) + c_session = cast(Session, self.m_session) + stage_sql.StageSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).create_tmp_stage( + stage_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=m_statement_params, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py b/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py index a99a1e51..963a3c0f 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +++ b/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py @@ -4,7 +4,6 @@ from string import Template import importlib_resources -import yaml from snowflake import snowpark from snowflake.ml._internal import file_utils @@ -180,7 +179,7 @@ def _construct_and_upload_job_spec(self, base_image: str, kaniko_shell_script_st assert self.artifact_stage_location.startswith("@") normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@")) (db, schema, stage, path) = identifier.parse_schema_level_object_identifier(normed_artifact_stage_path) - content = Template(spec_template).substitute( + content = Template(spec_template).safe_substitute( { "base_image": base_image, "container_name": constants.KANIKO_CONTAINER_NAME, @@ -188,10 +187,10 @@ def _construct_and_upload_job_spec(self, base_image: str, kaniko_shell_script_st # Remove @ in the beginning, append "/" to denote root directory. "script_path": "/" + posixpath.normpath(identifier.remove_prefix(kaniko_shell_script_stage_location, "@")), + "mounted_token_path": constants.SPCS_MOUNTED_TOKEN_PATH, } ) - content_dict = yaml.safe_load(content) - yaml.dump(content_dict, spec_file) + spec_file.write(content) spec_file.seek(0) logger.debug(f"Kaniko job spec file: \n\n {spec_file.read()}") diff --git a/snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template b/snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template index f3694df2..ff7c28b4 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +++ b/snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template @@ -1,22 +1,38 @@ spec: container: - - name: $container_name - image: $base_image + - name: "${container_name}" + image: "${base_image}" command: - sh args: - -c - - >- - while [ ! -f "$script_path" ]; do echo "File not found: $script_path"; sleep 1; done; - chmod +x $script_path; - sh $script_path; + - | + wait_for_file() { + file_path="$1" + timeout="$2" + elapsed_time=0 + while [ ! -f "${file_path}" ]; do + if [ "${elapsed_time}" -ge "${timeout}" ]; then + echo "Error: ${file_path} not found within ${timeout} seconds. Exiting." + exit 1 + fi + elapsed_time=$((elapsed_time + 1)) + remaining_time=$((timeout - elapsed_time)) + echo "Awaiting the mounting of ${file_path}. Wait time remaining: ${remaining_time} seconds" + sleep 1 + done + } + wait_for_file "${script_path}" 300 + wait_for_file "${mounted_token_path}" 300 + chmod +x "${script_path}" + sh "${script_path}" volumeMounts: - name: vol1 mountPath: /local/user/vol1 - name: stagemount - mountPath: /$stage + mountPath: "/${stage}" volume: - name: vol1 source: local # only local emptyDir volume is supported - name: stagemount - source: "@$stage" + source: "@${stage}" diff --git a/snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template b/snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template index 9d83cc24..70a4e5fe 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +++ b/snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template @@ -11,18 +11,41 @@ cleanup() { kill -- -$$$ # Kill the entire process group. Extra $ to escape, the generated shell script should have two $. } +# SNOW-990976, This is an additional safety check to ensure token file exists, on top of the token file check upon +# launching SPCS job. This additional check could provide value in cases things go wrong with token refresh that result +# in token file to disappear. +wait_till_token_file_exists() { + timeout=60 # 1 minute timeout + elapsed_time=0 + + while [ ! -f "${SESSION_TOKEN_PATH}" ] && [ "$elapsed_time" -lt "$timeout" ]; do + sleep 1 + elapsed_time=$((elapsed_time + 1)) + remaining_time=$((timeout - elapsed_time)) + echo "Waiting for token file to exist. Wait time remaining: ${remaining_time} seconds." + done + + if [ ! -f "${SESSION_TOKEN_PATH}" ]; then + echo "Error: Token file '${SESSION_TOKEN_PATH}' does not show up within the ${timeout} seconds timeout period." + exit 1 + fi +} + generate_registry_cred() { + wait_till_token_file_exists AUTH_TOKEN=$(printf '0auth2accesstoken:%s' "$(cat ${SESSION_TOKEN_PATH})" | base64); echo '{"auths":{"$image_repo":{"auth":"'"$AUTH_TOKEN"'"}}}' | tr -d '\n' > $REGISTRY_CRED_PATH; } on_session_token_change() { + wait_till_token_file_exists # Get the initial checksum of the file CHECKSUM=$(md5sum "${SESSION_TOKEN_PATH}" | awk '{ print $1 }') # Run the command once before the loop echo "Monitoring session token changes in the background..." ( while true; do + wait_till_token_file_exists # Get the current checksum of the file CURRENT_CHECKSUM=$(md5sum "${SESSION_TOKEN_PATH}" | awk '{ print $1 }') if [ "${CURRENT_CHECKSUM}" != "${CHECKSUM}" ]; then diff --git a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/kaniko_shell_script_fixture.sh b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/kaniko_shell_script_fixture.sh index e212c39f..36a9e449 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/kaniko_shell_script_fixture.sh +++ b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/kaniko_shell_script_fixture.sh @@ -11,18 +11,41 @@ cleanup() { kill -- -$$ # Kill the entire process group. Extra $ to escape, the generated shell script should have two $. } +# SNOW-990976, This is an additional safety check to ensure token file exists, on top of the token file check upon +# launching SPCS job. This additional check could provide value in cases things go wrong with token refresh that result +# in token file to disappear. +wait_till_token_file_exists() { + timeout=60 # 1 minute timeout + elapsed_time=0 + + while [ ! -f "${SESSION_TOKEN_PATH}" ] && [ "$elapsed_time" -lt "$timeout" ]; do + sleep 1 + elapsed_time=$((elapsed_time + 1)) + remaining_time=$((timeout - elapsed_time)) + echo "Waiting for token file to exist. Wait time remaining: ${remaining_time} seconds." + done + + if [ ! -f "${SESSION_TOKEN_PATH}" ]; then + echo "Error: Token file '${SESSION_TOKEN_PATH}' does not show up within the ${timeout} seconds timeout period." + exit 1 + fi +} + generate_registry_cred() { + wait_till_token_file_exists AUTH_TOKEN=$(printf '0auth2accesstoken:%s' "$(cat ${SESSION_TOKEN_PATH})" | base64); echo '{"auths":{"mock_image_repo":{"auth":"'"$AUTH_TOKEN"'"}}}' | tr -d '\n' > $REGISTRY_CRED_PATH; } on_session_token_change() { + wait_till_token_file_exists # Get the initial checksum of the file CHECKSUM=$(md5sum "${SESSION_TOKEN_PATH}" | awk '{ print $1 }') # Run the command once before the loop echo "Monitoring session token changes in the background..." ( while true; do + wait_till_token_file_exists # Get the current checksum of the file CURRENT_CHECKSUM=$(md5sum "${SESSION_TOKEN_PATH}" | awk '{ print $1 }') if [ "${CURRENT_CHECKSUM}" != "${CHECKSUM}" ]; then diff --git a/snowflake/ml/model/_deploy_client/snowservice/BUILD.bazel b/snowflake/ml/model/_deploy_client/snowservice/BUILD.bazel index 0084687e..2cc37a5c 100644 --- a/snowflake/ml/model/_deploy_client/snowservice/BUILD.bazel +++ b/snowflake/ml/model/_deploy_client/snowservice/BUILD.bazel @@ -21,8 +21,10 @@ py_library( deps = [ ":deploy_options", ":instance_types", + "//snowflake/ml/_internal:env_utils", "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:spcs_attribution_utils", "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_deploy_client/image_builds:base_image_builder", "//snowflake/ml/model/_deploy_client/image_builds:client_image_builder", diff --git a/snowflake/ml/model/_deploy_client/snowservice/deploy.py b/snowflake/ml/model/_deploy_client/snowservice/deploy.py index 413f479d..748ce594 100644 --- a/snowflake/ml/model/_deploy_client/snowservice/deploy.py +++ b/snowflake/ml/model/_deploy_client/snowservice/deploy.py @@ -10,14 +10,19 @@ import importlib_resources import yaml +from packaging import requirements from typing_extensions import Unpack -from snowflake.ml._internal import file_utils +from snowflake.ml._internal import env_utils, file_utils from snowflake.ml._internal.exceptions import ( error_codes, exceptions as snowml_exceptions, ) -from snowflake.ml._internal.utils import identifier, query_result_checker +from snowflake.ml._internal.utils import ( + identifier, + query_result_checker, + spcs_attribution_utils, +) from snowflake.ml.model import type_hints from snowflake.ml.model._deploy_client import snowservice from snowflake.ml.model._deploy_client.image_builds import ( @@ -161,6 +166,11 @@ def _deploy( # Set conda-forge as backup channel for SPCS deployment if "conda-forge" not in model_meta_deploy.env._conda_dependencies: model_meta_deploy.env._conda_dependencies["conda-forge"] = [] + # Snowflake connector needs pyarrow to work correctly. + env_utils.append_conda_dependency( + model_meta_deploy.env._conda_dependencies, + (env_utils.DEFAULT_CHANNEL_NAME, requirements.Requirement("pyarrow")), + ) if options.use_gpu: # Make mypy happy assert options.num_gpus is not None @@ -585,6 +595,8 @@ def _deploy_workflow(self, image: str) -> str: ) logger.info(f"Service {self._service_name} is ready. Creating service function...") + spcs_attribution_utils.record_service_start(self.session, self._service_name) + service_function_sql = client.create_or_replace_service_function( service_func_name=self.service_func_name, service_name=self._service_name, diff --git a/snowflake/ml/model/_deploy_client/utils/constants.py b/snowflake/ml/model/_deploy_client/utils/constants.py index 4762df71..edc441df 100644 --- a/snowflake/ml/model/_deploy_client/utils/constants.py +++ b/snowflake/ml/model/_deploy_client/utils/constants.py @@ -50,3 +50,4 @@ class ResourceStatus(Enum): KANIKO_CONTAINER_NAME = "kaniko" LATEST_IMAGE_TAG = "latest" KANIKO_IMAGE = "kaniko-project/executor:v1.16.0-debug" +SPCS_MOUNTED_TOKEN_PATH = "/snowflake/session/token" diff --git a/snowflake/ml/model/_deploy_client/warehouse/deploy.py b/snowflake/ml/model/_deploy_client/warehouse/deploy.py index eb62cc3c..0f20fc4a 100644 --- a/snowflake/ml/model/_deploy_client/warehouse/deploy.py +++ b/snowflake/ml/model/_deploy_client/warehouse/deploy.py @@ -173,7 +173,7 @@ def _get_model_final_packages( else: required_packages = meta.env._conda_dependencies[env_utils.DEFAULT_CHANNEL_NAME] - final_packages = env_utils.validate_requirements_in_snowflake_conda_channel( + final_packages = env_utils.validate_requirements_in_information_schema( session, required_packages, python_version=meta.env.python_version ) @@ -182,7 +182,7 @@ def _get_model_final_packages( raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.DEPENDENCY_VERSION_ERROR, original_exception=RuntimeError( - "The model's dependencyies are not available in Snowflake Anaconda Channel. " + "The model's dependencies are not available in Snowflake Anaconda Channel. " + relax_version_info_str + "Required packages are:\n" + " ".join(map(lambda x: f'"{x}"', required_packages)) diff --git a/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py b/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py index d0a75fd1..40250df4 100644 --- a/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py +++ b/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py @@ -91,7 +91,7 @@ def add_packages(self, packages_dicts: Dict[str, List[str]]) -> None: def test_get_model_final_packages(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: - env_utils._SNOWFLAKE_CONDA_PACKAGE_CACHE = {} + env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} with model_meta.create_model_metadata( model_dir_path=tmpdir, name="model1", model_type="custom", signatures=_DUMMY_SIG ) as meta: @@ -103,7 +103,7 @@ def test_get_model_final_packages(self) -> None: def test_get_model_final_packages_no_relax(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: - env_utils._SNOWFLAKE_CONDA_PACKAGE_CACHE = {} + env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} with model_meta.create_model_metadata( model_dir_path=tmpdir, name="model1", @@ -118,7 +118,7 @@ def test_get_model_final_packages_no_relax(self) -> None: def test_get_model_final_packages_relax(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: - env_utils._SNOWFLAKE_CONDA_PACKAGE_CACHE = {} + env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} with model_meta.create_model_metadata( model_dir_path=tmpdir, name="model1", @@ -136,7 +136,7 @@ def test_get_model_final_packages_relax(self) -> None: def test_get_model_final_packages_with_pip(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: - env_utils._SNOWFLAKE_CONDA_PACKAGE_CACHE = {} + env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} with model_meta.create_model_metadata( model_dir_path=tmpdir, name="model1", @@ -151,7 +151,7 @@ def test_get_model_final_packages_with_pip(self) -> None: def test_get_model_final_packages_with_other_channel(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: - env_utils._SNOWFLAKE_CONDA_PACKAGE_CACHE = {} + env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} with model_meta.create_model_metadata( model_dir_path=tmpdir, name="model1", @@ -166,7 +166,7 @@ def test_get_model_final_packages_with_other_channel(self) -> None: def test_get_model_final_packages_with_non_exist_package(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: - env_utils._SNOWFLAKE_CONDA_PACKAGE_CACHE = {} + env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} d = { **{ basic_dep.name: [importlib_metadata.version(basic_dep.name)] diff --git a/snowflake/ml/model/_model_composer/model_composer.py b/snowflake/ml/model/_model_composer/model_composer.py index c842410f..c04a1a40 100644 --- a/snowflake/ml/model/_model_composer/model_composer.py +++ b/snowflake/ml/model/_model_composer/model_composer.py @@ -3,7 +3,7 @@ import tempfile import zipfile from types import ModuleType -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from absl import logging from packaging import requirements @@ -32,8 +32,15 @@ class ModelComposer: """ MODEL_FILE_REL_PATH = "model.zip" + MODEL_DIR_REL_PATH = "model" - def __init__(self, session: Session, stage_path: str) -> None: + def __init__( + self, + session: Session, + stage_path: str, + *, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: self.session = session self.stage_path = pathlib.PurePosixPath(stage_path) @@ -43,6 +50,8 @@ def __init__(self, session: Session, stage_path: str) -> None: self.packager = model_packager.ModelPackager(local_dir_path=str(self._packager_workspace_path)) self.manifest = model_manifest.ModelManifest(workspace_path=self.workspace_path) + self._statement_params = statement_params + def __del__(self) -> None: self._workspace.cleanup() self._packager_workspace.cleanup() @@ -82,13 +91,11 @@ def save( options = model_types.BaseModelSaveOption() if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call] - snowml_server_availability = env_utils.validate_requirements_in_snowflake_conda_channel( - session=self.session, - reqs=[requirements.Requirement(f"snowflake-ml-python=={snowml_env.VERSION}")], - python_version=snowml_env.PYTHON_VERSION, + snowml_matched_versions = env_utils.get_matched_package_versions_in_snowflake_conda_channel( + req=requirements.Requirement(f"snowflake-ml-python=={snowml_env.VERSION}") ) - if snowml_server_availability is None and options.get("embed_local_ml_library", False) is False: + if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False: logging.info( f"Local snowflake-ml-python library has version {snowml_env.VERSION}," " which is not available in the Snowflake server, embedding local ML library automatically." @@ -111,6 +118,13 @@ def save( assert self.packager.meta is not None + if not options.get("_legacy_save", False): + # Keep both loose files and zipped file. + # TODO(SNOW-726678): Remove once import a directory is possible. + file_utils.copytree( + str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH) + ) + file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path)) self.manifest.save( @@ -120,7 +134,12 @@ def save( options=options, ) - file_utils.upload_directory_to_stage(self.session, local_path=self.workspace_path, stage_path=self.stage_path) + file_utils.upload_directory_to_stage( + self.session, + local_path=self.workspace_path, + stage_path=self.stage_path, + statement_params=self._statement_params, + ) def load( self, @@ -129,7 +148,10 @@ def load( options: Optional[model_types.ModelLoadOption] = None, ) -> None: file_utils.download_directory_from_stage( - self.session, stage_path=self.stage_path, local_path=self.workspace_path + self.session, + stage_path=self.stage_path, + local_path=self.workspace_path, + statement_params=self._statement_params, ) # TODO (Server-side Model Rollout): Remove this section. diff --git a/snowflake/ml/model/_model_composer/model_composer_test.py b/snowflake/ml/model/_model_composer/model_composer_test.py index 5cb3cdb5..6a46977a 100644 --- a/snowflake/ml/model/_model_composer/model_composer_test.py +++ b/snowflake/ml/model/_model_composer/model_composer_test.py @@ -1,3 +1,5 @@ +import os +import pathlib from typing import cast from unittest import mock @@ -6,13 +8,13 @@ from absl.testing import absltest from sklearn import linear_model -from snowflake.ml._internal import env_utils +from snowflake.ml._internal import env_utils, file_utils from snowflake.ml.model._model_composer import model_composer from snowflake.ml.modeling.linear_model import ( # type:ignore[attr-defined] LinearRegression, ) from snowflake.ml.test_utils import mock_session -from snowflake.snowpark import FileOperation, Session +from snowflake.snowpark import Session class ModelInterfaceTest(absltest.TestCase): @@ -28,34 +30,46 @@ def test_save_interface(self) -> None: mock_pk.meta = mock.MagicMock() mock_pk.meta.signatures = mock.MagicMock() m = model_composer.ModelComposer(session=c_session, stage_path=stage_path) + + with open(os.path.join(m._packager_workspace_path, "model.yaml"), "w", encoding="utf-8") as f: + f.write("") m.packager = mock_pk with mock.patch.object(m.packager, "save") as mock_save: with mock.patch.object(m.manifest, "save") as mock_manifest_save: - with mock.patch.object(FileOperation, "put", return_value=None) as mock_put_stream: - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=[""] - ): + with mock.patch.object( + file_utils, "upload_directory_to_stage", return_value=None + ) as mock_upload_directory_to_stage: + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): m.save( name="model1", model=LinearRegression(), ) mock_save.assert_called_once() mock_manifest_save.assert_called_once() + mock_upload_directory_to_stage.assert_called_once_with( + c_session, local_path=mock.ANY, stage_path=pathlib.PurePosixPath(stage_path), statement_params=None + ) m = model_composer.ModelComposer(session=c_session, stage_path=stage_path) m.packager = mock_pk + with open(os.path.join(m._packager_workspace_path, "model.yaml"), "w", encoding="utf-8") as f: + f.write("") with mock.patch.object(m.packager, "save") as mock_save: with mock.patch.object(m.manifest, "save") as mock_manifest_save: - with mock.patch.object(FileOperation, "put", return_value=None) as mock_put_stream: - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=[""] - ): + with mock.patch.object( + file_utils, "upload_directory_to_stage", return_value=None + ) as mock_upload_directory_to_stage: + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): m.save( name="model1", model=linear_model.LinearRegression(), sample_input=d, ) - mock_put_stream.assert_called_once_with(mock.ANY, stage_path, auto_compress=False, overwrite=False) + mock_save.assert_called_once() + mock_manifest_save.assert_called_once() + mock_upload_directory_to_stage.assert_called_once_with( + c_session, local_path=mock.ANY, stage_path=pathlib.PurePosixPath(stage_path), statement_params=None + ) if __name__ == "__main__": diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py index 98b24565..3d4d4f70 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py @@ -1,5 +1,6 @@ +import collections import pathlib -from typing import List, Optional +from typing import List, Optional, cast import yaml @@ -48,7 +49,6 @@ def save( ] self.function_generator = function_generator.FunctionGenerator(model_file_rel_path=model_file_rel_path) self.methods: List[model_method.ModelMethod] = [] - _seen_method_names: List[str] = [] for target_method in model_meta.signatures.keys(): method = model_method.ModelMethod( model_meta=model_meta, @@ -57,17 +57,18 @@ def save( function_generator=self.function_generator, options=model_method.get_model_method_options_from_options(options, target_method), ) - if method.method_name in _seen_method_names: - raise ValueError( - f"Found duplicate method named resolved as {method.method_name} in the model. " - "This might because you have methods with same letters but different cases. " - "In this case, set case_sensitive as True for those methods to distinguish them" - ) - else: - _seen_method_names.append(method.method_name) self.methods.append(method) + method_name_counter = collections.Counter([method.method_name for method in self.methods]) + dup_method_names = [k for k, v in method_name_counter.items() if v > 1] + if dup_method_names: + raise ValueError( + f"Found duplicate method named resolved as {', '.join(dup_method_names)} in the model. " + "This might because you have methods with same letters but different cases. " + "In this case, set case_sensitive as True for those methods to distinguish them." + ) + manifest_dict = model_manifest_schema.ModelManifestDict( manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION, runtimes={runtime.name: runtime.save(self.workspace_path) for runtime in self.runtimes}, @@ -84,3 +85,17 @@ def save( with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f: yaml.safe_dump(manifest_dict, f) + + def load(self) -> model_manifest_schema.ModelManifestDict: + with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("r", encoding="utf-8") as f: + raw_input = yaml.safe_load(f) + if not isinstance(raw_input, dict): + raise ValueError(f"Read ill-formatted model MANIFEST, should be a dict, received {type(raw_input)}") + + original_loaded_manifest_version = raw_input.get("manifest_version", None) + if not original_loaded_manifest_version: + raise ValueError("Unable to get the version of the MANIFEST file.") + + res = cast(model_manifest_schema.ModelManifestDict, raw_input) + + return res diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py index efc5126a..2df33b9b 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py @@ -1,6 +1,6 @@ # This files contains schema definition of what will be written into MANIFEST.yml -from typing import Dict, List, Literal, TypedDict +from typing import Any, Dict, List, Literal, TypedDict from typing_extensions import NotRequired, Required @@ -42,4 +42,4 @@ class ModelManifestDict(TypedDict): manifest_version: Required[str] runtimes: Required[Dict[str, ModelRuntimeDict]] methods: Required[List[ModelMethodDict]] - user_data: NotRequired[Dict[str, str]] + user_data: NotRequired[Dict[str, Any]] diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py index cf1109c6..50bfacd1 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py @@ -38,9 +38,7 @@ def test_model_manifest_1(self) -> None: model_dir_path=tmpdir, name="model1", model_type="custom", signatures=_DUMMY_SIG ) as meta: meta.models["model1"] = _DUMMY_BLOB - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=[""] - ): + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): mm.save(self.m_session, meta, pathlib.PurePosixPath("model.zip")) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: loaded_manifest = yaml.safe_load(f) @@ -62,7 +60,7 @@ def test_model_manifest_1(self) -> None: "runtime": "python_runtime", "type": "FUNCTION", "handler": "functions.predict.infer", - "inputs": [{"name": "tmp_input", "type": "OBJECT"}], + "inputs": [{"name": "INPUT", "type": "FLOAT"}], "outputs": [{"type": "OBJECT"}], } ], @@ -89,9 +87,7 @@ def test_model_manifest_2(self) -> None: signatures={"__call__": _DUMMY_SIG["predict"]}, ) as meta: meta.models["model1"] = _DUMMY_BLOB - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=[""] - ): + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): mm.save( self.m_session, meta, @@ -120,7 +116,7 @@ def test_model_manifest_2(self) -> None: "runtime": "python_runtime", "type": "FUNCTION", "handler": "functions.__call__.infer", - "inputs": [{"name": "tmp_input", "type": "OBJECT"}], + "inputs": [{"name": "INPUT", "type": "FLOAT"}], "outputs": [{"type": "OBJECT"}], } ], @@ -147,9 +143,7 @@ def test_model_manifest_mix(self) -> None: signatures={"predict": _DUMMY_SIG["predict"], "__call__": _DUMMY_SIG["predict"]}, ) as meta: meta.models["model1"] = _DUMMY_BLOB - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=None - ): + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=None): mm.save( self.m_session, meta, @@ -177,11 +171,11 @@ def test_model_manifest_mix(self) -> None: }, "methods": [ { - "name": '"predict"', + "name": "predict", "runtime": "python_runtime", "type": "FUNCTION", "handler": "functions.predict.infer", - "inputs": [{"name": "tmp_input", "type": "OBJECT"}], + "inputs": [{"name": "input", "type": "FLOAT"}], "outputs": [{"type": "OBJECT"}], }, { @@ -189,7 +183,7 @@ def test_model_manifest_mix(self) -> None: "runtime": "python_runtime", "type": "FUNCTION", "handler": "functions.__call__.infer", - "inputs": [{"name": "tmp_input", "type": "OBJECT"}], + "inputs": [{"name": "INPUT", "type": "FLOAT"}], "outputs": [{"type": "OBJECT"}], }, ], @@ -226,9 +220,7 @@ def test_model_manifest_bad(self) -> None: signatures={"predict": _DUMMY_SIG["predict"], "PREDICT": _DUMMY_SIG["predict"]}, ) as meta: meta.models["model1"] = _DUMMY_BLOB - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=[""] - ): + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): with self.assertRaisesRegex( ValueError, "Found duplicate method named resolved as PREDICT in the model." ): @@ -238,6 +230,73 @@ def test_model_manifest_bad(self) -> None: pathlib.PurePosixPath("model.zip"), ) + def test_load(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "MANIFEST.yml"), "w", encoding="utf-8") as f: + yaml.safe_dump({}, f) + + mm = model_manifest.ModelManifest(pathlib.Path(tmpdir)) + + with self.assertRaisesRegex(ValueError, "Unable to get the version of the MANIFEST file."): + mm.load() + + raw_input = { + "manifest_version": "1.0", + "runtimes": { + "python_runtime": { + "language": "PYTHON", + "version": "3.8", + "imports": ["model.zip", "runtimes/python_runtime/snowflake-ml-python.zip"], + "dependencies": {"conda": "runtimes/python_runtime/env/conda.yml"}, + } + }, + "methods": [ + { + "name": "predict", + "runtime": "python_runtime", + "type": "FUNCTION", + "handler": "functions.predict.infer", + "inputs": [{"name": "input", "type": "FLOAT"}], + "outputs": [{"type": "OBJECT"}], + }, + { + "name": "__CALL__", + "runtime": "python_runtime", + "type": "FUNCTION", + "handler": "functions.__call__.infer", + "inputs": [{"name": "INPUT", "type": "FLOAT"}], + "outputs": [{"type": "OBJECT"}], + }, + ], + } + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "MANIFEST.yml"), "w", encoding="utf-8") as f: + yaml.safe_dump(raw_input, f) + + mm = model_manifest.ModelManifest(pathlib.Path(tmpdir)) + + self.assertDictEqual(raw_input, mm.load()) + + raw_input["user_data"] = {} + + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "MANIFEST.yml"), "w", encoding="utf-8") as f: + yaml.safe_dump(raw_input, f) + + mm = model_manifest.ModelManifest(pathlib.Path(tmpdir)) + + self.assertDictEqual(raw_input, mm.load()) + + raw_input["user_data"] = {"description": "Hello"} + + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "MANIFEST.yml"), "w", encoding="utf-8") as f: + yaml.safe_dump(raw_input, f) + + mm = model_manifest.ModelManifest(pathlib.Path(tmpdir)) + + self.assertDictEqual(raw_input, mm.load()) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_model_composer/model_method/BUILD.bazel b/snowflake/ml/model/_model_composer/model_method/BUILD.bazel index 23d180f5..27924817 100644 --- a/snowflake/ml/model/_model_composer/model_method/BUILD.bazel +++ b/snowflake/ml/model/_model_composer/model_method/BUILD.bazel @@ -38,6 +38,7 @@ py_library( deps = [ ":function_generator", "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model:model_signature", "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", "//snowflake/ml/model/_packager/model_meta", diff --git a/snowflake/ml/model/_model_composer/model_method/fixtures/function_fixture_1.py_fixture b/snowflake/ml/model/_model_composer/model_method/fixtures/function_fixture_1.py_fixture index d51dc6c8..c14c9997 100644 --- a/snowflake/ml/model/_model_composer/model_method/fixtures/function_fixture_1.py_fixture +++ b/snowflake/ml/model/_model_composer/model_method/fixtures/function_fixture_1.py_fixture @@ -73,6 +73,7 @@ dtype_map = {feature.name: feature.as_dtype() for feature in features} # Actual function @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE) def infer(df: pd.DataFrame) -> dict: - input_df = pd.json_normalize(df[0]).astype(dtype=dtype_map) + df.columns = input_cols + input_df = df.astype(dtype=dtype_map) predictions_df = runner(input_df[input_cols]) return predictions_df.to_dict("records") diff --git a/snowflake/ml/model/_model_composer/model_method/fixtures/function_fixture_2.py_fixture b/snowflake/ml/model/_model_composer/model_method/fixtures/function_fixture_2.py_fixture index fa356184..c1bd1d0e 100644 --- a/snowflake/ml/model/_model_composer/model_method/fixtures/function_fixture_2.py_fixture +++ b/snowflake/ml/model/_model_composer/model_method/fixtures/function_fixture_2.py_fixture @@ -73,6 +73,7 @@ dtype_map = {feature.name: feature.as_dtype() for feature in features} # Actual function @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE) def infer(df: pd.DataFrame) -> dict: - input_df = pd.json_normalize(df[0]).astype(dtype=dtype_map) + df.columns = input_cols + input_df = df.astype(dtype=dtype_map) predictions_df = runner(input_df[input_cols]) return predictions_df.to_dict("records") diff --git a/snowflake/ml/model/_model_composer/model_method/infer_function.py_template b/snowflake/ml/model/_model_composer/model_method/infer_function.py_template index bd041620..8556d934 100644 --- a/snowflake/ml/model/_model_composer/model_method/infer_function.py_template +++ b/snowflake/ml/model/_model_composer/model_method/infer_function.py_template @@ -73,6 +73,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}} # Actual function @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE) def {function_name}(df: pd.DataFrame) -> dict: - input_df = pd.json_normalize(df[0]).astype(dtype=dtype_map) + df.columns = input_cols + input_df = df.astype(dtype=dtype_map) predictions_df = runner(input_df[input_cols]) return predictions_df.to_dict("records") diff --git a/snowflake/ml/model/_model_composer/model_method/model_method.py b/snowflake/ml/model/_model_composer/model_method/model_method.py index 94026777..6d5a9b16 100644 --- a/snowflake/ml/model/_model_composer/model_method/model_method.py +++ b/snowflake/ml/model/_model_composer/model_method/model_method.py @@ -1,13 +1,15 @@ +import collections import pathlib from typing import Optional, TypedDict from typing_extensions import NotRequired from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.model import type_hints +from snowflake.ml.model import model_signature, type_hints from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.model._model_composer.model_method import function_generator from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api +from snowflake.snowpark._internal import type_utils class ModelMethodOptions(TypedDict): @@ -69,6 +71,22 @@ def __init__( if self.target_method not in self.model_meta.signatures.keys(): raise ValueError(f"Target method {self.target_method} is not available in the signatures of the model.") + @staticmethod + def _get_method_arg_from_feature( + feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False + ) -> model_manifest_schema.ModelMethodSignatureFieldWithName: + assert isinstance(feature, model_signature.FeatureSpec), "FeatureGroupSpec is not supported." + try: + feature_name = sql_identifier.SqlIdentifier(feature.name, case_sensitive=case_sensitive) + except ValueError as e: + raise ValueError( + f"Your feature {feature.name} cannot be resolved as valid SQL identifier. " + "Try specify `case_sensitive` as True." + ) from e + return model_manifest_schema.ModelMethodSignatureFieldWithName( + name=feature_name.resolved(), type=type_utils.convert_sp_to_sf_type(feature.as_snowpark_type()) + ) + def save( self, workspace_path: pathlib.Path, options: Optional[function_generator.FunctionGenerateOptions] = None ) -> model_manifest_schema.ModelMethodDict: @@ -78,13 +96,26 @@ def save( self.target_method, options=options, ) + input_list = [ + ModelMethod._get_method_arg_from_feature(ft, case_sensitive=self.options.get("case_sensitive", False)) + for ft in self.model_meta.signatures[self.target_method].inputs + ] + input_name_counter = collections.Counter([input_info["name"] for input_info in input_list]) + dup_input_names = [k for k, v in input_name_counter.items() if v > 1] + if dup_input_names: + raise ValueError( + f"Found duplicate input feature named resolved as {', '.join(dup_input_names)} in the method" + f" {self.target_method} This might because you have methods with same letters but different cases. " + "In this case, set case_sensitive as True for those methods to distinguish them." + ) + return model_manifest_schema.ModelFunctionMethodDict( - name=self.method_name.identifier(), + name=self.method_name.resolved(), runtime=self.runtime_name, type="FUNCTION", handler=".".join( [ModelMethod.FUNCTIONS_DIR_REL_PATH, self.target_method, self.function_generator.FUNCTION_NAME] ), - inputs=[model_manifest_schema.ModelMethodSignatureFieldWithName(name="tmp_input", type="OBJECT")], + inputs=input_list, outputs=[model_manifest_schema.ModelMethodSignatureField(type="OBJECT")], ) diff --git a/snowflake/ml/model/_model_composer/model_method/model_method_test.py b/snowflake/ml/model/_model_composer/model_method/model_method_test.py index 6ce76d98..0594641d 100644 --- a/snowflake/ml/model/_model_composer/model_method/model_method_test.py +++ b/snowflake/ml/model/_model_composer/model_method/model_method_test.py @@ -16,6 +16,7 @@ "predict": model_signature.ModelSignature( inputs=[ model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="input"), + model_signature.FeatureSpec(dtype=model_signature.DataType.STRING, name="name"), ], outputs=[model_signature.FeatureSpec(name="output", dtype=model_signature.DataType.FLOAT)], ) @@ -59,7 +60,7 @@ def test_model_method(self) -> None: "runtime": "python_runtime", "type": "FUNCTION", "handler": "functions.predict.infer", - "inputs": [{"name": "tmp_input", "type": "OBJECT"}], + "inputs": [{"name": "INPUT", "type": "FLOAT"}, {"name": "NAME", "type": "STRING"}], "outputs": [{"type": "OBJECT"}], }, ) @@ -98,7 +99,7 @@ def test_model_method(self) -> None: "runtime": "python_runtime", "type": "FUNCTION", "handler": "functions.__call__.infer", - "inputs": [{"name": "tmp_input", "type": "OBJECT"}], + "inputs": [{"name": "INPUT", "type": "FLOAT"}, {"name": "NAME", "type": "STRING"}], "outputs": [{"type": "OBJECT"}], }, ) @@ -159,11 +160,11 @@ def test_model_method(self) -> None: self.assertDictEqual( method_dict, { - "name": '"predict"', + "name": "predict", "runtime": "python_runtime", "type": "FUNCTION", "handler": "functions.predict.infer", - "inputs": [{"name": "tmp_input", "type": "OBJECT"}], + "inputs": [{"name": "input", "type": "FLOAT"}, {"name": "name", "type": "STRING"}], "outputs": [{"type": "OBJECT"}], }, ) diff --git a/snowflake/ml/model/_model_composer/model_runtime/model_runtime.py b/snowflake/ml/model/_model_composer/model_runtime/model_runtime.py index 1e07ec68..aa90d3f3 100644 --- a/snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +++ b/snowflake/ml/model/_model_composer/model_runtime/model_runtime.py @@ -44,7 +44,7 @@ def __init__( if self.runtime_env._snowpark_ml_version.local: self.embed_local_ml_library = True else: - snowml_server_availability = env_utils.validate_requirements_in_snowflake_conda_channel( + snowml_server_availability = env_utils.validate_requirements_in_information_schema( session=session, reqs=[requirements.Requirement(snowml_pkg_spec)], python_version=snowml_env.PYTHON_VERSION, diff --git a/snowflake/ml/model/_model_composer/model_runtime/model_runtime_test.py b/snowflake/ml/model/_model_composer/model_runtime/model_runtime_test.py index 5ede2068..1a10e220 100644 --- a/snowflake/ml/model/_model_composer/model_runtime/model_runtime_test.py +++ b/snowflake/ml/model/_model_composer/model_runtime/model_runtime_test.py @@ -56,9 +56,7 @@ def test_model_runtime(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=[""] - ): + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -85,9 +83,7 @@ def test_model_runtime_local_snowml(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=None - ): + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=None): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -122,9 +118,7 @@ def test_model_runtime_dup_basic_dep(self) -> None: dep_target.append("pandas") dep_target.sort() - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=[""] - ): + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -150,9 +144,7 @@ def test_model_runtime_dup_basic_dep_other_channel(self) -> None: dep_target.append("conda-forge::pandas") dep_target.sort() - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=[""] - ): + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -177,7 +169,7 @@ def test_model_runtime_dup_basic_dep_pip(self) -> None: dep_target.remove(f"pandas=={importlib_metadata.version('pandas')}") dep_target.sort() - with mock.patch.object(env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=[""]): + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -202,9 +194,7 @@ def test_model_runtime_additional_conda_dep(self) -> None: dep_target.append("pytorch") dep_target.sort() - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=[""] - ): + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -228,9 +218,7 @@ def test_model_runtime_additional_pip_dep(self) -> None: dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] dep_target.sort() - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=[""] - ): + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -256,9 +244,7 @@ def test_model_runtime_additional_dep_both(self) -> None: dep_target.append("pytorch") dep_target.sort() - with mock.patch.object( - env_utils, "validate_requirements_in_snowflake_conda_channel", return_value=[""] - ): + with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) diff --git a/snowflake/ml/model/_packager/model_handlers/BUILD.bazel b/snowflake/ml/model/_packager/model_handlers/BUILD.bazel index 585d7cba..eb19bd50 100644 --- a/snowflake/ml/model/_packager/model_handlers/BUILD.bazel +++ b/snowflake/ml/model/_packager/model_handlers/BUILD.bazel @@ -64,7 +64,6 @@ py_library( srcs = ["snowmlmodel.py"], deps = [ ":_base", - ":_utils", "//snowflake/ml/_internal:type_utils", "//snowflake/ml/model:custom_model", "//snowflake/ml/model:model_signature", diff --git a/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py b/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py index 7dcabd30..47459cfc 100644 --- a/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +++ b/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py @@ -59,7 +59,7 @@ def get_requirements_from_task(task: str, spcs_only: bool = False) -> List[model return ( [model_env.ModelDependency(requirement="tokenizers>=0.13.3", pip_name="tokenizers")] if spcs_only - else [model_env.ModelDependency(requirement="tokenizers", pip_name="tokenizers")] + else [model_env.ModelDependency(requirement="tokenizers<=0.13.2", pip_name="tokenizers")] ) return [] @@ -170,6 +170,7 @@ def save_model( " `snowflake.ml.model.models.huggingface_pipeline.HuggingFacePipelineModel` object. " "Please make sure you are providing correct model signatures.", UserWarning, + stacklevel=2, ) else: handlers_utils.validate_target_methods(model, target_methods) @@ -179,6 +180,7 @@ def save_model( + "Model signature will automatically be inferred from pipeline task. " + "Or, you could specify model signature manually.", UserWarning, + stacklevel=2, ) if inferred_pipe_sig is None: raise NotImplementedError(f"Cannot auto infer the signature of pipeline for task {task}") diff --git a/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py b/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py index cbafe76f..567d2f46 100644 --- a/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +++ b/snowflake/ml/model/_packager/model_handlers/snowmlmodel.py @@ -1,4 +1,5 @@ import os +import warnings from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final import cloudpickle @@ -9,7 +10,7 @@ from snowflake.ml._internal import type_utils from snowflake.ml.model import custom_model, model_signature, type_hints as model_types from snowflake.ml.model._packager.model_env import model_env -from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils +from snowflake.ml.model._packager.model_handlers import _base from snowflake.ml.model._packager.model_handlers_migrator import base_migrator from snowflake.ml.model._packager.model_meta import ( model_blob_meta, @@ -78,34 +79,15 @@ def save_model( # Pipeline is inherited from BaseEstimator, so no need to add one more check if not is_sub_model: - if (not model_meta.signatures) and sample_input is None: - assert hasattr(model, "model_signatures") - model_meta.signatures = getattr(model, "model_signatures", {}) - else: - target_methods = handlers_utils.get_target_methods( - model=model, - target_methods=kwargs.pop("target_methods", None), - default_target_methods=cls.DEFAULT_TARGET_METHODS, - ) - - def get_prediction( - target_method_name: str, sample_input: model_types.SupportedLocalDataType - ) -> model_types.SupportedLocalDataType: - if not isinstance(sample_input, (pd.DataFrame,)): - sample_input = model_signature._convert_local_data_to_df(sample_input) - - target_method = getattr(model, target_method_name, None) - assert callable(target_method) - predictions_df = target_method(sample_input) - return predictions_df - - model_meta = handlers_utils.validate_signature( - model=model, - model_meta=model_meta, - target_methods=target_methods, - sample_input=sample_input, - get_prediction_fn=get_prediction, + if sample_input is not None or model_meta.signatures: + warnings.warn( + "Inferring model signature from sample input or providing model signature for Snowpark ML " + + "Modeling model is not required. Model signature will automatically be inferred during fitting. ", + UserWarning, + stacklevel=2, ) + assert hasattr(model, "model_signatures"), "Model does not have model signatures as expected." + model_meta.signatures = getattr(model, "model_signatures", {}) model_blob_path = os.path.join(model_blobs_dir_path, name) os.makedirs(model_blob_path, exist_ok=True) diff --git a/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py b/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py index 7e3b1c32..59c319a2 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/snowmlmodel_test.py @@ -31,21 +31,27 @@ def test_snowml_all_input(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: s = {"predict": model_signature.infer_signature(df[INPUT_COLUMNS], regr.predict(df)[[OUTPUT_COLUMNS]])} - with self.assertRaises(ValueError): + with self.assertWarnsRegex(UserWarning, "Model signature will automatically be inferred during fitting"): model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( name="model1", model=regr, - signatures={**s, "another_predict": s["predict"]}, + signatures=s, + metadata={"author": "halu", "version": "1"}, + ) + with self.assertWarnsRegex(UserWarning, "Model signature will automatically be inferred during fitting"): + model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( + name="model1_no_sig", + model=regr, + sample_input=df[INPUT_COLUMNS], metadata={"author": "halu", "version": "1"}, ) + with tempfile.TemporaryDirectory() as tmpdir: model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( name="model1", model=regr, - signatures=s, metadata={"author": "halu", "version": "1"}, ) - with warnings.catch_warnings(): warnings.simplefilter("error") @@ -64,30 +70,6 @@ def test_snowml_all_input(self) -> None: assert callable(predict_method) np.testing.assert_allclose(predictions, predict_method(df[:1])[[OUTPUT_COLUMNS]]) - model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( - name="model1_no_sig", - model=regr, - sample_input=df[INPUT_COLUMNS], - metadata={"author": "halu", "version": "1"}, - ) - - pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) - pk.load() - assert pk.model - assert pk.meta - assert isinstance(pk.model, LinearRegression) - np.testing.assert_allclose(predictions, desired=pk.model.predict(df[:1])[[OUTPUT_COLUMNS]]) - s = regr.model_signatures - self.assertEqual(s["predict"], pk.meta.signatures["predict"]) - - pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) - pk.load(as_custom_model=True) - assert pk.model - assert pk.meta - predict_method = getattr(pk.model, "predict", None) - assert callable(predict_method) - np.testing.assert_allclose(predictions, predict_method(df[:1])[[OUTPUT_COLUMNS]]) - def test_snowml_signature_partial_input(self) -> None: iris = datasets.load_iris() @@ -103,19 +85,9 @@ def test_snowml_signature_partial_input(self) -> None: predictions = regr.predict(df[:1])[[OUTPUT_COLUMNS]] with tempfile.TemporaryDirectory() as tmpdir: - s = {"predict": model_signature.infer_signature(df[INPUT_COLUMNS], regr.predict(df)[[OUTPUT_COLUMNS]])} - with self.assertRaises(ValueError): - model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( - name="model1", - model=regr, - signatures={**s, "another_predict": s["predict"]}, - metadata={"author": "halu", "version": "1"}, - ) - model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( name="model1", model=regr, - signatures=s, metadata={"author": "halu", "version": "1"}, ) @@ -137,31 +109,6 @@ def test_snowml_signature_partial_input(self) -> None: assert callable(predict_method) np.testing.assert_allclose(predictions, predict_method(df[:1])[[OUTPUT_COLUMNS]]) - model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( - name="model1_no_sig", - model=regr, - sample_input=df, - metadata={"author": "halu", "version": "1"}, - ) - - pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) - pk.load() - assert pk.model - assert pk.meta - assert isinstance(pk.model, LinearRegression) - np.testing.assert_allclose(predictions, pk.model.predict(df[:1])[[OUTPUT_COLUMNS]]) - s = regr.model_signatures - # Compare the Model Signature without indexing - self.assertItemsEqual(s["predict"].to_dict(), pk.meta.signatures["predict"].to_dict()) - - pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) - pk.load(as_custom_model=True) - assert pk.model - assert pk.meta - predict_method = getattr(pk.model, "predict", None) - assert callable(predict_method) - np.testing.assert_allclose(predictions, predict_method(df[:1])[[OUTPUT_COLUMNS]]) - def test_snowml_signature_drop_input_cols(self) -> None: iris = datasets.load_iris() @@ -179,19 +126,9 @@ def test_snowml_signature_drop_input_cols(self) -> None: predictions = regr.predict(df[:1])[[OUTPUT_COLUMNS]] with tempfile.TemporaryDirectory() as tmpdir: - s = {"predict": model_signature.infer_signature(df[INPUT_COLUMNS], regr.predict(df)[[OUTPUT_COLUMNS]])} - with self.assertRaises(ValueError): - model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( - name="model1", - model=regr, - signatures={**s, "another_predict": s["predict"]}, - metadata={"author": "halu", "version": "1"}, - ) - model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( name="model1", model=regr, - signatures=s, metadata={"author": "halu", "version": "1"}, ) @@ -213,31 +150,6 @@ def test_snowml_signature_drop_input_cols(self) -> None: assert callable(predict_method) np.testing.assert_allclose(predictions, predict_method(df[:1])[[OUTPUT_COLUMNS]]) - model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")).save( - name="model1_no_sig", - model=regr, - sample_input=df, - metadata={"author": "halu", "version": "1"}, - ) - - pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) - pk.load() - assert pk.model - assert pk.meta - assert isinstance(pk.model, LinearRegression) - np.testing.assert_allclose(predictions, pk.model.predict(df[:1])[[OUTPUT_COLUMNS]]) - s = regr.model_signatures - # Compare the Model Signature without indexing - self.assertItemsEqual(s["predict"].to_dict(), pk.meta.signatures["predict"].to_dict()) - - pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1_no_sig")) - pk.load(as_custom_model=True) - assert pk.model - assert pk.meta - predict_method = getattr(pk.model, "predict", None) - assert callable(predict_method) - np.testing.assert_allclose(predictions, predict_method(df[:1])[[OUTPUT_COLUMNS]]) - if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_packager/model_meta/model_meta.py b/snowflake/ml/model/_packager/model_meta/model_meta.py index 73785d2e..ae6a32e1 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta.py @@ -72,20 +72,22 @@ def create_model_metadata( """ model_dir_path = os.path.normpath(model_dir_path) embed_local_ml_library = kwargs.pop("embed_local_ml_library", False) - # Use the last one which is loaded first, that is mean, it is loaded from site-packages. - # We could make sure that user does not overwrite our library with their code follow the same naming. - snowml_path, snowml_start_path = file_utils.get_package_path(_SNOWFLAKE_ML_PKG_NAME, strategy="last") - if os.path.isdir(snowml_start_path): - path_to_copy = snowml_path - # If the package is zip-imported, then the path will be `../path_to_zip.zip/snowflake/ml` - # It is not a valid path in fact and we need to get the path to the zip file to verify it. - elif os.path.isfile(snowml_start_path): - extract_root = tempfile.mkdtemp() - with zipfile.ZipFile(os.path.abspath(snowml_start_path), mode="r", compression=zipfile.ZIP_DEFLATED) as zf: - zf.extractall(path=extract_root) - path_to_copy = os.path.join(extract_root, *(_SNOWFLAKE_ML_PKG_NAME.split("."))) - else: - raise ValueError("`snowflake.ml` is imported via a way that embedding local ML library is not supported.") + legacy_save = kwargs.pop("_legacy_save", False) + if embed_local_ml_library: + # Use the last one which is loaded first, that is mean, it is loaded from site-packages. + # We could make sure that user does not overwrite our library with their code follow the same naming. + snowml_path, snowml_start_path = file_utils.get_package_path(_SNOWFLAKE_ML_PKG_NAME, strategy="last") + if os.path.isdir(snowml_start_path): + path_to_copy = snowml_path + # If the package is zip-imported, then the path will be `../path_to_zip.zip/snowflake/ml` + # It is not a valid path in fact and we need to get the path to the zip file to verify it. + elif os.path.isfile(snowml_start_path): + extract_root = tempfile.mkdtemp() + with zipfile.ZipFile(os.path.abspath(snowml_start_path), mode="r", compression=zipfile.ZIP_DEFLATED) as zf: + zf.extractall(path=extract_root) + path_to_copy = os.path.join(extract_root, *(_SNOWFLAKE_ML_PKG_NAME.split("."))) + else: + raise ValueError("`snowflake.ml` is imported via a way that embedding local ML library is not supported.") env = _create_env_for_model_metadata( conda_dependencies=conda_dependencies, @@ -106,10 +108,10 @@ def create_model_metadata( ) code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR) - if embed_local_ml_library or code_paths: + if (embed_local_ml_library and legacy_save) or code_paths: os.makedirs(code_dir_path, exist_ok=True) - if embed_local_ml_library: + if embed_local_ml_library and legacy_save: snowml_path_in_code = os.path.join(code_dir_path, _SNOWFLAKE_PKG_NAME) os.makedirs(snowml_path_in_code, exist_ok=True) file_utils.copy_file_or_tree(path_to_copy, snowml_path_in_code) diff --git a/snowflake/ml/model/_packager/model_packager_test.py b/snowflake/ml/model/_packager/model_packager_test.py index 70387fde..a4416f33 100644 --- a/snowflake/ml/model/_packager/model_packager_test.py +++ b/snowflake/ml/model/_packager/model_packager_test.py @@ -143,7 +143,7 @@ def test_zipimport_snowml(self) -> None: model=lm, sample_input=d, metadata={"author": "halu", "version": "1"}, - options={"embed_local_ml_library": True}, + options={"embed_local_ml_library": True, "_legacy_save": True}, ) self.assertTrue( os.path.exists( diff --git a/snowflake/ml/model/_signatures/snowpark_handler.py b/snowflake/ml/model/_signatures/snowpark_handler.py index 4ada0262..c7ee6392 100644 --- a/snowflake/ml/model/_signatures/snowpark_handler.py +++ b/snowflake/ml/model/_signatures/snowpark_handler.py @@ -51,7 +51,7 @@ def infer_signature( data: snowflake.snowpark.DataFrame, role: Literal["input", "output"] ) -> Sequence[core.BaseFeatureSpec]: return pandas_handler.PandasDataFrameHandler.infer_signature( - SnowparkDataFrameHandler.convert_to_df(data), role=role + SnowparkDataFrameHandler.convert_to_df(data.limit(n=1)), role=role ) @staticmethod diff --git a/snowflake/ml/model/_signatures/snowpark_test.py b/snowflake/ml/model/_signatures/snowpark_test.py index c43e56ca..1a9502e1 100644 --- a/snowflake/ml/model/_signatures/snowpark_test.py +++ b/snowflake/ml/model/_signatures/snowpark_test.py @@ -65,15 +65,64 @@ def test_validate_data_with_features(self) -> None: core.FeatureSpec("b", core.DataType.INT64), ] df = self._session.create_dataframe([{'"a"': 1}, {'"b"': 2}]) - with self.assertWarnsRegex(RuntimeWarning, "Nullable column [^\\s]* provided"): - model_signature._validate_snowpark_data(df, fts) + self.assertEqual( + model_signature._validate_snowpark_data(df, fts), model_signature.SnowparkIdentifierRule.INFERRED + ) fts = [ - core.FeatureSpec("a", core.DataType.INT16), - core.FeatureSpec("b", core.DataType.UINT32), + core.FeatureSpec("a", core.DataType.UINT8), + core.FeatureSpec("b", core.DataType.INT64), + ] + df = self._session.create_dataframe([{"a": 1}, {"b": 2}]) + self.assertEqual( + model_signature._validate_snowpark_data(df, fts), model_signature.SnowparkIdentifierRule.NORMALIZED + ) + + fts = [ + core.FeatureSpec("a", core.DataType.UINT8), + core.FeatureSpec("b", core.DataType.INT64), + ] + df = self._session.create_dataframe([{"A": 1}, {"B": 2}]) + self.assertEqual( + model_signature._validate_snowpark_data(df, fts), model_signature.SnowparkIdentifierRule.NORMALIZED + ) + + fts = [ + core.FeatureSpec('"a"', core.DataType.UINT8), + core.FeatureSpec('"b"', core.DataType.INT64), ] df = self._session.create_dataframe([{'"a"': 1}, {'"b"': 2}]) - with self.assertWarnsRegex(RuntimeWarning, "Nullable column [^\\s]* provided"): + self.assertEqual( + model_signature._validate_snowpark_data(df, fts), model_signature.SnowparkIdentifierRule.NORMALIZED + ) + + fts = [ + core.FeatureSpec('"a"', core.DataType.UINT8), + core.FeatureSpec('"b"', core.DataType.INT64), + ] + df = self._session.create_dataframe([{'"""a"""': 1}, {'"""b"""': 2}]) + self.assertEqual( + model_signature._validate_snowpark_data(df, fts), model_signature.SnowparkIdentifierRule.INFERRED + ) + + fts = [ + core.FeatureSpec('"a"', core.DataType.UINT8), + core.FeatureSpec('"b"', core.DataType.INT64), + ] + df = self._session.create_dataframe([{'"A"': 1}, {'"b"': 2}]) + with exception_utils.assert_snowml_exceptions( + self, expected_original_error_type=ValueError, expected_regex="feature [^\\s]* does not exist in data." + ): + model_signature._validate_snowpark_data(df, fts) + + fts = [ + core.FeatureSpec('"a"', core.DataType.UINT8), + core.FeatureSpec('"b"', core.DataType.INT64), + ] + df = self._session.create_dataframe([{"A": 1}, {'"b"': 2}]) + with exception_utils.assert_snowml_exceptions( + self, expected_original_error_type=ValueError, expected_regex="feature [^\\s]* does not exist in data." + ): model_signature._validate_snowpark_data(df, fts) fts = [ @@ -106,8 +155,7 @@ def test_validate_data_with_features(self) -> None: df = self._session.create_dataframe( [[decimal.Decimal(1), decimal.Decimal(1)], [decimal.Decimal(1), decimal.Decimal(1)]], schema ) - with self.assertWarnsRegex(RuntimeWarning, "Nullable column [^\\s]* provided"): - model_signature._validate_snowpark_data(df, fts) + model_signature._validate_snowpark_data(df, fts) fts = [ core.FeatureSpec("a", core.DataType.INT16), @@ -144,24 +192,21 @@ def test_validate_data_with_features(self) -> None: core.FeatureSpec("b", core.DataType.FLOAT), ] df = self._session.create_dataframe([{'"a"': 1}, {'"b"': 2}]) - with self.assertWarnsRegex(RuntimeWarning, "Nullable column [^\\s]* provided"): - model_signature._validate_snowpark_data(df, fts) + model_signature._validate_snowpark_data(df, fts) fts = [ core.FeatureSpec("a", core.DataType.UINT8), core.FeatureSpec("b", core.DataType.FLOAT), ] df = self._session.create_dataframe([{'"a"': 1}, {'"b"': 2.0}]) - with self.assertWarnsRegex(RuntimeWarning, "Nullable column [^\\s]* provided"): - model_signature._validate_snowpark_data(df, fts) + model_signature._validate_snowpark_data(df, fts) fts = [ core.FeatureSpec("a", core.DataType.UINT8), core.FeatureSpec("b", core.DataType.FLOAT), ] df = self._session.create_dataframe([{'"a"': 1}, {'"b"': 98765432109876543210987654321098765432}]) - with self.assertWarnsRegex(RuntimeWarning, "Nullable column [^\\s]* provided"): - model_signature._validate_snowpark_data(df, fts) + model_signature._validate_snowpark_data(df, fts) fts = [ core.FeatureSpec("a", core.DataType.INT16), @@ -177,8 +222,7 @@ def test_validate_data_with_features(self) -> None: ], schema, ) - with self.assertWarnsRegex(RuntimeWarning, "Nullable column [^\\s]* provided"): - model_signature._validate_snowpark_data(df, fts) + model_signature._validate_snowpark_data(df, fts) fts = [ core.FeatureSpec("a", core.DataType.INT16), diff --git a/snowflake/ml/model/model_signature.py b/snowflake/ml/model/model_signature.py index c8f89683..7abfa9f3 100644 --- a/snowflake/ml/model/model_signature.py +++ b/snowflake/ml/model/model_signature.py @@ -1,8 +1,10 @@ +import enum import warnings -from typing import Any, List, Literal, Optional, Sequence, Tuple, Type +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Type import numpy as np import pandas as pd +from typing_extensions import Never import snowflake.snowpark import snowflake.snowpark.functions as F @@ -11,7 +13,7 @@ error_codes, exceptions as snowml_exceptions, ) -from snowflake.ml._internal.utils import formatting, identifier +from snowflake.ml._internal.utils import formatting, identifier, sql_identifier from snowflake.ml.model import type_hints as model_types from snowflake.ml.model._signatures import ( base_handler, @@ -310,8 +312,35 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS ) -def _validate_snowpark_data(data: snowflake.snowpark.DataFrame, features: Sequence[core.BaseFeatureSpec]) -> None: - """Validate Snowpark DataFrame as input +def assert_never(arg: Never) -> Never: + raise AssertionError("Expected code to be unreachable") + + +class SnowparkIdentifierRule(enum.Enum): + INFERRED = "inferred" + NORMALIZED = "normalized" + + def get_identifier_from_feature(self, ft_name: str) -> str: + if self == SnowparkIdentifierRule.INFERRED: + return identifier.get_inferred_name(ft_name) + elif self == SnowparkIdentifierRule.NORMALIZED: + return identifier.resolve_identifier(ft_name) + else: + assert_never(self) + + def get_sql_identifier_from_feature(self, ft_name: str) -> sql_identifier.SqlIdentifier: + if self == SnowparkIdentifierRule.INFERRED: + return sql_identifier.SqlIdentifier(ft_name, case_sensitive=True) + elif self == SnowparkIdentifierRule.NORMALIZED: + return sql_identifier.SqlIdentifier(ft_name, case_sensitive=False) + else: + assert_never(self) + + +def _validate_snowpark_data( + data: snowflake.snowpark.DataFrame, features: Sequence[core.BaseFeatureSpec] +) -> SnowparkIdentifierRule: + """Validate Snowpark DataFrame as input. It will try to map both normalized name or inferred name. Args: data: A snowpark dataframe to be validated. @@ -321,60 +350,86 @@ def _validate_snowpark_data(data: snowflake.snowpark.DataFrame, features: Sequen SnowflakeMLException: NotImplementedError: FeatureGroupSpec is not supported. SnowflakeMLException: ValueError: Raised when confronting invalid feature. SnowflakeMLException: ValueError: Raised when a feature cannot be found. - """ + Returns: + Identifier rule to use. + - inferred: signature `a` - Snowpark DF `"a"`, use `get_inferred_name` + - normalized: signature `a` - Snowpark DF `A`, use `resolve_identifier` + """ + errors: Dict[SnowparkIdentifierRule, List[Exception]] = { + SnowparkIdentifierRule.INFERRED: [], + SnowparkIdentifierRule.NORMALIZED: [], + } schema = data.schema - for feature in features: - ft_name = feature.name - found = False - for field in schema.fields: - name = identifier.get_unescaped_names(field.name) - if name == ft_name: - found = True - if field.nullable: - warnings.warn( - f"Warn in feature {ft_name}: Nullable column {field.name} provided," - + " inference might fail if there is null value.", - category=RuntimeWarning, - stacklevel=1, - ) - if isinstance(feature, core.FeatureGroupSpec): - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.NOT_IMPLEMENTED, - original_exception=NotImplementedError("FeatureGroupSpec is not supported."), - ) - assert isinstance(feature, core.FeatureSpec) # mypy - ft_type = feature._dtype - field_data_type = field.datatype - if isinstance(field_data_type, spt.ArrayType): - if feature._shape is None: + for identifier_rule in errors.keys(): + for feature in features: + try: + ft_name = identifier_rule.get_identifier_from_feature(feature.name) + except ValueError as e: + errors[identifier_rule].append(e) + continue + found = False + for field in schema.fields: + if field.name == ft_name: + found = True + if isinstance(feature, core.FeatureGroupSpec): raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_DATA, - original_exception=ValueError( - f"Data Validation Error in feature {ft_name}: " - + f"Feature is a array feature, while {field.name} is not." - ), + error_code=error_codes.NOT_IMPLEMENTED, + original_exception=NotImplementedError("FeatureGroupSpec is not supported."), ) - warnings.warn( - f"Warn in feature {ft_name}: Feature is a array feature, type validation cannot happen.", - category=RuntimeWarning, - stacklevel=1, - ) - else: - if feature._shape: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_DATA, - original_exception=ValueError( - f"Data Validation Error in feature {ft_name}: " - + f"Feature is a scalar feature, while {field.name} is not." - ), + assert isinstance(feature, core.FeatureSpec) # mypy + ft_type = feature._dtype + field_data_type = field.datatype + if isinstance(field_data_type, spt.ArrayType): + if feature._shape is None: + errors[identifier_rule].append( + ValueError( + f"Data Validation Error in feature {feature.name}: " + + f"Feature is an array feature, while {field.name} is not." + ), + ) + warnings.warn( + (f"Feature {feature.name} type cannot be validated: feature is an array feature."), + category=RuntimeWarning, + stacklevel=2, ) - _validate_snowpark_type_feature(data, field, ft_type, ft_name) - if not found: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_DATA, - original_exception=ValueError(f"Data Validation Error: feature {ft_name} does not exist in data."), - ) + else: + if feature._shape: + errors[identifier_rule].append( + ValueError( + f"Data Validation Error in feature {feature.name}: " + + f"Feature is a scalar feature, while {field.name} is not." + ), + ) + try: + _validate_snowpark_type_feature(data, field, ft_type, feature.name) + except snowml_exceptions.SnowflakeMLException as e: + errors[identifier_rule].append(e.original_exception) + break + if not found: + errors[identifier_rule].append( + ValueError(f"Data Validation Error: feature {feature.name} does not exist in data."), + ) + if all(len(error_list) != 0 for error_list in errors.values()): + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_DATA, + original_exception=ValueError( + f""" +Data Validation Error when validating your Snowpark DataFrame. +If using the normalized names from model signatures, there are the following errors: +{errors[SnowparkIdentifierRule.NORMALIZED]} + +If using the inferred names from model signatures, there are the following errors: +{errors[SnowparkIdentifierRule.INFERRED]} +""" + ), + ) + else: + return ( + SnowparkIdentifierRule.INFERRED + if len(errors[SnowparkIdentifierRule.INFERRED]) == 0 + else SnowparkIdentifierRule.NORMALIZED + ) def _validate_snowpark_type_feature( diff --git a/snowflake/ml/model/type_hints.py b/snowflake/ml/model/type_hints.py index 5e757e81..31a89bbf 100644 --- a/snowflake/ml/model/type_hints.py +++ b/snowflake/ml/model/type_hints.py @@ -201,6 +201,7 @@ class BaseModelSaveOption(TypedDict): """ embed_local_ml_library: NotRequired[bool] + _legacy_save: NotRequired[bool] method_options: NotRequired[Dict[str, ModelMethodSaveOptions]] diff --git a/snowflake/ml/modeling/_internal/BUILD.bazel b/snowflake/ml/modeling/_internal/BUILD.bazel index 7e663da9..c4d34c58 100644 --- a/snowflake/ml/modeling/_internal/BUILD.bazel +++ b/snowflake/ml/modeling/_internal/BUILD.bazel @@ -47,10 +47,76 @@ py_test( ], ) +py_library( + name = "model_specifications", + srcs = ["model_specifications.py"], + deps = [ + "//snowflake/ml/_internal/exceptions", + ], +) + py_test( - name = "snowpark_handlers_test", - srcs = ["snowpark_handlers_test.py"], + name = "model_specifications_test", + srcs = ["model_specifications_test.py"], deps = [ - ":estimator_utils", + ":model_specifications", + ], +) + +py_library( + name = "model_trainer", + srcs = ["model_trainer.py"], + deps = [], +) + +py_library( + name = "pandas_trainer", + srcs = ["pandas_trainer.py"], + deps = [ + ":model_trainer", + ], +) + +py_library( + name = "snowpark_trainer", + srcs = ["snowpark_trainer.py"], + deps = [ + ":model_specifications", + ":model_trainer", + "//snowflake/ml/_internal:env_utils", + "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/exceptions", + "//snowflake/ml/_internal/exceptions:modeling_error_messages", + "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:query_result_checker", + "//snowflake/ml/_internal/utils:snowpark_dataframe_utils", + "//snowflake/ml/_internal/utils:temp_file_utils", + ], +) + +py_library( + name = "distributed_hpo_trainer", + srcs = ["distributed_hpo_trainer.py"], + deps = [ + ":model_specifications", + ":snowpark_trainer", + "//snowflake/ml/_internal:env_utils", + "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/exceptions", + "//snowflake/ml/_internal/exceptions:modeling_error_messages", + "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:snowpark_dataframe_utils", + "//snowflake/ml/_internal/utils:temp_file_utils", + ], +) + +py_library( + name = "model_trainer_builder", + srcs = ["model_trainer_builder.py"], + deps = [ + ":distributed_hpo_trainer", + ":model_trainer", + ":pandas_trainer", + ":snowpark_trainer", ], ) diff --git a/snowflake/ml/modeling/_internal/distributed_hpo_trainer.py b/snowflake/ml/modeling/_internal/distributed_hpo_trainer.py new file mode 100644 index 00000000..9dc57ee2 --- /dev/null +++ b/snowflake/ml/modeling/_internal/distributed_hpo_trainer.py @@ -0,0 +1,554 @@ +import importlib +import inspect +import io +import os +import posixpath +import sys +from typing import Any, Dict, List, Optional, Tuple, Union + +import cloudpickle as cp +import numpy as np +from scipy.stats import rankdata +from sklearn import model_selection + +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils +from snowflake.ml._internal.utils.temp_file_utils import ( + cleanup_temp_files, + get_temp_file_path, +) +from snowflake.ml.modeling._internal.model_specifications import ( + ModelSpecificationsBuilder, +) +from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer +from snowflake.snowpark import DataFrame, Session, functions as F +from snowflake.snowpark._internal.utils import ( + TempObjectType, + random_name_for_temp_object, +) +from snowflake.snowpark.functions import col, sproc, udtf +from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType + +cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path)) +cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name)) + +_PROJECT = "ModelDevelopment" +DEFAULT_UDTF_NJOBS = 3 + + +class DistributedHPOTrainer(SnowparkModelTrainer): + """ + A class for performing distributed hyperparameter optimization (HPO) using Snowpark. + + This class inherits from SnowparkModelTrainer and extends its functionality + to support distributed HPO for machine learning models. It enables optimization + of hyperparameters by distributing the tasks across the warehouse using Snowpark. + """ + + def __init__( + self, + estimator: object, + dataset: DataFrame, + session: Session, + input_cols: List[str], + label_cols: Optional[List[str]], + sample_weight_col: Optional[str], + autogenerated: bool = False, + subproject: str = "", + ) -> None: + """ + Initializes the DistributedHPOTrainer with a model, a Snowpark DataFrame, feature, and label column names, etc. + + Args: + estimator: SKLearn compatible estimator or transformer object. + dataset: The dataset used for training the model. + session: Snowflake session object to be used for training. + input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training. + label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn. + sample_weight_col: The column name representing the weight of training examples. + autogenerated: A boolean denoting if the trainer is being used by autogenerated code or not. + subproject: subproject name to be used in telemetry. + """ + super().__init__( + estimator=estimator, + dataset=dataset, + session=session, + input_cols=input_cols, + label_cols=label_cols, + sample_weight_col=sample_weight_col, + autogenerated=autogenerated, + subproject=subproject, + ) + + # TODO(snandamuri): Copied this code as it is from the snowpark_handler. + # Update it to improve the readability. + def fit_search_snowpark( + self, + param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler], + dataset: DataFrame, + session: Session, + estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV], + dependencies: List[str], + udf_imports: List[str], + input_cols: List[str], + label_cols: Optional[List[str]], + sample_weight_col: Optional[str], + ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]: + from itertools import product + + import cachetools + from sklearn.base import clone, is_classifier + from sklearn.calibration import check_cv + + # Create one stage for data and for estimators. + temp_stage_name = random_name_for_temp_object(TempObjectType.STAGE) + temp_stage_creation_query = f"CREATE OR REPLACE TEMP STAGE {temp_stage_name};" + session.sql(temp_stage_creation_query).collect() + + # Stage data. + dataset = snowpark_dataframe_utils.cast_snowpark_dataframe(dataset) + remote_file_path = f"{temp_stage_name}/{temp_stage_name}.parquet" + dataset.write.copy_into_location( # type:ignore[call-overload] + remote_file_path, file_format_type="parquet", header=True, overwrite=True + ) + imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}").collect()] + + # Store GridSearchCV's refit variable. If user set it as False, we don't need to refit it again + original_refit = estimator.refit + + # Create a temp file and dump the estimator to that file. + estimator_file_name = get_temp_file_path() + params_to_evaluate = [] + for param_to_eval in list(param_grid): + for k, v in param_to_eval.items(): + param_to_eval[k] = [v] + params_to_evaluate.append([param_to_eval]) + + with open(estimator_file_name, mode="w+b") as local_estimator_file_obj: + # Set GridSearchCV refit as False and fit it again after retrieving the best param + estimator.refit = False + cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj) + stage_estimator_file_name = posixpath.join(temp_stage_name, os.path.basename(estimator_file_name)) + sproc_statement_params = telemetry.get_function_usage_statement_params( + project=_PROJECT, + subproject=self._subproject, + function_name=telemetry.get_statement_params_full_func_name( + inspect.currentframe(), self.__class__.__name__ + ), + api_calls=[sproc], + custom_tags=dict([("autogen", True)]) if self._autogenerated else None, + ) + udtf_statement_params = telemetry.get_function_usage_statement_params( + project=_PROJECT, + subproject=self._subproject, + function_name=telemetry.get_statement_params_full_func_name( + inspect.currentframe(), self.__class__.__name__ + ), + api_calls=[udtf], + custom_tags=dict([("autogen", True)]) if self._autogenerated else None, + ) + + # Put locally serialized estimator on stage. + put_result = session.file.put( + estimator_file_name, + temp_stage_name, + auto_compress=False, + overwrite=True, + ) + estimator_location = put_result[0].target + imports.append(f"@{temp_stage_name}/{estimator_location}") + + search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE) + random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION) + + required_deps = dependencies + [ + "snowflake-snowpark-python<2", + "fastparquet<2023.11", + "pyarrow<14", + "cachetools<5", + ] + + @sproc( # type: ignore[misc] + is_permanent=False, + name=search_sproc_name, + packages=required_deps, # type: ignore[arg-type] + replace=True, + session=session, + anonymous=True, + imports=imports, # type: ignore[arg-type] + statement_params=sproc_statement_params, + ) + def _distributed_search( + session: Session, + imports: List[str], + stage_estimator_file_name: str, + input_cols: List[str], + label_cols: Optional[List[str]], + ) -> str: + import os + import time + from typing import Iterator + + import cloudpickle as cp + import pandas as pd + import pyarrow.parquet as pq + from sklearn.metrics import check_scoring + from sklearn.metrics._scorer import _check_multimetric_scoring + + for import_name in udf_imports: + importlib.import_module(import_name) + + data_files = [ + filename + for filename in os.listdir(sys._xoptions["snowflake_import_directory"]) + if filename.startswith(temp_stage_name) + ] + partial_df = [ + pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas() + for file_name in data_files + ] + df = pd.concat(partial_df, ignore_index=True) + df.columns = [identifier.get_inferred_name(col) for col in df.columns] + + X = df[input_cols] + y = df[label_cols].squeeze() if label_cols else None + + local_estimator_file_name = get_temp_file_path() + session.file.get(stage_estimator_file_name, local_estimator_file_name) + + local_estimator_file_path = os.path.join( + local_estimator_file_name, os.listdir(local_estimator_file_name)[0] + ) + with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj: + estimator = cp.load(local_estimator_file_obj)["estimator"] + + cv_orig = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator)) + indices = [test for _, test in cv_orig.split(X, y)] + local_indices_file_name = get_temp_file_path() + with open(local_indices_file_name, mode="w+b") as local_indices_file_obj: + cp.dump(indices, local_indices_file_obj) + + # Put locally serialized indices on stage. + put_result = session.file.put( + local_indices_file_name, + temp_stage_name, + auto_compress=False, + overwrite=True, + ) + indices_location = put_result[0].target + imports.append(f"@{temp_stage_name}/{indices_location}") + indices_len = len(indices) + + assert estimator is not None + + @cachetools.cached(cache={}) + def _load_data_into_udf() -> Tuple[ + Dict[str, pd.DataFrame], + Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV], + pd.DataFrame, + int, + List[Dict[str, Any]], + ]: + import pyarrow.parquet as pq + + data_files = [ + filename + for filename in os.listdir(sys._xoptions["snowflake_import_directory"]) + if filename.startswith(temp_stage_name) + ] + partial_df = [ + pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas() + for file_name in data_files + ] + df = pd.concat(partial_df, ignore_index=True) + df.columns = [identifier.get_inferred_name(col) for col in df.columns] + + # load estimator + local_estimator_file_path = os.path.join( + sys._xoptions["snowflake_import_directory"], f"{estimator_location}" + ) + with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj: + estimator_objects = cp.load(local_estimator_file_obj) + estimator = estimator_objects["estimator"] + params_to_evaluate = estimator_objects["param_grid"] + + # load indices + local_indices_file_path = os.path.join( + sys._xoptions["snowflake_import_directory"], f"{indices_location}" + ) + with open(local_indices_file_path, mode="rb") as local_indices_file_obj: + indices = cp.load(local_indices_file_obj) + + argspec = inspect.getfullargspec(estimator.fit) + args = {"X": df[input_cols]} + + if label_cols: + label_arg_name = "Y" if "Y" in argspec.args else "y" + args[label_arg_name] = df[label_cols].squeeze() + + if sample_weight_col is not None and "sample_weight" in argspec.args: + args["sample_weight"] = df[sample_weight_col].squeeze() + return args, estimator, indices, len(df), params_to_evaluate + + class SearchCV: + def __init__(self) -> None: + args, estimator, indices, data_length, params_to_evaluate = _load_data_into_udf() + self.args = args + self.estimator = estimator + self.indices = indices + self.data_length = data_length + self.params_to_evaluate = params_to_evaluate + + def process(self, params_idx: int, idx: int) -> Iterator[Tuple[str]]: + if hasattr(estimator, "param_grid"): + self.estimator.param_grid = self.params_to_evaluate[params_idx] + else: + self.estimator.param_distributions = self.params_to_evaluate[params_idx] + full_indices = np.array([i for i in range(self.data_length)]) + test_indice = self.indices[idx] + train_indice = np.setdiff1d(full_indices, test_indice) + self.estimator.cv = [(train_indice, test_indice)] + self.estimator.fit(**self.args) + binary_cv_results = None + with io.BytesIO() as f: + cp.dump(self.estimator.cv_results_, f) + f.seek(0) + binary_cv_results = f.getvalue().hex() + yield (binary_cv_results,) + + def end_partition(self) -> None: + ... + + session.udtf.register( + SearchCV, + output_schema=StructType([StructField("CV_RESULTS", StringType())]), + input_types=[IntegerType(), IntegerType()], + name=random_udtf_name, + packages=required_deps, # type: ignore[arg-type] + replace=True, + is_permanent=False, + imports=imports, # type: ignore[arg-type] + statement_params=udtf_statement_params, + ) + + HP_TUNING = F.table_function(random_udtf_name) + + idx_length = int(indices_len) + params_length = len(param_grid) + idxs = [i for i in range(idx_length)] + param_indices, training_indices = [], [] + for param_idx, cv_idx in product([param_index for param_index in range(params_length)], idxs): + param_indices.append(param_idx) + training_indices.append(cv_idx) + + pd_df = pd.DataFrame( + { + "PARAMS": param_indices, + "TRAIN_IND": training_indices, + "PARAM_INDEX": [i for i in range(idx_length * params_length)], + } + ) + df = session.create_dataframe(pd_df) + results = df.select( + F.cast(df["PARAM_INDEX"], IntegerType()).as_("PARAM_INDEX"), + (HP_TUNING(df["PARAMS"], df["TRAIN_IND"]).over(partition_by=df["PARAM_INDEX"])), + ) + + # cv_result maintains the original order + multimetric = False + cv_results_ = dict() + scorers = set() + for i, val in enumerate(results.select("CV_RESULTS").sort(col("PARAM_INDEX")).collect()): + # retrieved string had one more double quote in the front and end of the string. + # use [1:-1] to remove the extra double quotes + hex_str = bytes.fromhex(val[0]) + with io.BytesIO(hex_str) as f_reload: + each_cv_result = cp.load(f_reload) + for k, v in each_cv_result.items(): + cur_cv = i % idx_length + key = k + if "split0_test_" in k: + # For multi-metric evaluation, the scores for all the scorers are available in the + # cv_results_ dict at the keys ending with that scorer’s name ('_') + # instead of '_score'. + scorers.add(k[len("split0_test_") :]) + key = k.replace("split0_test", f"split{cur_cv}_test") + elif k.startswith("param"): + if cur_cv != 0: + key = False + if key: + if key not in cv_results_: + cv_results_[key] = v + else: + cv_results_[key] = np.concatenate([cv_results_[key], v]) + + multimetric = len(scorers) > 1 + # Use numpy to re-calculate all the information in cv_results_ again + # Generally speaking, reshape all the results into the (scorers+2, idx_length, params_length) shape, + # and average them by the idx_length; + # idx_length is the number of cv folds; params_length is the number of parameter combinations + scores = [ + np.reshape( + np.concatenate([cv_results_[f"split{cur_cv}_test_{score}"] for cur_cv in range(idx_length)]), + (idx_length, -1), + ) + for score in scorers + ] + + fit_score_test_matrix = np.stack( + [ + np.reshape(cv_results_["mean_fit_time"], (idx_length, -1)), + np.reshape(cv_results_["mean_score_time"], (idx_length, -1)), + ] + + scores + ) + + mean_fit_score_test_matrix = np.mean(fit_score_test_matrix, axis=1) + std_fit_score_test_matrix = np.std(fit_score_test_matrix, axis=1) + cv_results_["std_fit_time"] = std_fit_score_test_matrix[0] + cv_results_["mean_fit_time"] = mean_fit_score_test_matrix[0] + cv_results_["std_score_time"] = std_fit_score_test_matrix[1] + cv_results_["mean_score_time"] = mean_fit_score_test_matrix[1] + for idx, score in enumerate(scorers): + cv_results_[f"std_test_{score}"] = std_fit_score_test_matrix[idx + 2] + cv_results_[f"mean_test_{score}"] = mean_fit_score_test_matrix[idx + 2] + # re-compute the ranking again with mean_test_. + cv_results_[f"rank_test_{score}"] = rankdata(-cv_results_[f"mean_test_{score}"], method="min") + # The best param is the highest ranking (which is 1) and we choose the first time ranking 1 appeared. + # If all scores are `nan`, `rankdata` will also produce an array of `nan` values. + # In that case, default to first index. + best_param_index = ( + np.where(cv_results_[f"rank_test_{score}"] == 1)[0][0] + if not np.isnan(cv_results_[f"rank_test_{score}"]).all() + else 0 + ) + + estimator.cv_results_ = cv_results_ + estimator.multimetric_ = multimetric + + # Reconstruct the sklearn estimator. + refit_metric = "score" + if callable(estimator.scoring): + scorers = estimator.scoring + elif estimator.scoring is None or isinstance(estimator.scoring, str): + scorers = check_scoring(estimator.estimator, estimator.scoring) + else: + scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring) + estimator._check_refit_for_multimetric(scorers) + refit_metric = original_refit + + estimator.scorer_ = scorers + + # check refit_metric now for a callabe scorer that is multimetric + if callable(estimator.scoring) and estimator.multimetric_: + refit_metric = original_refit + + # For multi-metric evaluation, store the best_index_, best_params_ and + # best_score_ iff refit is one of the scorer names + # In single metric evaluation, refit_metric is "score" + if original_refit or not estimator.multimetric_: + estimator.best_index_ = estimator._select_best_index(original_refit, refit_metric, cv_results_) + if not callable(original_refit): + # With a non-custom callable, we can select the best score + # based on the best index + estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_] + estimator.best_params_ = cv_results_["params"][best_param_index] + + if original_refit: + estimator.best_estimator_ = clone(estimator.estimator).set_params( + **clone(estimator.best_params_, safe=False) + ) + + # Let the sproc use all cores to refit. + estimator.n_jobs = -1 if not estimator.n_jobs else estimator.n_jobs + + # process the input as args + argspec = inspect.getfullargspec(estimator.fit) + args = {"X": X} + if label_cols: + label_arg_name = "Y" if "Y" in argspec.args else "y" + args[label_arg_name] = y + if sample_weight_col is not None and "sample_weight" in argspec.args: + args["sample_weight"] = df[sample_weight_col].squeeze() + estimator.refit = original_refit + refit_start_time = time.time() + estimator.best_estimator_.fit(**args) + refit_end_time = time.time() + estimator.refit_time_ = refit_end_time - refit_start_time + + if hasattr(estimator.best_estimator_, "feature_names_in_"): + estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_ + + local_result_file_name = get_temp_file_path() + + with open(local_result_file_name, mode="w+b") as local_result_file_obj: + cp.dump(estimator, local_result_file_obj) + + session.file.put( + local_result_file_name, + temp_stage_name, + auto_compress=False, + overwrite=True, + ) + + # Note: you can add something like + "|" + str(df) to the return string + # to pass debug information to the caller. + return str(os.path.basename(local_result_file_name)) + + sproc_export_file_name = _distributed_search( + session, + imports, + stage_estimator_file_name, + input_cols, + label_cols, + ) + + local_estimator_path = get_temp_file_path() + session.file.get( + posixpath.join(temp_stage_name, sproc_export_file_name), + local_estimator_path, + ) + + with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj: + fit_estimator = cp.load(result_file_obj) + + cleanup_temp_files([local_estimator_path]) + + return fit_estimator + + def train(self) -> object: + """ + Runs hyper parameter optimization by distributing the tasks across warehouse. + + Returns: + Trained model + """ + model_spec = ModelSpecificationsBuilder.build(model=self.estimator) + assert isinstance(self.estimator, model_selection.GridSearchCV) or isinstance( + self.estimator, model_selection.RandomizedSearchCV + ) + if hasattr(self.estimator.estimator, "n_jobs") and self.estimator.estimator.n_jobs in [ + None, + -1, + ]: + self.estimator.estimator.n_jobs = DEFAULT_UDTF_NJOBS + + if isinstance(self.estimator, model_selection.GridSearchCV): + param_grid = model_selection.ParameterGrid(self.estimator.param_grid) + elif isinstance(self.estimator, model_selection.RandomizedSearchCV): + param_grid = model_selection.ParameterSampler( + self.estimator.param_distributions, + n_iter=self.estimator.n_iter, + random_state=self.estimator.random_state, + ) + return self.fit_search_snowpark( + param_grid=param_grid, + dataset=self.dataset, + session=self.session, + estimator=self.estimator, + dependencies=model_spec.pkgDependencies, + udf_imports=["sklearn"], + input_cols=self.input_cols, + label_cols=self.label_cols, + sample_weight_col=self.sample_weight_col, + ) diff --git a/snowflake/ml/modeling/_internal/estimator_protocols.py b/snowflake/ml/modeling/_internal/estimator_protocols.py index 155c10e3..c507b2a7 100644 --- a/snowflake/ml/modeling/_internal/estimator_protocols.py +++ b/snowflake/ml/modeling/_internal/estimator_protocols.py @@ -1,35 +1,12 @@ -from typing import List, Optional, Protocol, Union +from typing import List, Optional, Protocol import pandas as pd -from sklearn import model_selection from snowflake.snowpark import DataFrame, Session # TODO: Add more specific entities to type hint estimators instead of using `object`. class FitPredictHandlers(Protocol): - def fit_snowpark( - self, - dataset: DataFrame, - session: Session, - estimator: object, - dependencies: List[str], - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], - ) -> object: - raise NotImplementedError - - def fit_pandas( - self, - dataset: pd.DataFrame, - estimator: object, - input_cols: List[str], - label_cols: Optional[List[str]], - sample_weight_col: Optional[str], - ) -> object: - raise NotImplementedError - def batch_inference( self, dataset: DataFrame, @@ -70,28 +47,6 @@ def score_snowpark( # TODO: Add more specific entities to type hint estimators instead of using `object`. class CVHandlers(Protocol): - def fit_snowpark( - self, - dataset: DataFrame, - session: Session, - estimator: object, - dependencies: List[str], - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], - ) -> object: - raise NotImplementedError - - def fit_pandas( - self, - dataset: pd.DataFrame, - estimator: object, - input_cols: List[str], - label_cols: Optional[List[str]], - sample_weight_col: Optional[str], - ) -> object: - raise NotImplementedError - def batch_inference( self, dataset: DataFrame, @@ -128,17 +83,3 @@ def score_snowpark( sample_weight_col: Optional[str], ) -> float: raise NotImplementedError - - def fit_search_snowpark( - self, - param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler], - dataset: DataFrame, - session: Session, - estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV], - dependencies: List[str], - udf_imports: List[str], - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], - ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]: - raise NotImplementedError diff --git a/snowflake/ml/modeling/_internal/model_specifications.py b/snowflake/ml/modeling/_internal/model_specifications.py new file mode 100644 index 00000000..e6f375c5 --- /dev/null +++ b/snowflake/ml/modeling/_internal/model_specifications.py @@ -0,0 +1,146 @@ +import inspect +from typing import List + +import cloudpickle as cp +import numpy as np + +from snowflake.ml._internal.exceptions import error_codes, exceptions + + +class ModelSpecifications: + """ + A dataclass to define model based specifications like required imports, and package dependencies for Sproc/Udfs. + """ + + def __init__(self, imports: List[str], pkgDependencies: List[str]) -> None: + self.imports = imports + self.pkgDependencies = pkgDependencies + + +class SKLearnModelSpecifications(ModelSpecifications): + def __init__(self) -> None: + import sklearn + + imports: List[str] = ["sklearn"] + # TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda. + pkgDependencies = [ + f"numpy=={np.__version__}", + f"scikit-learn=={sklearn.__version__}", + f"cloudpickle=={cp.__version__}", + ] + + # A change from previous implementation. + # When reusing the Sprocs for all the fit() call in the session, the static dpendencies list should include + # all the possible dependencies required during the lifetime. + + # Include XGBoost in the dependencies if it is installed. + try: + import xgboost + except ModuleNotFoundError: + pass + else: + pkgDependencies.append(f"xgboost=={xgboost.__version__}") + + # Include lightgbm in the dependencies if it is installed. + try: + import lightgbm + except ModuleNotFoundError: + pass + else: + pkgDependencies.append(f"lightgbm=={lightgbm.__version__}") + + super().__init__(imports=imports, pkgDependencies=pkgDependencies) + + +class XGBoostModelSpecifications(ModelSpecifications): + def __init__(self) -> None: + import xgboost + + imports: List[str] = ["xgboost"] + pkgDependencies: List[str] = [ + f"numpy=={np.__version__}", + f"xgboost=={xgboost.__version__}", + f"cloudpickle=={cp.__version__}", + ] + super().__init__(imports=imports, pkgDependencies=pkgDependencies) + + +class LightGBMModelSpecifications(ModelSpecifications): + def __init__(self) -> None: + import lightgbm + + imports: List[str] = ["lightgbm"] + pkgDependencies: List[str] = [ + f"numpy=={np.__version__}", + f"lightgbm=={lightgbm.__version__}", + f"cloudpickle=={cp.__version__}", + ] + super().__init__(imports=imports, pkgDependencies=pkgDependencies) + + +class SklearnModelSelectionModelSpecifications(ModelSpecifications): + def __init__(self) -> None: + import sklearn + import xgboost + + imports: List[str] = ["sklearn", "xgboost"] + pkgDependencies: List[str] = [ + f"numpy=={np.__version__}", + f"scikit-learn=={sklearn.__version__}", + f"cloudpickle=={cp.__version__}", + f"xgboost=={xgboost.__version__}", + ] + + # Only include lightgbm in the dependencies if it is installed. + try: + import lightgbm + except ModuleNotFoundError: + pass + else: + imports.append("lightgbm") + pkgDependencies.append(f"lightgbm=={lightgbm.__version__}") + + super().__init__(imports=imports, pkgDependencies=pkgDependencies) + + +class ModelSpecificationsBuilder: + """ + A factory class to build ModelSpecifications object for different types of models. + """ + + @classmethod + def build(cls, model: object) -> ModelSpecifications: + """ + A static factory method that builds ModelSpecifications object based on the module name of native model object. + + Args: + model: Native model object to be trained. + + Returns: + Appropriate ModelSpecification object + + Raises: + SnowflakeMLException: Raises an exception the module of given model can't be determined. + TypeError: Raises the exception for unsupported modules. + """ + module = inspect.getmodule(model) + if module is None: + raise exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_TYPE, + original_exception=ValueError("Unable to infer model type of the given native model object."), + ) + root_module_name = module.__name__.split(".")[0] + if root_module_name == "sklearn": + from sklearn.model_selection import GridSearchCV, RandomizedSearchCV + + if isinstance(model, GridSearchCV) or isinstance(model, RandomizedSearchCV): + return SklearnModelSelectionModelSpecifications() + return SKLearnModelSpecifications() + elif root_module_name == "xgboost": + return XGBoostModelSpecifications() + elif root_module_name == "lightgbm": + return LightGBMModelSpecifications() + else: + raise TypeError( + f"Unexpected module type: {root_module_name}." "Supported module types: sklearn, xgboost, lightgbm." + ) diff --git a/snowflake/ml/modeling/_internal/snowpark_handlers_test.py b/snowflake/ml/modeling/_internal/model_specifications_test.py similarity index 70% rename from snowflake/ml/modeling/_internal/snowpark_handlers_test.py rename to snowflake/ml/modeling/_internal/model_specifications_test.py index e17803b7..26671eb2 100644 --- a/snowflake/ml/modeling/_internal/snowpark_handlers_test.py +++ b/snowflake/ml/modeling/_internal/model_specifications_test.py @@ -2,12 +2,13 @@ from unittest import mock from absl.testing import absltest, parameterized +from lightgbm import LGBMRegressor +from sklearn.linear_model import LinearRegression +from sklearn.model_selection import GridSearchCV +from xgboost import XGBRegressor -from snowflake.ml.modeling._internal.snowpark_handlers import ( - LightGBMWrapperProvider, - SklearnModelSelectionWrapperProvider, - SklearnWrapperProvider, - XGBoostWrapperProvider, +from snowflake.ml.modeling._internal.model_specifications import ( + ModelSpecificationsBuilder, ) @@ -23,7 +24,8 @@ def import_mock(name: str, *args: Any, **kwargs: Any) -> Any: return orig_import(name, *args, **kwargs) with mock.patch("builtins.__import__", side_effect=import_mock): - provider = SklearnModelSelectionWrapperProvider() + model = GridSearchCV(estimator=XGBRegressor(), param_grid={"max_depth": [10, 100]}) + provider = ModelSpecificationsBuilder.build(model=model) self.assertEqual(provider.imports, ["sklearn", "xgboost", "lightgbm"]) @@ -36,16 +38,17 @@ def import_mock(name: str, *args: Any, **kwargs: Any) -> Any: return orig_import(name, *args, **kwargs) with mock.patch("builtins.__import__", side_effect=import_mock): - provider = SklearnModelSelectionWrapperProvider() + model = GridSearchCV(estimator=XGBRegressor(), param_grid={"max_depth": [10, 100]}) + provider = ModelSpecificationsBuilder.build(model=model) self.assertEqual(provider.imports, ["sklearn", "xgboost"]) def test_xgboost_wrapper_provider(self) -> None: - provider = XGBoostWrapperProvider() + provider = ModelSpecificationsBuilder.build(model=XGBRegressor()) self.assertEqual(provider.imports, ["xgboost"]) def test_sklearn_wrapper_provider(self) -> None: - provider = SklearnWrapperProvider() + provider = ModelSpecificationsBuilder.build(model=LinearRegression()) self.assertEqual(provider.imports, ["sklearn"]) def test_lightgbm_wrapper_provider(self) -> None: @@ -59,7 +62,7 @@ def import_mock(name: str, *args: Any, **kwargs: Any) -> Any: return orig_import(name, *args, **kwargs) with mock.patch("builtins.__import__", side_effect=import_mock): - provider = LightGBMWrapperProvider() + provider = ModelSpecificationsBuilder.build(model=LGBMRegressor()) self.assertEqual(provider.imports, ["lightgbm"]) diff --git a/snowflake/ml/modeling/_internal/model_trainer.py b/snowflake/ml/modeling/_internal/model_trainer.py new file mode 100644 index 00000000..0c99a011 --- /dev/null +++ b/snowflake/ml/modeling/_internal/model_trainer.py @@ -0,0 +1,13 @@ +from typing import Protocol + + +class ModelTrainer(Protocol): + """ + Interface for model trainer implementations. + + There are multiple flavors of training like training with pandas datasets, training with + Snowpark datasets using sprocs, and out of core training with Snowpark datasets etc. + """ + + def train(self) -> object: + raise NotImplementedError diff --git a/snowflake/ml/modeling/_internal/model_trainer_builder.py b/snowflake/ml/modeling/_internal/model_trainer_builder.py new file mode 100644 index 00000000..4c4d7aca --- /dev/null +++ b/snowflake/ml/modeling/_internal/model_trainer_builder.py @@ -0,0 +1,78 @@ +from typing import List, Optional, Union + +import pandas as pd +from sklearn import model_selection + +from snowflake.ml.modeling._internal.distributed_hpo_trainer import ( + DistributedHPOTrainer, +) +from snowflake.ml.modeling._internal.estimator_utils import is_single_node +from snowflake.ml.modeling._internal.model_trainer import ModelTrainer +from snowflake.ml.modeling._internal.pandas_trainer import PandasModelTrainer +from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer +from snowflake.snowpark import DataFrame, Session + +_PROJECT = "ModelDevelopment" + + +class ModelTrainerBuilder: + """ + A builder class to create instances of ModelTrainer for different models and training conditions. + + This class provides methods to build instances of ModelTrainer tailored to specific machine learning + models and training configurations like dataset's location etc. It abstracts the creation process, + allowing the user to obtain a configured ModelTrainer for a particular model architecture or configuration. + """ + + _ENABLE_DISTRIBUTED = True + + @classmethod + def _check_if_distributed_hpo_enabled(cls, session: Session) -> bool: + return not is_single_node(session) and ModelTrainerBuilder._ENABLE_DISTRIBUTED is True + + @classmethod + def build( + cls, + estimator: object, + dataset: Union[DataFrame, pd.DataFrame], + input_cols: Optional[List[str]] = None, + label_cols: Optional[List[str]] = None, + sample_weight_col: Optional[str] = None, + autogenerated: bool = False, + subproject: str = "", + ) -> ModelTrainer: + """ + Builder method that creates an approproiate ModelTrainer instance based on the given params. + """ + assert input_cols is not None # Make MyPy happpy + if isinstance(dataset, pd.DataFrame): + return PandasModelTrainer( + estimator=estimator, + dataset=dataset, + input_cols=input_cols, + label_cols=label_cols, + sample_weight_col=sample_weight_col, + ) + elif isinstance(dataset, DataFrame): + trainer_klass = SnowparkModelTrainer + assert dataset._session is not None # Make MyPy happpy + if isinstance(estimator, model_selection.GridSearchCV) or isinstance( + estimator, model_selection.RandomizedSearchCV + ): + if ModelTrainerBuilder._check_if_distributed_hpo_enabled(session=dataset._session): + trainer_klass = DistributedHPOTrainer + return trainer_klass( + estimator=estimator, + dataset=dataset, + session=dataset._session, + input_cols=input_cols, + label_cols=label_cols, + sample_weight_col=sample_weight_col, + autogenerated=autogenerated, + subproject=subproject, + ) + else: + raise TypeError( + f"Unexpected dataset type: {type(dataset)}." + "Supported dataset types: snowpark.DataFrame, pandas.DataFrame." + ) diff --git a/snowflake/ml/modeling/_internal/pandas_trainer.py b/snowflake/ml/modeling/_internal/pandas_trainer.py new file mode 100644 index 00000000..6a2d726e --- /dev/null +++ b/snowflake/ml/modeling/_internal/pandas_trainer.py @@ -0,0 +1,54 @@ +import inspect +from typing import List, Optional + +import pandas as pd + + +class PandasModelTrainer: + """ + A class for training machine learning models using Pandas datasets. + """ + + def __init__( + self, + estimator: object, + dataset: pd.DataFrame, + input_cols: List[str], + label_cols: Optional[List[str]], + sample_weight_col: Optional[str], + ) -> None: + """ + Initializes the PandasModelTrainer with a model, a Pandas DataFrame, feature, and label column names. + + Args: + estimator: SKLearn compatible estimator or transformer object. + dataset: The dataset used for training the model. + input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training. + label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn. + sample_weight_col: The column name representing the weight of training examples. + """ + self.estimator = estimator + self.dataset = dataset + self.input_cols = input_cols + self.label_cols = label_cols + self.sample_weight_col = sample_weight_col + + def train(self) -> object: + """ + Trains the model using specified features and target columns from the dataset. + + Returns: + Trained model + """ + assert hasattr(self.estimator, "fit") # Keep mypy happy + argspec = inspect.getfullargspec(self.estimator.fit) + args = {"X": self.dataset[self.input_cols]} + + if self.label_cols: + label_arg_name = "Y" if "Y" in argspec.args else "y" + args[label_arg_name] = self.dataset[self.label_cols].squeeze() + + if self.sample_weight_col is not None and "sample_weight" in argspec.args: + args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze() + + return self.estimator.fit(**args) diff --git a/snowflake/ml/modeling/_internal/snowpark_handlers.py b/snowflake/ml/modeling/_internal/snowpark_handlers.py index 79cbe5b1..3adb0c54 100644 --- a/snowflake/ml/modeling/_internal/snowpark_handlers.py +++ b/snowflake/ml/modeling/_internal/snowpark_handlers.py @@ -1,51 +1,29 @@ import importlib import inspect -import io import os import posixpath -import sys -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional from uuid import uuid4 import cloudpickle as cp -import numpy as np import pandas as pd -import sklearn -from scipy.stats import rankdata -from sklearn import model_selection from snowflake.ml._internal import telemetry from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV -from snowflake.ml._internal.exceptions import ( - error_codes, - exceptions, - modeling_error_messages, -) +from snowflake.ml._internal.exceptions import error_codes, exceptions from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator from snowflake.ml._internal.utils.temp_file_utils import ( cleanup_temp_files, get_temp_file_path, ) -from snowflake.snowpark import ( - DataFrame, - Session, - exceptions as snowpark_exceptions, - functions as F, -) +from snowflake.snowpark import DataFrame, Session from snowflake.snowpark._internal.utils import ( TempObjectType, random_name_for_temp_object, ) -from snowflake.snowpark.functions import col, pandas_udf, sproc, udtf -from snowflake.snowpark.stored_procedure import StoredProcedure -from snowflake.snowpark.types import ( - IntegerType, - PandasSeries, - StringType, - StructField, - StructType, -) +from snowflake.snowpark.functions import pandas_udf, sproc +from snowflake.snowpark.types import PandasSeries cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path)) cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name)) @@ -53,144 +31,6 @@ _PROJECT = "ModelDevelopment" -class WrapperProvider: - def __init__(self) -> None: - self.imports: List[str] = [] - self.dependencies: List[str] = [] - - def get_fit_wrapper_function( - self, - ) -> Callable[[Any, List[str], str, str, List[str], List[str], Optional[str], Dict[str, str]], str]: - imports = self.imports # In order for the sproc to not resolve this reference in snowflake.ml - - def fit_wrapper_function( - session: Session, - sql_queries: List[str], - stage_transform_file_name: str, - stage_result_file_name: str, - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], - statement_params: Dict[str, str], - ) -> str: - import inspect - import os - - import cloudpickle as cp - import pandas as pd - - for import_name in imports: - importlib.import_module(import_name) - - # Execute snowpark queries and obtain the results as pandas dataframe - # NB: this implies that the result data must fit into memory. - for query in sql_queries[:-1]: - _ = session.sql(query).collect(statement_params=statement_params) - sp_df = session.sql(sql_queries[-1]) - df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params) - df.columns = sp_df.columns - - local_transform_file_name = get_temp_file_path() - - session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params) - - local_transform_file_path = os.path.join( - local_transform_file_name, os.listdir(local_transform_file_name)[0] - ) - with open(local_transform_file_path, mode="r+b") as local_transform_file_obj: - estimator = cp.load(local_transform_file_obj) - - argspec = inspect.getfullargspec(estimator.fit) - args = {"X": df[input_cols]} - if label_cols: - label_arg_name = "Y" if "Y" in argspec.args else "y" - args[label_arg_name] = df[label_cols].squeeze() - - if sample_weight_col is not None and "sample_weight" in argspec.args: - args["sample_weight"] = df[sample_weight_col].squeeze() - - estimator.fit(**args) - - local_result_file_name = get_temp_file_path() - - with open(local_result_file_name, mode="w+b") as local_result_file_obj: - cp.dump(estimator, local_result_file_obj) - - session.file.put( - local_result_file_name, - stage_result_file_name, - auto_compress=False, - overwrite=True, - statement_params=statement_params, - ) - - # Note: you can add something like + "|" + str(df) to the return string - # to pass debug information to the caller. - return str(os.path.basename(local_result_file_name)) - - return fit_wrapper_function - - -class SklearnWrapperProvider(WrapperProvider): - def __init__(self) -> None: - import sklearn - - self.imports: List[str] = ["sklearn"] - - # TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda. - self.dependencies: List[str] = [ - f"numpy=={np.__version__}", - f"scikit-learn=={sklearn.__version__}", - f"cloudpickle=={cp.__version__}", - ] - - -class XGBoostWrapperProvider(WrapperProvider): - def __init__(self) -> None: - import xgboost - - self.imports: List[str] = ["xgboost"] - self.dependencies = [ - f"numpy=={np.__version__}", - f"xgboost=={xgboost.__version__}", - f"cloudpickle=={cp.__version__}", - ] - - -class LightGBMWrapperProvider(WrapperProvider): - def __init__(self) -> None: - import lightgbm - - self.imports: List[str] = ["lightgbm"] - self.dependencies = [ - f"numpy=={np.__version__}", - f"lightgbm=={lightgbm.__version__}", - f"cloudpickle=={cp.__version__}", - ] - - -class SklearnModelSelectionWrapperProvider(WrapperProvider): - def __init__(self) -> None: - import xgboost - - self.imports: List[str] = ["sklearn", "xgboost"] - self.dependencies = [ - f"numpy=={np.__version__}", - f"scikit-learn=={sklearn.__version__}", - f"cloudpickle=={cp.__version__}", - f"xgboost=={xgboost.__version__}", - ] - - # Only include lightgbm in the dependencies if it is installed. - try: - import lightgbm - except ModuleNotFoundError: - pass - else: - self.imports.append("lightgbm") - self.dependencies.append(f"lightgbm=={lightgbm.__version__}") - - def _get_rand_id() -> str: """ Generate random id to be used in sproc and stage names. @@ -202,171 +42,11 @@ def _get_rand_id() -> str: class SnowparkHandlers: - def __init__( - self, class_name: str, subproject: str, wrapper_provider: WrapperProvider, autogenerated: Optional[bool] = False - ) -> None: + def __init__(self, class_name: str, subproject: str, autogenerated: Optional[bool] = False) -> None: self._class_name = class_name self._subproject = subproject - self._wrapper_provider = wrapper_provider self._autogenerated = autogenerated - def _get_fit_wrapper_sproc( - self, dependencies: List[str], session: Session, statement_params: Dict[str, str] - ) -> StoredProcedure: - # If the sproc already exists, don't register. - if not hasattr(session, "_FIT_WRAPPER_SPROCS"): - session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc] - - fit_sproc_key = self._wrapper_provider.__class__.__name__ - if fit_sproc_key in session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined] - fit_sproc: StoredProcedure = session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined] - return fit_sproc - - fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE) - - fit_wrapper_sproc = session.sproc.register( - func=self._wrapper_provider.get_fit_wrapper_function(), - is_permanent=False, - name=fit_sproc_name, - packages=dependencies, # type: ignore[arg-type] - replace=True, - session=session, - statement_params=statement_params, - ) - - session._FIT_WRAPPER_SPROCS[fit_sproc_key] = fit_wrapper_sproc # type: ignore[attr-defined] - - return fit_wrapper_sproc - - def fit_pandas( - self, - dataset: pd.DataFrame, - estimator: object, - input_cols: List[str], - label_cols: Optional[List[str]], - sample_weight_col: Optional[str], - ) -> object: - assert hasattr(estimator, "fit") # Keep mypy happy - argspec = inspect.getfullargspec(estimator.fit) - args = {"X": dataset[input_cols]} - - if label_cols: - label_arg_name = "Y" if "Y" in argspec.args else "y" - args[label_arg_name] = dataset[label_cols].squeeze() - - if sample_weight_col is not None and "sample_weight" in argspec.args: - args["sample_weight"] = dataset[sample_weight_col].squeeze() - - return estimator.fit(**args) - - def fit_snowpark( - self, - dataset: DataFrame, - session: Session, - estimator: object, - dependencies: List[str], - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], - ) -> Any: - dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(dataset) - - # If we are already in a stored procedure, no need to kick off another one. - if SNOWML_SPROC_ENV in os.environ: - statement_params = telemetry.get_function_usage_statement_params( - project=_PROJECT, - subproject=self._subproject, - function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name), - api_calls=[Session.call], - custom_tags=dict([("autogen", True)]) if self._autogenerated else None, - ) - pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params) - pd_df.columns = dataset.columns - return self.fit_pandas(pd_df, estimator, input_cols, label_cols, sample_weight_col) - - # Extract query that generated the dataframe. We will need to pass it to the fit procedure. - queries = dataset.queries["queries"] - - # Create a temp file and dump the transform to that file. - local_transform_file_name = get_temp_file_path() - with open(local_transform_file_name, mode="w+b") as local_transform_file: - cp.dump(estimator, local_transform_file) - - # Create temp stage to run fit. - transform_stage_name = random_name_for_temp_object(TempObjectType.STAGE) - stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};" - SqlResultValidator(session=session, query=stage_creation_query).has_dimensions( - expected_rows=1, expected_cols=1 - ).validate() - - # Use posixpath to construct stage paths - stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name)) - stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name)) - local_result_file_name = get_temp_file_path() - - statement_params = telemetry.get_function_usage_statement_params( - project=_PROJECT, - subproject=self._subproject, - function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name), - api_calls=[sproc], - custom_tags=dict([("autogen", True)]) if self._autogenerated else None, - ) - # Put locally serialized transform on stage. - session.file.put( - local_transform_file_name, - stage_transform_file_name, - auto_compress=False, - overwrite=True, - statement_params=statement_params, - ) - - # Call fit sproc - statement_params = telemetry.get_function_usage_statement_params( - project=_PROJECT, - subproject=self._subproject, - function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name), - api_calls=[Session.call], - custom_tags=dict([("autogen", True)]) if self._autogenerated else None, - ) - - fit_wrapper_sproc = self._get_fit_wrapper_sproc(dependencies, session, statement_params) - - try: - sproc_export_file_name: str = fit_wrapper_sproc( - session, - queries, - stage_transform_file_name, - stage_result_file_name, - input_cols, - label_cols, - sample_weight_col, - statement_params, - ) - except snowpark_exceptions.SnowparkClientException as e: - if "fit() missing 1 required positional argument: 'y'" in str(e): - raise exceptions.SnowflakeMLException( - error_code=error_codes.NOT_FOUND, - original_exception=RuntimeError(modeling_error_messages.ATTRIBUTE_NOT_SET.format("label_cols")), - ) from e - raise e - - if "|" in sproc_export_file_name: - fields = sproc_export_file_name.strip().split("|") - sproc_export_file_name = fields[0] - - session.file.get( - posixpath.join(stage_result_file_name, sproc_export_file_name), - local_result_file_name, - statement_params=statement_params, - ) - - with open(os.path.join(local_result_file_name, sproc_export_file_name), mode="r+b") as result_file_obj: - fit_estimator = cp.load(result_file_obj) - - cleanup_temp_files([local_transform_file_name, local_result_file_name]) - - return fit_estimator - def batch_inference( self, dataset: DataFrame, @@ -690,437 +370,3 @@ def score_wrapper_sproc( cleanup_temp_files([local_score_file_name]) return score - - def fit_search_snowpark( - self, - param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler], - dataset: DataFrame, - session: Session, - estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV], - dependencies: List[str], - udf_imports: List[str], - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], - ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]: - from itertools import product - - import cachetools - from sklearn.base import clone, is_classifier - from sklearn.calibration import check_cv - - # Create one stage for data and for estimators. - temp_stage_name = random_name_for_temp_object(TempObjectType.STAGE) - temp_stage_creation_query = f"CREATE OR REPLACE TEMP STAGE {temp_stage_name};" - session.sql(temp_stage_creation_query).collect() - - # Stage data. - dataset = snowpark_dataframe_utils.cast_snowpark_dataframe(dataset) - remote_file_path = f"{temp_stage_name}/{temp_stage_name}.parquet" - dataset.write.copy_into_location( # type:ignore[call-overload] - remote_file_path, file_format_type="parquet", header=True, overwrite=True - ) - imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}").collect()] - - # Store GridSearchCV's refit variable. If user set it as False, we don't need to refit it again - original_refit = estimator.refit - - # Create a temp file and dump the estimator to that file. - estimator_file_name = get_temp_file_path() - params_to_evaluate = [] - for param_to_eval in list(param_grid): - for k, v in param_to_eval.items(): - param_to_eval[k] = [v] - params_to_evaluate.append([param_to_eval]) - - with open(estimator_file_name, mode="w+b") as local_estimator_file_obj: - # Set GridSearchCV refit as False and fit it again after retrieving the best param - estimator.refit = False - cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj) - stage_estimator_file_name = posixpath.join(temp_stage_name, os.path.basename(estimator_file_name)) - sproc_statement_params = telemetry.get_function_usage_statement_params( - project=_PROJECT, - subproject=self._subproject, - function_name=telemetry.get_statement_params_full_func_name( - inspect.currentframe(), self.__class__.__name__ - ), - api_calls=[sproc], - custom_tags=dict([("autogen", True)]) if self._autogenerated else None, - ) - udtf_statement_params = telemetry.get_function_usage_statement_params( - project=_PROJECT, - subproject=self._subproject, - function_name=telemetry.get_statement_params_full_func_name( - inspect.currentframe(), self.__class__.__name__ - ), - api_calls=[udtf], - custom_tags=dict([("autogen", True)]) if self._autogenerated else None, - ) - - # Put locally serialized estimator on stage. - put_result = session.file.put( - estimator_file_name, - temp_stage_name, - auto_compress=False, - overwrite=True, - ) - estimator_location = put_result[0].target - imports.append(f"@{temp_stage_name}/{estimator_location}") - - search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE) - random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION) - - required_deps = dependencies + [ - "snowflake-snowpark-python<2", - "fastparquet<2023.11", - "pyarrow<14", - "cachetools<5", - ] - - @sproc( # type: ignore[misc] - is_permanent=False, - name=search_sproc_name, - packages=required_deps, # type: ignore[arg-type] - replace=True, - session=session, - anonymous=True, - imports=imports, # type: ignore[arg-type] - statement_params=sproc_statement_params, - ) - def _distributed_search( - session: Session, - imports: List[str], - stage_estimator_file_name: str, - input_cols: List[str], - label_cols: List[str], - ) -> str: - import os - import time - from typing import Iterator - - import cloudpickle as cp - import pandas as pd - import pyarrow.parquet as pq - from sklearn.metrics import check_scoring - from sklearn.metrics._scorer import _check_multimetric_scoring - - for import_name in udf_imports: - importlib.import_module(import_name) - - data_files = [ - filename - for filename in os.listdir(sys._xoptions["snowflake_import_directory"]) - if filename.startswith(temp_stage_name) - ] - partial_df = [ - pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas() - for file_name in data_files - ] - df = pd.concat(partial_df, ignore_index=True) - df.columns = [identifier.get_inferred_name(col) for col in df.columns] - - X = df[input_cols] - y = df[label_cols].squeeze() - - local_estimator_file_name = get_temp_file_path() - session.file.get(stage_estimator_file_name, local_estimator_file_name) - - local_estimator_file_path = os.path.join( - local_estimator_file_name, os.listdir(local_estimator_file_name)[0] - ) - with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj: - estimator = cp.load(local_estimator_file_obj)["estimator"] - - cv_orig = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator)) - indices = [test for _, test in cv_orig.split(X, y)] - local_indices_file_name = get_temp_file_path() - with open(local_indices_file_name, mode="w+b") as local_indices_file_obj: - cp.dump(indices, local_indices_file_obj) - - # Put locally serialized indices on stage. - put_result = session.file.put( - local_indices_file_name, - temp_stage_name, - auto_compress=False, - overwrite=True, - ) - indices_location = put_result[0].target - imports.append(f"@{temp_stage_name}/{indices_location}") - indices_len = len(indices) - - assert estimator is not None - - @cachetools.cached(cache={}) - def _load_data_into_udf() -> Tuple[ - Dict[str, pd.DataFrame], - Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV], - pd.DataFrame, - int, - List[Dict[str, Any]], - ]: - import pyarrow.parquet as pq - - data_files = [ - filename - for filename in os.listdir(sys._xoptions["snowflake_import_directory"]) - if filename.startswith(temp_stage_name) - ] - partial_df = [ - pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas() - for file_name in data_files - ] - df = pd.concat(partial_df, ignore_index=True) - df.columns = [identifier.get_inferred_name(col) for col in df.columns] - - # load estimator - local_estimator_file_path = os.path.join( - sys._xoptions["snowflake_import_directory"], f"{estimator_location}" - ) - with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj: - estimator_objects = cp.load(local_estimator_file_obj) - estimator = estimator_objects["estimator"] - params_to_evaluate = estimator_objects["param_grid"] - - # load indices - local_indices_file_path = os.path.join( - sys._xoptions["snowflake_import_directory"], f"{indices_location}" - ) - with open(local_indices_file_path, mode="rb") as local_indices_file_obj: - indices = cp.load(local_indices_file_obj) - - argspec = inspect.getfullargspec(estimator.fit) - args = {"X": df[input_cols]} - - if label_cols: - label_arg_name = "Y" if "Y" in argspec.args else "y" - args[label_arg_name] = df[label_cols].squeeze() - - if sample_weight_col is not None and "sample_weight" in argspec.args: - args["sample_weight"] = df[sample_weight_col].squeeze() - return args, estimator, indices, len(df), params_to_evaluate - - class SearchCV: - def __init__(self) -> None: - args, estimator, indices, data_length, params_to_evaluate = _load_data_into_udf() - self.args = args - self.estimator = estimator - self.indices = indices - self.data_length = data_length - self.params_to_evaluate = params_to_evaluate - - def process(self, params_idx: int, idx: int) -> Iterator[Tuple[str]]: - if hasattr(estimator, "param_grid"): - self.estimator.param_grid = self.params_to_evaluate[params_idx] - else: - self.estimator.param_distributions = self.params_to_evaluate[params_idx] - full_indices = np.array([i for i in range(self.data_length)]) - test_indice = self.indices[idx] - train_indice = np.setdiff1d(full_indices, test_indice) - self.estimator.cv = [(train_indice, test_indice)] - self.estimator.fit(**self.args) - binary_cv_results = None - with io.BytesIO() as f: - cp.dump(self.estimator.cv_results_, f) - f.seek(0) - binary_cv_results = f.getvalue().hex() - yield (binary_cv_results,) - - def end_partition(self) -> None: - ... - - session.udtf.register( - SearchCV, - output_schema=StructType([StructField("CV_RESULTS", StringType())]), - input_types=[IntegerType(), IntegerType()], - name=random_udtf_name, - packages=required_deps, # type: ignore[arg-type] - replace=True, - is_permanent=False, - imports=imports, # type: ignore[arg-type] - statement_params=udtf_statement_params, - ) - - HP_TUNING = F.table_function(random_udtf_name) - - idx_length = int(indices_len) - params_length = len(param_grid) - idxs = [i for i in range(idx_length)] - param_indices, training_indices = [], [] - for param_idx, cv_idx in product([param_index for param_index in range(params_length)], idxs): - param_indices.append(param_idx) - training_indices.append(cv_idx) - - pd_df = pd.DataFrame( - { - "PARAMS": param_indices, - "TRAIN_IND": training_indices, - "PARAM_INDEX": [i for i in range(idx_length * params_length)], - } - ) - df = session.create_dataframe(pd_df) - results = df.select( - F.cast(df["PARAM_INDEX"], IntegerType()).as_("PARAM_INDEX"), - (HP_TUNING(df["PARAMS"], df["TRAIN_IND"]).over(partition_by=df["PARAM_INDEX"])), - ) - - # cv_result maintains the original order - multimetric = False - cv_results_ = dict() - scorers = set() - for i, val in enumerate(results.select("CV_RESULTS").sort(col("PARAM_INDEX")).collect()): - # retrieved string had one more double quote in the front and end of the string. - # use [1:-1] to remove the extra double quotes - hex_str = bytes.fromhex(val[0]) - with io.BytesIO(hex_str) as f_reload: - each_cv_result = cp.load(f_reload) - for k, v in each_cv_result.items(): - cur_cv = i % idx_length - key = k - if "split0_test_" in k: - # For multi-metric evaluation, the scores for all the scorers are available in the - # cv_results_ dict at the keys ending with that scorer’s name ('_') - # instead of '_score'. - scorers.add(k[len("split0_test_") :]) - key = k.replace("split0_test", f"split{cur_cv}_test") - elif k.startswith("param"): - if cur_cv != 0: - key = False - if key: - if key not in cv_results_: - cv_results_[key] = v - else: - cv_results_[key] = np.concatenate([cv_results_[key], v]) - - multimetric = len(scorers) > 1 - # Use numpy to re-calculate all the information in cv_results_ again - # Generally speaking, reshape all the results into the (scorers+2, idx_length, params_length) shape, - # and average them by the idx_length; - # idx_length is the number of cv folds; params_length is the number of parameter combinations - scores = [ - np.reshape( - np.concatenate([cv_results_[f"split{cur_cv}_test_{score}"] for cur_cv in range(idx_length)]), - (idx_length, -1), - ) - for score in scorers - ] - - fit_score_test_matrix = np.stack( - [ - np.reshape(cv_results_["mean_fit_time"], (idx_length, -1)), - np.reshape(cv_results_["mean_score_time"], (idx_length, -1)), - ] - + scores - ) - - mean_fit_score_test_matrix = np.mean(fit_score_test_matrix, axis=1) - std_fit_score_test_matrix = np.std(fit_score_test_matrix, axis=1) - cv_results_["std_fit_time"] = std_fit_score_test_matrix[0] - cv_results_["mean_fit_time"] = mean_fit_score_test_matrix[0] - cv_results_["std_score_time"] = std_fit_score_test_matrix[1] - cv_results_["mean_score_time"] = mean_fit_score_test_matrix[1] - for idx, score in enumerate(scorers): - cv_results_[f"std_test_{score}"] = std_fit_score_test_matrix[idx + 2] - cv_results_[f"mean_test_{score}"] = mean_fit_score_test_matrix[idx + 2] - # re-compute the ranking again with mean_test_. - cv_results_[f"rank_test_{score}"] = rankdata(-cv_results_[f"mean_test_{score}"], method="min") - # The best param is the highest ranking (which is 1) and we choose the first time ranking 1 appeared. - # If all scores are `nan`, `rankdata` will also produce an array of `nan` values. - # In that case, default to first index. - best_param_index = ( - np.where(cv_results_[f"rank_test_{score}"] == 1)[0][0] - if not np.isnan(cv_results_[f"rank_test_{score}"]).all() - else 0 - ) - - estimator.cv_results_ = cv_results_ - estimator.multimetric_ = multimetric - - # Reconstruct the sklearn estimator. - refit_metric = "score" - if callable(estimator.scoring): - scorers = estimator.scoring - elif estimator.scoring is None or isinstance(estimator.scoring, str): - scorers = check_scoring(estimator.estimator, estimator.scoring) - else: - scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring) - estimator._check_refit_for_multimetric(scorers) - refit_metric = original_refit - - estimator.scorer_ = scorers - - # check refit_metric now for a callabe scorer that is multimetric - if callable(estimator.scoring) and estimator.multimetric_: - refit_metric = original_refit - - # For multi-metric evaluation, store the best_index_, best_params_ and - # best_score_ iff refit is one of the scorer names - # In single metric evaluation, refit_metric is "score" - if original_refit or not estimator.multimetric_: - estimator.best_index_ = estimator._select_best_index(original_refit, refit_metric, cv_results_) - if not callable(original_refit): - # With a non-custom callable, we can select the best score - # based on the best index - estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_] - estimator.best_params_ = cv_results_["params"][best_param_index] - - if original_refit: - estimator.best_estimator_ = clone(estimator.estimator).set_params( - **clone(estimator.best_params_, safe=False) - ) - - # Let the sproc use all cores to refit. - estimator.n_jobs = -1 if not estimator.n_jobs else estimator.n_jobs - - # process the input as args - argspec = inspect.getfullargspec(estimator.fit) - args = {"X": X} - if label_cols: - label_arg_name = "Y" if "Y" in argspec.args else "y" - args[label_arg_name] = y - if sample_weight_col is not None and "sample_weight" in argspec.args: - args["sample_weight"] = df[sample_weight_col].squeeze() - estimator.refit = original_refit - refit_start_time = time.time() - estimator.best_estimator_.fit(**args) - refit_end_time = time.time() - estimator.refit_time_ = refit_end_time - refit_start_time - - if hasattr(estimator.best_estimator_, "feature_names_in_"): - estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_ - - local_result_file_name = get_temp_file_path() - - with open(local_result_file_name, mode="w+b") as local_result_file_obj: - cp.dump(estimator, local_result_file_obj) - - session.file.put( - local_result_file_name, - temp_stage_name, - auto_compress=False, - overwrite=True, - ) - - # Note: you can add something like + "|" + str(df) to the return string - # to pass debug information to the caller. - return str(os.path.basename(local_result_file_name)) - - sproc_export_file_name = _distributed_search( - session, - imports, - stage_estimator_file_name, - input_cols, - label_cols, - ) - - local_estimator_path = get_temp_file_path() - session.file.get( - posixpath.join(temp_stage_name, sproc_export_file_name), - local_estimator_path, - ) - - with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj: - fit_estimator = cp.load(result_file_obj) - - cleanup_temp_files([local_estimator_path]) - - return fit_estimator diff --git a/snowflake/ml/modeling/_internal/snowpark_trainer.py b/snowflake/ml/modeling/_internal/snowpark_trainer.py new file mode 100644 index 00000000..3d7aaf39 --- /dev/null +++ b/snowflake/ml/modeling/_internal/snowpark_trainer.py @@ -0,0 +1,331 @@ +import importlib +import inspect +import os +import posixpath +from typing import Any, Callable, Dict, List, Optional, Tuple + +import cloudpickle as cp + +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.exceptions import ( + error_codes, + exceptions, + modeling_error_messages, +) +from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils +from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator +from snowflake.ml._internal.utils.temp_file_utils import ( + cleanup_temp_files, + get_temp_file_path, +) +from snowflake.ml.modeling._internal.model_specifications import ( + ModelSpecifications, + ModelSpecificationsBuilder, +) +from snowflake.snowpark import DataFrame, Session, exceptions as snowpark_exceptions +from snowflake.snowpark._internal.utils import ( + TempObjectType, + random_name_for_temp_object, +) +from snowflake.snowpark.functions import sproc +from snowflake.snowpark.stored_procedure import StoredProcedure + +cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path)) +cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name)) + +_PROJECT = "ModelDevelopment" + + +class SnowparkModelTrainer: + """ + A class for training models on Snowflake data using the Sproc. + + TODO (snandamuri): Introduce the concept of executor that would take the training function + and execute it on the target environments like, local, Snowflake warehouse, or SPCS, etc. + """ + + def __init__( + self, + estimator: object, + dataset: DataFrame, + session: Session, + input_cols: List[str], + label_cols: Optional[List[str]], + sample_weight_col: Optional[str], + autogenerated: bool = False, + subproject: str = "", + ) -> None: + """ + Initializes the SnowparkModelTrainer with a model, a Snowpark DataFrame, feature, and label column names. + + Args: + estimator: SKLearn compatible estimator or transformer object. + dataset: The dataset used for training the model. + session: Snowflake session object to be used for training. + input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training. + label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn. + sample_weight_col: The column name representing the weight of training examples. + autogenerated: A boolean denoting if the trainer is being used by autogenerated code or not. + subproject: subproject name to be used in telemetry. + """ + self.estimator = estimator + self.dataset = dataset + self.session = session + self.input_cols = input_cols + self.label_cols = label_cols + self.sample_weight_col = sample_weight_col + self._autogenerated = autogenerated + self._subproject = subproject + self._class_name = estimator.__class__.__name__ + + def _create_temp_stage(self) -> str: + """ + Creates temporary stage. + + Returns: + Temp stage name. + """ + # Create temp stage to upload pickled model file. + transform_stage_name = random_name_for_temp_object(TempObjectType.STAGE) + stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};" + SqlResultValidator(session=self.session, query=stage_creation_query).has_dimensions( + expected_rows=1, expected_cols=1 + ).validate() + return transform_stage_name + + def _upload_model_to_stage(self, stage_name: str) -> Tuple[str, str]: + """ + Util method to pickle and upload the model to a temp Snowflake stage. + + Args: + stage_name: Stage name to save model. + + Returns: + a tuple containing stage file paths for pickled input model for training and location to store trained + models(response from training sproc). + """ + # Create a temp file and dump the transform to that file. + local_transform_file_name = get_temp_file_path() + with open(local_transform_file_name, mode="w+b") as local_transform_file: + cp.dump(self.estimator, local_transform_file) + + # Use posixpath to construct stage paths + stage_transform_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name)) + stage_result_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name)) + + statement_params = telemetry.get_function_usage_statement_params( + project=_PROJECT, + subproject=self._subproject, + function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name), + api_calls=[sproc], + custom_tags=dict([("autogen", True)]) if self._autogenerated else None, + ) + # Put locally serialized transform on stage. + self.session.file.put( + local_transform_file_name, + stage_transform_file_name, + auto_compress=False, + overwrite=True, + statement_params=statement_params, + ) + + cleanup_temp_files([local_transform_file_name]) + return (stage_transform_file_name, stage_result_file_name) + + def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: Dict[str, str]) -> object: + """ + Downloads the serialized model from a stage location and unpickels it. + + Args: + dir_path: Stage directory path where results are stored. + file_name: File name with in the directory where results are stored. + statement_params: Statement params to be attached to the SQL queries issue form this method. + + Returns: + Deserialized model object. + """ + local_result_file_name = get_temp_file_path() + self.session.file.get( + posixpath.join(dir_path, file_name), + local_result_file_name, + statement_params=statement_params, + ) + + with open(os.path.join(local_result_file_name, file_name), mode="r+b") as result_file_obj: + fit_estimator = cp.load(result_file_obj) + + cleanup_temp_files([local_result_file_name]) + return fit_estimator + + def _build_fit_wrapper_sproc( + self, + model_spec: ModelSpecifications, + ) -> Callable[[Any, List[str], str, str, List[str], List[str], Optional[str], Dict[str, str]], str]: + """ + Constructs and returns a python stored procedure function to be used for training model. + + Args: + model_spec: ModelSpecifications object that contains model specific information + like required imports, package dependencies, etc. + + Returns: + A callable that can be registered as a stored procedure. + """ + imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml + + def fit_wrapper_function( + session: Session, + sql_queries: List[str], + stage_transform_file_name: str, + stage_result_file_name: str, + input_cols: List[str], + label_cols: List[str], + sample_weight_col: Optional[str], + statement_params: Dict[str, str], + ) -> str: + import inspect + import os + + import cloudpickle as cp + import pandas as pd + + for import_name in imports: + importlib.import_module(import_name) + + # Execute snowpark queries and obtain the results as pandas dataframe + # NB: this implies that the result data must fit into memory. + for query in sql_queries[:-1]: + _ = session.sql(query).collect(statement_params=statement_params) + sp_df = session.sql(sql_queries[-1]) + df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params) + df.columns = sp_df.columns + + local_transform_file_name = get_temp_file_path() + + session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params) + + local_transform_file_path = os.path.join( + local_transform_file_name, os.listdir(local_transform_file_name)[0] + ) + with open(local_transform_file_path, mode="r+b") as local_transform_file_obj: + estimator = cp.load(local_transform_file_obj) + + argspec = inspect.getfullargspec(estimator.fit) + args = {"X": df[input_cols]} + if label_cols: + label_arg_name = "Y" if "Y" in argspec.args else "y" + args[label_arg_name] = df[label_cols].squeeze() + + if sample_weight_col is not None and "sample_weight" in argspec.args: + args["sample_weight"] = df[sample_weight_col].squeeze() + + estimator.fit(**args) + + local_result_file_name = get_temp_file_path() + + with open(local_result_file_name, mode="w+b") as local_result_file_obj: + cp.dump(estimator, local_result_file_obj) + + session.file.put( + local_result_file_name, + stage_result_file_name, + auto_compress=False, + overwrite=True, + statement_params=statement_params, + ) + + # Note: you can add something like + "|" + str(df) to the return string + # to pass debug information to the caller. + return str(os.path.basename(local_result_file_name)) + + return fit_wrapper_function + + def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure: + # If the sproc already exists, don't register. + if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"): + self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc] + + model_spec = ModelSpecificationsBuilder.build(model=self.estimator) + fit_sproc_key = model_spec.__class__.__name__ + if fit_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined] + fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined] + return fit_sproc + + fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE) + + fit_wrapper_sproc = self.session.sproc.register( + func=self._build_fit_wrapper_sproc(model_spec=model_spec), + is_permanent=False, + name=fit_sproc_name, + packages=["snowflake-snowpark-python"] + model_spec.pkgDependencies, # type: ignore[arg-type] + replace=True, + session=self.session, + statement_params=statement_params, + ) + + self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] = fit_wrapper_sproc # type: ignore[attr-defined] + + return fit_wrapper_sproc + + def train(self) -> object: + """ + Trains the model by pushing down the compute into Snowflake using stored procedures. + + Returns: + Trained model + + Raises: + e: Raises an exception if any of Snowflake operations fail because of any reason. + SnowflakeMLException: Know exception are caught and rethrow with more detailed error message. + """ + dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(self.dataset) + + # TODO(snandamuri) : Handle the already in a stored procedure case in the in builder. + + # Extract query that generated the dataframe. We will need to pass it to the fit procedure. + queries = dataset.queries["queries"] + + transform_stage_name = self._create_temp_stage() + (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage( + stage_name=transform_stage_name + ) + + # Call fit sproc + statement_params = telemetry.get_function_usage_statement_params( + project=_PROJECT, + subproject=self._subproject, + function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name), + api_calls=[Session.call], + custom_tags=dict([("autogen", True)]) if self._autogenerated else None, + ) + + fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params) + + try: + sproc_export_file_name: str = fit_wrapper_sproc( + self.session, + queries, + stage_transform_file_name, + stage_result_file_name, + self.input_cols, + self.label_cols, + self.sample_weight_col, + statement_params, + ) + except snowpark_exceptions.SnowparkClientException as e: + if "fit() missing 1 required positional argument: 'y'" in str(e): + raise exceptions.SnowflakeMLException( + error_code=error_codes.NOT_FOUND, + original_exception=RuntimeError(modeling_error_messages.ATTRIBUTE_NOT_SET.format("label_cols")), + ) from e + raise e + + if "|" in sproc_export_file_name: + fields = sproc_export_file_name.strip().split("|") + sproc_export_file_name = fields[0] + + return self._fetch_model_from_stage( + dir_path=stage_result_file_name, + file_name=sproc_export_file_name, + statement_params=statement_params, + ) diff --git a/snowflake/ml/modeling/framework/base.py b/snowflake/ml/modeling/framework/base.py index 4c03ef99..186ac831 100644 --- a/snowflake/ml/modeling/framework/base.py +++ b/snowflake/ml/modeling/framework/base.py @@ -133,7 +133,7 @@ def set_label_cols(self, label_cols: Optional[Union[str, Iterable[str]]]) -> "Ba def get_passthrough_cols(self) -> List[str]: """ - Getter method for passthrough_cols attribute. + Passthrough columns getter. Returns: Passthrough column(s). @@ -142,7 +142,7 @@ def get_passthrough_cols(self) -> List[str]: def set_passthrough_cols(self, passthrough_cols: Optional[Union[str, Iterable[str]]]) -> "Base": """ - Setter method passthrough_cols attribute. + Passthrough columns setter. Args: passthrough_cols: Column(s) that should not be used or modified by the estimator/transformer. diff --git a/snowflake/ml/modeling/impute/simple_imputer.py b/snowflake/ml/modeling/impute/simple_imputer.py index 701af133..d9cfc818 100644 --- a/snowflake/ml/modeling/impute/simple_imputer.py +++ b/snowflake/ml/modeling/impute/simple_imputer.py @@ -278,6 +278,7 @@ def fit(self, dataset: snowpark.DataFrame) -> "SimpleImputer": state = STRATEGY_TO_STATE_DICT[self.strategy] assert state is not None dataset_copy = copy.copy(dataset) + dataset_copy = dataset_copy.select(self.input_cols) if not pd.isna(self.missing_values): # Replace `self.missing_values` with null to avoid including it when computing states. dataset_copy = dataset_copy.na.replace(self.missing_values, None) @@ -308,7 +309,6 @@ def fit(self, dataset: snowpark.DataFrame) -> "SimpleImputer": return self @telemetry.send_api_usage_telemetry(project=base.PROJECT, subproject=_SUBPROJECT) - @telemetry.add_stmt_params_to_df(project=base.PROJECT, subproject=_SUBPROJECT) def transform(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[snowpark.DataFrame, pd.DataFrame]: """ Transform the input dataset by imputing the computed statistics in the input columns. diff --git a/snowflake/ml/modeling/metrics/classification.py b/snowflake/ml/modeling/metrics/classification.py index 17a1c0d0..23bf0960 100644 --- a/snowflake/ml/modeling/metrics/classification.py +++ b/snowflake/ml/modeling/metrics/classification.py @@ -5,7 +5,6 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import cloudpickle -import numpy import numpy as np import numpy.typing as npt from sklearn import exceptions, metrics @@ -43,12 +42,17 @@ def accuracy_score( corresponding set of labels in the y true columns. Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted values. - normalize: If ``False``, return the number of correctly classified samples. + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted values. + normalize: boolean, default=True + If ``False``, return the number of correctly classified samples. Otherwise, return the fraction of correctly classified samples. - sample_weight_col_name: Column name representing sample weights. + sample_weight_col_name: string, default=None + Column name representing sample weights. Returns: If ``normalize == True``, return the fraction of correctly @@ -102,14 +106,19 @@ def confusion_matrix( :math:`C_{1,1}` and false positives is :math:`C_{0,1}`. Args: - df: Input dataframe. - y_true_col_name: Column name representing actual values. - y_pred_col_name: Column name representing predicted values. - labels: List of labels to index the matrix. This may be used to + df: snowpark.DataFrame + Input dataframe. + y_true_col_name: string or list of strings + Column name representing actual values. + y_pred_col_name: string or list of strings + Column name representing predicted values. + labels: list of labels, default=None + List of labels to index the matrix. This may be used to reorder or select a subset of labels. If ``None`` is given, those that appear at least once in the y true or y pred column are used in sorted order. - sample_weight_col_name: Column name representing sample weights. + sample_weight_col_name: string, default=None + Column name representing sample weights. normalize: {'true', 'pred', 'all'}, default=None Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population. If None, confusion matrix will not be @@ -124,7 +133,9 @@ def confusion_matrix( Raises: ValueError: The given ``labels`` is empty. + ValueError: No label specified in the given ``labels`` is in the y true column. + ValueError: ``normalize`` is not one of {'true', 'pred', 'all', None}. """ assert df._session is not None @@ -323,17 +334,22 @@ def f1_score( parameter. Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted values. - labels: The set of labels to include when ``average != 'binary'``, and + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted values. + labels: list of labels, default=None + The set of labels to include when ``average != 'binary'``, and their order if ``average is None``. Labels present in the data can be excluded, for example to calculate a multiclass average ignoring a majority negative class, while labels not present in the data will result in 0 components in a macro average. For multilabel targets, labels are column indices. By default, all labels in the y true and y pred columns are used in sorted order. - pos_label: The class to report if ``average='binary'`` and the data is + pos_label: string or integer, default=1 + The class to report if ``average='binary'`` and the data is binary. If the data are multiclass or multilabel, this will be ignored; setting ``labels=[pos_label]`` and ``average != 'binary'`` will report scores for that label only. @@ -359,7 +375,8 @@ def f1_score( Calculate metrics for each instance, and find their average (only meaningful for multilabel classification where this differs from func`accuracy_score`). - sample_weight_col_name: Column name representing sample weights. + sample_weight_col_name: string, default=None + Column name representing sample weights. zero_division: "warn", 0 or 1, default="warn" Sets the value to return when there is a zero division, i.e. when all predictions and labels are negative. If set to "warn", this acts as 0, @@ -408,18 +425,24 @@ def fbeta_score( only recall). Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted values. - beta: Determines the weight of recall in the combined score. - labels: The set of labels to include when ``average != 'binary'``, and + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted values. + beta: float + Determines the weight of recall in the combined score. + labels: list of labels, default=None + The set of labels to include when ``average != 'binary'``, and their order if ``average is None``. Labels present in the data can be excluded, for example to calculate a multiclass average ignoring a majority negative class, while labels not present in the data will result in 0 components in a macro average. For multilabel targets, labels are column indices. By default, all labels in the y true and y pred columns are used in sorted order. - pos_label: The class to report if ``average='binary'`` and the data is + pos_label: string or integer, default=1 + The class to report if ``average='binary'`` and the data is binary. If the data are multiclass or multilabel, this will be ignored; setting ``labels=[pos_label]`` and ``average != 'binary'`` will report scores for that label only. @@ -445,7 +468,8 @@ def fbeta_score( Calculate metrics for each instance, and find their average (only meaningful for multilabel classification where this differs from func`accuracy_score`). - sample_weight_col_name: Column name representing sample weights. + sample_weight_col_name: string, default=None + Column name representing sample weights. zero_division: "warn", 0 or 1, default="warn" Sets the value to return when there is a zero division, i.e. when all predictions and labels are negative. If set to "warn", this acts as 0, @@ -498,9 +522,12 @@ def log_loss( L_{\log}(y, p) = -(y \log (p) + (1 - y) \log (1 - p)) Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted probabilities, + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted probabilities, as returned by a classifier's predict_proba method. If ``y_pred.shape = (n_samples,)`` the probabilities provided are assumed to be that of the positive class. The labels in ``y_pred`` @@ -509,10 +536,13 @@ def log_loss( Log loss is undefined for p=0 or p=1, so probabilities are clipped to `max(eps, min(1 - eps, p))`. The default will depend on the data type of `y_pred` and is set to `np.finfo(y_pred.dtype).eps`. - normalize: If true, return the mean loss per sample. + normalize: boolean, default=True + If true, return the mean loss per sample. Otherwise, return the sum of the per-sample losses. - sample_weight_col_name: Column name representing sample weights. - labels: If not provided, labels will be inferred from y_true. If ``labels`` + sample_weight_col_name: string, default=None + Column name representing sample weights. + labels: list of labels, default=None + If not provided, labels will be inferred from y_true. If ``labels`` is ``None`` and ``y_pred`` has shape (n_samples,) the labels are assumed to be binary and are inferred from ``y_true``. @@ -697,18 +727,24 @@ def precision_recall_fscore_support( is one of ``'micro'``, ``'macro'``, ``'weighted'`` or ``'samples'``. Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted values. - beta: The strength of recall versus precision in the F-score. - labels: The set of labels to include when ``average != 'binary'``, and + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted values. + beta: float, default=1.0 + The strength of recall versus precision in the F-score. + labels: list of labels, default=None + The set of labels to include when ``average != 'binary'``, and their order if ``average is None``. Labels present in the data can be excluded, for example to calculate a multiclass average ignoring a majority negative class, while labels not present in the data will result in 0 components in a macro average. For multilabel targets, labels are column indices. By default, all labels in the y true and y pred columns are used in sorted order. - pos_label: The class to report if ``average='binary'`` and the data is + pos_label: string or integer, default=1 + The class to report if ``average='binary'`` and the data is binary. If the data are multiclass or multilabel, this will be ignored; setting ``labels=[pos_label]`` and ``average != 'binary'`` will report scores for that label only. @@ -733,9 +769,11 @@ def precision_recall_fscore_support( Calculate metrics for each instance, and find their average (only meaningful for multilabel classification where this differs from :func:`accuracy_score`). - warn_for: This determines which warnings will be made in the case that this + warn_for: tuple or set containing "precision", "recall", or "f-score" + This determines which warnings will be made in the case that this function is being used to return only one of its metrics. - sample_weight_col_name: Column name representing sample weights. + sample_weight_col_name: string, default=None + Column name representing sample weights. zero_division: "warn", 0 or 1, default="warn" Sets the value to return when there is a zero division: * recall - when there are no positive labels @@ -980,6 +1018,78 @@ def end_partition( return multilabel_confusion_matrix_computer +def _binary_precision_score( + *, + df: snowpark.DataFrame, + y_true_col_names: Union[str, List[str]], + y_pred_col_names: Union[str, List[str]], + pos_label: Union[str, int] = 1, + sample_weight_col_name: Optional[str] = None, + zero_division: Union[str, int] = "warn", +) -> Union[float, npt.NDArray[np.float_]]: + + statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT) + + if isinstance(y_true_col_names, str): + y_true_col_names = [y_true_col_names] + if isinstance(y_pred_col_names, str): + y_pred_col_names = [y_pred_col_names] + + if len(y_pred_col_names) != len(y_true_col_names): + raise ValueError( + "precision_score: `y_true_col_names` and `y_pred_column_names` must be lists of the same length " + "or both strings." + ) + + # Confirm that the data is binary. + labels_set = set() + columns = y_true_col_names + y_pred_col_names + column_labels_lists = df.select(*[F.array_unique_agg(col) for col in columns]).collect( + statement_params=statement_params + )[0] + for column_labels_list in column_labels_lists: + for column_label in json.loads(column_labels_list): + labels_set.add(column_label) + labels = sorted(list(labels_set)) + _ = _check_binary_labels(labels, pos_label=pos_label) + + sample_weight_column = df[sample_weight_col_name] if sample_weight_col_name else None + + scores = [] + for y_true, y_pred in zip(y_true_col_names, y_pred_col_names): + tp_col = F.iff((F.col(y_true) == pos_label) & (F.col(y_pred) == pos_label), 1, 0) + fp_col = F.iff((F.col(y_true) != pos_label) & (F.col(y_pred) == pos_label), 1, 0) + tp = metrics_utils.weighted_sum( + df=df, + sample_score_column=tp_col, + sample_weight_column=sample_weight_column, + statement_params=statement_params, + ) + fp = metrics_utils.weighted_sum( + df=df, + sample_score_column=fp_col, + sample_weight_column=sample_weight_column, + statement_params=statement_params, + ) + + try: + score = tp / (tp + fp) + except ZeroDivisionError: + if zero_division == "warn": + msg = "precision_score: division by zero: score value will be 0." + warnings.warn(msg, exceptions.UndefinedMetricWarning, stacklevel=2) + score = 0.0 + else: + score = float(zero_division) + + scores.append(score) + + if len(scores) == 1: + return scores[0] + + return np.array(scores) + + @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT) def precision_score( *, @@ -1003,17 +1113,22 @@ def precision_score( The best value is 1 and the worst value is 0. Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted values. - labels: The set of labels to include when ``average != 'binary'``, and + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted values. + labels: list of labels, default=None + The set of labels to include when ``average != 'binary'``, and their order if ``average is None``. Labels present in the data can be excluded, for example to calculate a multiclass average ignoring a majority negative class, while labels not present in the data will result in 0 components in a macro average. For multilabel targets, labels are column indices. By default, all labels in the y true and y pred columns are used in sorted order. - pos_label: The class to report if ``average='binary'`` and the data is + pos_label: string or integer, default=1 + The class to report if ``average='binary'`` and the data is binary. If the data are multiclass or multilabel, this will be ignored; setting ``labels=[pos_label]`` and ``average != 'binary'`` will report scores for that label only. @@ -1038,7 +1153,8 @@ def precision_score( Calculate metrics for each instance, and find their average (only meaningful for multilabel classification where this differs from func`accuracy_score`). - sample_weight_col_name: Column name representing sample weights. + sample_weight_col_name: string, default=None + Column name representing sample weights. zero_division: "warn", 0 or 1, default="warn" Sets the value to return when there is a zero division. If set to "warn", this acts as 0, but warnings are also raised. @@ -1048,6 +1164,16 @@ def precision_score( Precision of the positive class in binary classification or weighted average of the precision of each class for the multiclass task. """ + if average == "binary": + return _binary_precision_score( + df=df, + y_true_col_names=y_true_col_names, + y_pred_col_names=y_pred_col_names, + pos_label=pos_label, + sample_weight_col_name=sample_weight_col_name, + zero_division=zero_division, + ) + p, _, _, _ = precision_recall_fscore_support( df=df, y_true_col_names=y_true_col_names, @@ -1084,17 +1210,22 @@ def recall_score( The best value is 1 and the worst value is 0. Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted values. - labels: The set of labels to include when ``average != 'binary'``, and + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted values. + labels: list of labels, default=None + The set of labels to include when ``average != 'binary'``, and their order if ``average is None``. Labels present in the data can be excluded, for example to calculate a multiclass average ignoring a majority negative class, while labels not present in the data will result in 0 components in a macro average. For multilabel targets, labels are column indices. By default, all labels in the y true and y pred columns are used in sorted order. - pos_label: The class to report if ``average='binary'`` and the data is + pos_label: string or integer, default=1 + The class to report if ``average='binary'`` and the data is binary. If the data are multiclass or multilabel, this will be ignored; setting ``labels=[pos_label]`` and ``average != 'binary'`` will report scores for that label only. @@ -1121,7 +1252,8 @@ def recall_score( Calculate metrics for each instance, and find their average (only meaningful for multilabel classification where this differs from func`accuracy_score`). - sample_weight_col_name: Column name representing sample weights. + sample_weight_col_name: string, default=None + Column name representing sample weights. zero_division: "warn", 0 or 1, default="warn" Sets the value to return when there is a zero division. If set to "warn", this acts as 0, but warnings are also raised. @@ -1190,10 +1322,13 @@ def _check_binary_labels( """ if len(labels) <= 2: if len(labels) == 2 and pos_label not in labels: - raise ValueError(f"pos_label={pos_label} is not a valid label. It should be one of {labels}") + raise ValueError(f"pos_label={pos_label} is not a valid label. It must be one of {labels}") labels = [pos_label] else: - raise ValueError("Please choose another average setting.") + raise ValueError( + "Cannot compute precision score with binary average: there are more than two labels present." + "Please choose another average setting." + ) return labels diff --git a/snowflake/ml/modeling/metrics/correlation.py b/snowflake/ml/modeling/metrics/correlation.py index 70a32290..e0daa611 100644 --- a/snowflake/ml/modeling/metrics/correlation.py +++ b/snowflake/ml/modeling/metrics/correlation.py @@ -36,8 +36,10 @@ def correlation(*, df: snowpark.DataFrame, columns: Optional[Collection[str]] = as a post-processing step. Args: - df (snowpark.DataFrame): Snowpark Dataframe for which correlation matrix has to be computed. - columns (Optional[Collection[str]]): List of column names for which the correlation matrix has to be computed. + df: snowpark.DataFrame + Snowpark Dataframe for which correlation matrix has to be computed. + columns: List of strings + List of column names for which the correlation matrix has to be computed. If None, correlation matrix is computed for all numeric columns in the snowpark dataframe. Returns: diff --git a/snowflake/ml/modeling/metrics/covariance.py b/snowflake/ml/modeling/metrics/covariance.py index 71162a67..1cd7a32a 100644 --- a/snowflake/ml/modeling/metrics/covariance.py +++ b/snowflake/ml/modeling/metrics/covariance.py @@ -36,11 +36,14 @@ def covariance(*, df: DataFrame, columns: Optional[Collection[str]] = None, ddof as a post-processing step. Args: - df (DataFrame): Snowpark Dataframe for which covariance matrix has to be computed. - columns (Optional[Collection[str]]): List of column names for which the covariance matrix has to be computed. + df: snowpark.DataFrame + Snowpark Dataframe for which covariance matrix has to be computed. + columns: list of strings, default=None + List of column names for which the covariance matrix has to be computed. If None, covariance matrix is computed for all numeric columns in the snowpark dataframe. - ddof (int): default 1. Delta degrees of freedom. - The divisor used in calculations is N - ddof, where N represents the number of rows. + ddof: int, default=1 + Delta degrees of freedom. The divisor used in calculations is N - ddof, where N represents the + number of rows. Returns: Covariance matrix in pandas.DataFrame format. diff --git a/snowflake/ml/modeling/metrics/ranking.py b/snowflake/ml/modeling/metrics/ranking.py index 4abb3e4c..19a2e1c1 100644 --- a/snowflake/ml/modeling/metrics/ranking.py +++ b/snowflake/ml/modeling/metrics/ranking.py @@ -49,18 +49,23 @@ def precision_recall_curve( which corresponds to a classifier that always predicts the positive class. Args: - df: Input dataframe. - y_true_col_name: Column name representing true binary labels. + df: snowpark.DataFrame + Input dataframe. + y_true_col_name: string + Column name representing true binary labels. If labels are not either {-1, 1} or {0, 1}, then pos_label should be explicitly given. - probas_pred_col_name: Column name representing target scores. + probas_pred_col_name: string + Column name representing target scores. Can either be probability estimates of the positive class, or non-thresholded measure of decisions (as returned by `decision_function` on some classifiers). - pos_label: The label of the positive class. + pos_label: string or int, default=None + The label of the positive class. When ``pos_label=None``, if y_true is in {-1, 1} or {0, 1}, ``pos_label`` is set to 1, otherwise an error will be raised. - sample_weight_col_name: Column name representing sample weights. + sample_weight_col_name: string, default=None + Column name representing sample weights. Returns: Tuple containing following items @@ -142,12 +147,15 @@ def roc_auc_score( multilabel classification, but some restrictions apply. Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing true labels or binary label indicators. + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing true labels or binary label indicators. The binary and multiclass cases expect labels with shape (n_samples,) while the multilabel case expects binary label indicators with shape (n_samples, n_classes). - y_score_col_names: Column name(s) representing target scores. + y_score_col_names: string or list of strings + Column name(s) representing target scores. * In the binary case, it corresponds to an array of shape `(n_samples,)`. Both probability estimates and non-thresholded decision values can be provided. The probability estimates correspond @@ -186,7 +194,8 @@ class scores must correspond to the order of ``labels``, ``'samples'`` Calculate metrics for each instance, and find their average. Will be ignored when ``y_true`` is binary. - sample_weight_col_name: Column name representing sample weights. + sample_weight_col_name: string, default=None + Column name representing sample weights. max_fpr: float > 0 and <= 1, default=None If not ``None``, the standardized partial AUC [2]_ over the range [0, max_fpr] is returned. For the multiclass case, ``max_fpr``, @@ -208,7 +217,8 @@ class scores must correspond to the order of ``labels``, possible pairwise combinations of classes [5]_. Insensitive to class imbalance when ``average == 'macro'``. - labels: Only used for multiclass targets. List of labels that index the + labels: list of labels, default=None + Only used for multiclass targets. List of labels that index the classes in ``y_score``. If ``None``, the numerical or lexicographical order of the labels in ``y_true`` is used. @@ -282,19 +292,25 @@ def roc_curve( Note: this implementation is restricted to the binary classification task. Args: - df: Input dataframe. - y_true_col_name: Column name representing true binary labels. + df: snowpark.DataFrame + Input dataframe. + y_true_col_name: string + Column name representing true binary labels. If labels are not either {-1, 1} or {0, 1}, then pos_label should be explicitly given. - y_score_col_name: Column name representing target scores, can either + y_score_col_name: string + Column name representing target scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions (as returned by "decision_function" on some classifiers). - pos_label: The label of the positive class. + pos_label: string, default=None + The label of the positive class. When ``pos_label=None``, if `y_true` is in {-1, 1} or {0, 1}, ``pos_label`` is set to 1, otherwise an error will be raised. - sample_weight_col_name: Column name representing sample weights. - drop_intermediate: Whether to drop some suboptimal thresholds which would + sample_weight_col_name: string, default=None + Column name representing sample weights. + drop_intermediate: boolean, default=True + Whether to drop some suboptimal thresholds which would not appear on a plotted ROC curve. This is useful in order to create lighter ROC curves. diff --git a/snowflake/ml/modeling/metrics/regression.py b/snowflake/ml/modeling/metrics/regression.py index c71459c4..1fcfb79e 100644 --- a/snowflake/ml/modeling/metrics/regression.py +++ b/snowflake/ml/modeling/metrics/regression.py @@ -40,10 +40,14 @@ def d2_absolute_error_score( gets a :math:`D^2` score of 0.0. Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted values. - sample_weight_col_name: Column name representing sample weights. + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted values. + sample_weight_col_name: string, default=None + Column name representing sample weights. multioutput: {'raw_values', 'uniform_average'} or array-like of shape \ (n_outputs,), default='uniform_average' Defines aggregating of multiple output values. @@ -128,11 +132,16 @@ def d2_pinball_score( gets a :math:`D^2` score of 0.0. Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted values. - sample_weight_col_name: Column name representing sample weights. - alpha: Slope of the pinball deviance. It determines the quantile level + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted values. + sample_weight_col_name: string, default=None + Column name representing sample weights. + alpha: float, default=0.5 + Slope of the pinball deviance. It determines the quantile level alpha for which the pinball deviance and also D2 are optimal. The default `alpha=0.5` is equivalent to `d2_absolute_error_score`. multioutput: {'raw_values', 'uniform_average'} or array-like of shape \ @@ -233,10 +242,14 @@ def explained_variance_score( the :func:`R^2 score ` should be preferred. Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted values. - sample_weight_col_name: Column name representing sample weights. + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted values. + sample_weight_col_name: string, default=None + Column name representing sample weights. multioutput: {'raw_values', 'uniform_average', 'variance_weighted'} or \ array-like of shape (n_outputs,), default='uniform_average' Defines aggregating of multiple output values. @@ -248,7 +261,8 @@ def explained_variance_score( 'variance_weighted': Scores of all outputs are averaged, weighted by the variances of each individual output. - force_finite: Flag indicating if ``NaN`` and ``-Inf`` scores resulting + force_finite: boolean, default=True + Flag indicating if ``NaN`` and ``-Inf`` scores resulting from constant data should be replaced with real numbers (``1.0`` if prediction is perfect, ``0.0`` otherwise). Default is ``True``, a convenient setting for hyperparameters' search procedures (e.g. grid @@ -323,10 +337,14 @@ def mean_absolute_error( Mean absolute error regression loss. Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted values. - sample_weight_col_name: Column name representing sample weights. + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted values. + sample_weight_col_name: string, default=None + Column name representing sample weights. multioutput: {'raw_values', 'uniform_average'} or array-like of shape \ (n_outputs,), default='uniform_average' Defines aggregating of multiple output values. @@ -398,10 +416,14 @@ def mean_absolute_percentage_error( regression metrics). Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted values. - sample_weight_col_name: Column name representing sample weights. + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted values. + sample_weight_col_name: string, default=None + Column name representing sample weights. multioutput: {'raw_values', 'uniform_average'} or array-like of shape \ (n_outputs,), default='uniform_average' Defines aggregating of multiple output values. @@ -472,10 +494,14 @@ def mean_squared_error( Mean squared error regression loss. Args: - df: Input dataframe. - y_true_col_names: Column name(s) representing actual values. - y_pred_col_names: Column name(s) representing predicted values. - sample_weight_col_name: Column name representing sample weights. + df: snowpark.DataFrame + Input dataframe. + y_true_col_names: string or list of strings + Column name(s) representing actual values. + y_pred_col_names: string or list of strings + Column name(s) representing predicted values. + sample_weight_col_name: string, default=None + Column name representing sample weights. multioutput: {'raw_values', 'uniform_average'} or array-like of shape \ (n_outputs,), default='uniform_average' Defines aggregating of multiple output values. @@ -484,7 +510,8 @@ def mean_squared_error( Returns a full set of errors in case of multioutput input. 'uniform_average': Errors of all outputs are averaged with uniform weight. - squared: If True returns MSE value, if False returns RMSE value. + squared: boolean, default=True + If True returns MSE value, if False returns RMSE value. Returns: loss: float or ndarray of floats @@ -538,12 +565,13 @@ def r2_score(*, df: snowpark.DataFrame, y_true_col_name: str, y_pred_col_name: s non-constant, a constant model that always predicts the average y disregarding the input features would get a :math:`R^2` score of 0.0. - TODO(pdorairaj): Implement other params from sklearn - sample_weight, multi_output, force_finite. - Args: - df: Input dataframe. - y_true_col_name: Column name representing actual values. - y_pred_col_name: Column name representing predicted values. + df: snowpark.DataFrame + Input dataframe. + y_true_col_name: string + Column name representing actual values. + y_pred_col_name: string + Column name representing predicted values. Returns: R squared metric. diff --git a/snowflake/ml/modeling/model_selection/BUILD.bazel b/snowflake/ml/modeling/model_selection/BUILD.bazel index cca563ec..23e28418 100644 --- a/snowflake/ml/modeling/model_selection/BUILD.bazel +++ b/snowflake/ml/modeling/model_selection/BUILD.bazel @@ -28,6 +28,7 @@ py_library( ":init", "//snowflake/ml/_internal:telemetry", "//snowflake/ml/_internal/exceptions", + "//snowflake/ml/modeling/_internal:model_trainer_builder", "//snowflake/ml/modeling/_internal:snowpark_handlers", ], ) @@ -39,6 +40,7 @@ py_library( ":init", "//snowflake/ml/_internal:telemetry", "//snowflake/ml/_internal/exceptions", + "//snowflake/ml/modeling/_internal:model_trainer_builder", "//snowflake/ml/modeling/_internal:snowpark_handlers", ], ) diff --git a/snowflake/ml/modeling/model_selection/grid_search_cv.py b/snowflake/ml/modeling/model_selection/grid_search_cv.py index f6de259a..dd1392db 100644 --- a/snowflake/ml/modeling/model_selection/grid_search_cv.py +++ b/snowflake/ml/modeling/model_selection/grid_search_cv.py @@ -2,13 +2,13 @@ # This code is auto-generated using the sklearn_wrapper_template.py_template template. # Do not modify the auto-generated code(except automatic reformatting by precommit hooks). # -from typing import Dict, Iterable, List, Optional, Set, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Union from uuid import uuid4 +import cloudpickle as cp import numpy as np import pandas as pd import sklearn.model_selection -from sklearn.model_selection import ParameterGrid from sklearn.utils.metaestimators import available_if from snowflake.ml._internal import telemetry @@ -25,13 +25,12 @@ from snowflake.ml.modeling._internal.estimator_protocols import CVHandlers from snowflake.ml.modeling._internal.estimator_utils import ( gather_dependencies, - is_single_node, original_estimator_has_callable, transform_snowml_obj_to_sklearn_obj, validate_sklearn_args, ) +from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder from snowflake.ml.modeling._internal.snowpark_handlers import ( - SklearnModelSelectionWrapperProvider, SnowparkHandlers as HandlersImpl, ) from snowflake.ml.modeling.framework.base import BaseTransformer @@ -53,19 +52,54 @@ class GridSearchCV(BaseTransformer): Parameters ---------- - estimator : estimator object + estimator: estimator object This is assumed to implement the scikit-learn estimator interface. Either estimator needs to provide a ``score`` function, or ``scoring`` must be passed. - param_grid : dict or list of dictionaries + param_grid: dict or list of dictionaries Dictionary with parameters names (`str`) as keys and lists of parameter settings to try as values, or a list of such dictionaries, in which case the grids spanned by each dictionary in the list are explored. This enables searching over any sequence of parameter settings. - scoring : str, callable, list, tuple or dict, default=None + input_cols: Optional[Union[str, List[str]]] + A string or list of strings representing column names that contain features. + If this parameter is not specified, all columns in the input DataFrame except + the columns specified by label_cols and sample-weight_col parameters are + considered input columns. + + label_cols: Optional[Union[str, List[str]]] + A string or list of strings representing column names that contain labels. + This is a required param for estimators, as there is no way to infer these + columns. If this parameter is not specified, then object is fitted without + labels(Like a transformer). + + output_cols: Optional[Union[str, List[str]]] + A string or list of strings representing column names that will store the + output of predict and transform operations. The length of output_cols mus + match the expected number of output columns from the specific estimator or + transformer class used. + If this parameter is not specified, output column names are derived by + adding an OUTPUT_ prefix to the label column names. These inferred output + column names work for estimator's predict() method, but output_cols must + be set explicitly for transformers. + + passthrough_cols: A string or a list of strings indicating column names to be excluded from any + operations (such as train, transform, or inference). These specified column(s) + will remain untouched throughout the process. This option is helpful in scenarios + requiring automatic input_cols inference, but need to avoid using specific + columns, like index columns, during training or inference. + + sample_weight_col: Optional[str] + A string representing the column name containing the examples’ weights. + This argument is only required when working with weighted datasets. + + drop_input_cols: Optional[bool], default=False + If set, the response of predict(), transform() methods will not contain input columns. + + scoring: str, callable, list, tuple or dict, default=None Strategy to evaluate the performance of the cross-validated model on the test set. @@ -83,13 +117,13 @@ class GridSearchCV(BaseTransformer): See :ref:`multimetric_grid_search` for an example. - n_jobs : int, default=None + n_jobs: int, default=None Number of jobs to run in parallel. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. - refit : bool, str, or callable, default=True + refit: bool, str, or callable, default=True Refit an estimator using the best found parameters on the whole dataset. @@ -120,7 +154,7 @@ class GridSearchCV(BaseTransformer): to see how to design a custom selection strategy using a callable via `refit`. - cv : int, cross-validation generator or an iterable, default=None + cv: int, cross-validation generator or an iterable, default=None Determines the cross-validation splitting strategy. Possible inputs for cv are: @@ -137,7 +171,7 @@ class GridSearchCV(BaseTransformer): Refer :ref:`User Guide ` for the various cross-validation strategies that can be used here. - verbose : int + verbose: int Controls the verbosity: the higher, the more messages. - >1 : the computation time for each fold and parameter candidate is @@ -146,7 +180,7 @@ class GridSearchCV(BaseTransformer): - >3 : the fold and candidate parameter indexes are also displayed together with the starting time of the computation. - pre_dispatch : int, or str, default='2*n_jobs' + pre_dispatch: int, or str, default='2*n_jobs' Controls the number of jobs that get dispatched during parallel execution. Reducing this number can be useful to avoid an explosion of memory consumption when more jobs get dispatched @@ -163,13 +197,13 @@ class GridSearchCV(BaseTransformer): - A str, giving an expression as a function of n_jobs, as in '2*n_jobs' - error_score : 'raise' or numeric, default=np.nan + error_score: 'raise' or numeric, default=np.nan Value to assign to the score if an error occurs in estimator fitting. If set to 'raise', the error is raised. If a numeric value is given, FitFailedWarning is raised. This parameter does not affect the refit step, which will always raise the error. - return_train_score : bool, default=False + return_train_score: bool, default=False If ``False``, the ``cv_results_`` attribute will not include training scores. Computing training scores is used to get insights on how different @@ -177,41 +211,6 @@ class GridSearchCV(BaseTransformer): However computing the scores on the training set can be computationally expensive and is not strictly required to select the parameters that yield the best generalization performance. - - input_cols : Optional[Union[str, List[str]]] - A string or list of strings representing column names that contain features. - If this parameter is not specified, all columns in the input DataFrame except - the columns specified by label_cols and sample-weight_col parameters are - considered input columns. - - label_cols : Optional[Union[str, List[str]]] - A string or list of strings representing column names that contain labels. - This is a required param for estimators, as there is no way to infer these - columns. If this parameter is not specified, then object is fitted without - labels(Like a transformer). - - output_cols: Optional[Union[str, List[str]]] - A string or list of strings representing column names that will store the - output of predict and transform operations. The length of output_cols mus - match the expected number of output columns from the specific estimator or - transformer class used. - If this parameter is not specified, output column names are derived by - adding an OUTPUT_ prefix to the label column names. These inferred output - column names work for estimator's predict() method, but output_cols must - be set explicitly for transformers. - - passthrough_cols: A string or a list of strings indicating column names to be excluded from any - operations (such as train, transform, or inference). These specified column(s) - will remain untouched throughout the process. This option is helpful in scenarios - requiring automatic input_cols inference, but need to avoid using specific - columns, like index columns, during training or inference. - - sample_weight_col: Optional[str] - A string representing the column name containing the examples’ weights. - This argument is only required when working with weighted datasets. - - drop_input_cols: Optional[bool], default=False - If set, the response of predict(), transform() methods will not contain input columns. """ _ENABLE_DISTRIBUTED = True @@ -236,7 +235,11 @@ def __init__( # type: ignore[no-untyped-def] sample_weight_col: Optional[str] = None, ) -> None: super().__init__() - deps: Set[str] = set(SklearnModelSelectionWrapperProvider().dependencies) + deps: Set[str] = { + f"numpy=={np.__version__}", + f"scikit-learn=={sklearn.__version__}", + f"cloudpickle=={cp.__version__}", + } deps = deps | gather_dependencies(estimator) self._deps = list(deps) estimator = transform_snowml_obj_to_sklearn_obj(estimator) @@ -253,7 +256,7 @@ def __init__( # type: ignore[no-untyped-def] "return_train_score": (return_train_score, False, False), } cleaned_up_init_args = validate_sklearn_args(args=init_args, klass=sklearn.model_selection.GridSearchCV) - self._sklearn_object = sklearn.model_selection.GridSearchCV( + self._sklearn_object: Any = sklearn.model_selection.GridSearchCV( **cleaned_up_init_args, ) self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None @@ -266,7 +269,6 @@ def __init__( # type: ignore[no-untyped-def] self._handlers: CVHandlers = HandlersImpl( class_name=self.__class__.__name__, subproject=_SUBPROJECT, - wrapper_provider=SklearnModelSelectionWrapperProvider(), ) def _get_rand_id(self) -> str: @@ -294,10 +296,6 @@ def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "GridSearchCV": For more details on this function, see [sklearn.model_selection.GridSearchCV.fit] (https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV.fit) - - Raises: - TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame. - Args: dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame] Snowpark or Pandas DataFrame. @@ -306,70 +304,37 @@ def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "GridSearchCV": self """ self._infer_input_output_cols(dataset) - if isinstance(dataset, pd.DataFrame): - self._estimator = self._handlers.fit_pandas( - dataset, self._sklearn_object, self.input_cols, self.label_cols, self.sample_weight_col - ) - elif isinstance(dataset, DataFrame): - self._fit_snowpark(dataset) - else: - raise TypeError( - f"Unexpected dataset type: {type(dataset)}." - "Supported dataset types: snowpark.DataFrame, pandas.DataFrame." + if self._sklearn_object.n_jobs is None: + self._sklearn_object.n_jobs = -1 + if isinstance(dataset, DataFrame): + session = dataset._session + assert session is not None # keep mypy happy + # Validate that key package version in user workspace are supported in snowflake conda channel + # If customer doesn't have package in conda channel, replace the ones have the closest versions + self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( + pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT ) - self._is_fitted = True - self._get_model_signatures(dataset) - return self - def _fit_snowpark(self, dataset: DataFrame) -> None: - session = dataset._session - assert session is not None # keep mypy happy - # Validate that key package version in user workspace are supported in snowflake conda channel - # If customer doesn't have package in conda channel, replace the ones have the closest versions - self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( - pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT - ) + # Specify input columns so column pruning will be enforced + selected_cols = self._get_active_columns() + if len(selected_cols) > 0: + dataset = dataset.select(selected_cols) - selected_cols = self._get_active_columns() - if len(selected_cols) > 0: - dataset = dataset.select(selected_cols) + self._snowpark_cols = dataset.select(self.input_cols).columns - assert self._sklearn_object is not None - is_distributed = not is_single_node(session) and self._ENABLE_DISTRIBUTED is True - if is_distributed: - # Set the default value of the `n_jobs` attribute for the estimator. - # If minus one is set, it will not be abided by in the UDTF, so we set that to the default value as well. - if hasattr(self._sklearn_object.estimator, "n_jobs") and self._sklearn_object.estimator.n_jobs in [ - None, - -1, - ]: - self._sklearn_object.estimator.n_jobs = DEFAULT_UDTF_NJOBS - self._sklearn_object = self._handlers.fit_search_snowpark( - param_grid=ParameterGrid(self._sklearn_object.param_grid), - dataset=dataset, - session=session, - estimator=self._sklearn_object, - dependencies=self._get_dependencies(), - udf_imports=["sklearn"], - input_cols=self.input_cols, - label_cols=self.label_cols, - sample_weight_col=self.sample_weight_col, - ) - else: - # Fall back with stored procedure implementation - # set the parallel factor to default to minus one, to fully accelerate the cores in single node - if self._sklearn_object.n_jobs is None: - self._sklearn_object.n_jobs = -1 - - self._sklearn_object = self._handlers.fit_snowpark( - dataset, - session, - self._sklearn_object, - ["snowflake-snowpark-python"] + self._get_dependencies(), - self.input_cols, - self.label_cols, - self.sample_weight_col, - ) + model_trainer = ModelTrainerBuilder.build( + estimator=self._sklearn_object, + dataset=dataset, + input_cols=self.input_cols, + label_cols=self.label_cols, + sample_weight_col=self.sample_weight_col, + autogenerated=False, + subproject=_SUBPROJECT, + ) + self._sklearn_object = model_trainer.train() + self._is_fitted = True + self._get_model_signatures(dataset) + return self def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]: if self._drop_input_cols: @@ -523,10 +488,6 @@ def _sklearn_inference( project=_PROJECT, subproject=_SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - ) def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]: """Call predict on the estimator with the best found parameters For more details on this function, see [sklearn.model_selection.GridSearchCV.predict] @@ -569,10 +530,6 @@ def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, p project=_PROJECT, subproject=_SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - ) def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]: """Call transform on the estimator with the best found parameters For more details on this function, see [sklearn.model_selection.GridSearchCV.transform] @@ -636,10 +593,6 @@ def _get_output_column_names(self, output_cols_prefix: str) -> List[str]: project=_PROJECT, subproject=_SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - ) def predict_proba( self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_proba_" ) -> Union[DataFrame, pd.DataFrame]: @@ -677,10 +630,6 @@ def predict_proba( project=_PROJECT, subproject=_SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - ) def predict_log_proba( self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_log_proba_" ) -> Union[DataFrame, pd.DataFrame]: @@ -719,10 +668,6 @@ def predict_log_proba( project=_PROJECT, subproject=_SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - ) def decision_function( self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "decision_function_" ) -> Union[DataFrame, pd.DataFrame]: @@ -759,6 +704,8 @@ def decision_function( @available_if(original_estimator_has_callable("score")) # type: ignore[misc] def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float: """ + If implemented by the original estimator, return the score for the dataset. + Args: dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame] Snowpark or Pandas DataFrame. @@ -811,9 +758,9 @@ def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None # For classifier, the type of predict is the same as the type of label if self._sklearn_object._estimator_type == "classifier": # label columns is the desired type for output - outputs = _infer_signature(dataset[self.label_cols], "output") + outputs = list(_infer_signature(dataset[self.label_cols], "output")) # rename the output columns - outputs = model_signature_utils.rename_features(outputs, self.output_cols) + outputs = list(model_signature_utils.rename_features(outputs, self.output_cols)) self._model_signature_dict["predict"] = ModelSignature( inputs, ([] if self._drop_input_cols else inputs) + outputs ) @@ -850,6 +797,9 @@ def model_signatures(self) -> Dict[str, ModelSignature]: return self._model_signature_dict def to_sklearn(self) -> sklearn.model_selection.GridSearchCV: + """ + Get sklearn.model_selection.GridSearchCV object. + """ assert self._sklearn_object is not None return self._sklearn_object diff --git a/snowflake/ml/modeling/model_selection/randomized_search_cv.py b/snowflake/ml/modeling/model_selection/randomized_search_cv.py index 0843d026..edf912ad 100644 --- a/snowflake/ml/modeling/model_selection/randomized_search_cv.py +++ b/snowflake/ml/modeling/model_selection/randomized_search_cv.py @@ -1,11 +1,11 @@ -from typing import Dict, Iterable, List, Optional, Set, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Union from uuid import uuid4 +import cloudpickle as cp import numpy as np import pandas as pd import sklearn import sklearn.model_selection -from sklearn.model_selection import ParameterSampler from sklearn.utils.metaestimators import available_if from snowflake.ml._internal import telemetry @@ -22,13 +22,12 @@ from snowflake.ml.modeling._internal.estimator_protocols import CVHandlers from snowflake.ml.modeling._internal.estimator_utils import ( gather_dependencies, - is_single_node, original_estimator_has_callable, transform_snowml_obj_to_sklearn_obj, validate_sklearn_args, ) +from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder from snowflake.ml.modeling._internal.snowpark_handlers import ( - SklearnModelSelectionWrapperProvider, SnowparkHandlers as HandlersImpl, ) from snowflake.ml.modeling.framework.base import BaseTransformer @@ -50,13 +49,13 @@ class RandomizedSearchCV(BaseTransformer): Parameters ---------- - estimator : estimator object + estimator: estimator object An object of that type is instantiated for each grid point. This is assumed to implement the scikit-learn estimator interface. Either estimator needs to provide a ``score`` function, or ``scoring`` must be passed. - param_distributions : dict or list of dicts + param_distributions: dict or list of dicts Dictionary with parameters names (`str`) as keys and distributions or lists of parameters to try. Distributions must provide a ``rvs`` method for sampling (such as those from scipy.stats.distributions). @@ -64,11 +63,46 @@ class RandomizedSearchCV(BaseTransformer): If a list of dicts is given, first a dict is sampled uniformly, and then a parameter is sampled using that dict as above. - n_iter : int, default=10 + input_cols: Optional[Union[str, List[str]]] + A string or list of strings representing column names that contain features. + If this parameter is not specified, all columns in the input DataFrame except + the columns specified by label_cols and sample-weight_col parameters are + considered input columns. + + label_cols: Optional[Union[str, List[str]]] + A string or list of strings representing column names that contain labels. + This is a required param for estimators, as there is no way to infer these + columns. If this parameter is not specified, then object is fitted without + labels(Like a transformer). + + output_cols: Optional[Union[str, List[str]]] + A string or list of strings representing column names that will store the + output of predict and transform operations. The length of output_cols mus + match the expected number of output columns from the specific estimator or + transformer class used. + If this parameter is not specified, output column names are derived by + adding an OUTPUT_ prefix to the label column names. These inferred output + column names work for estimator's predict() method, but output_cols must + be set explicitly for transformers. + + passthrough_cols: A string or a list of strings indicating column names to be excluded from any + operations (such as train, transform, or inference). These specified column(s) + will remain untouched throughout the process. This option is helpful in scenarios + requiring automatic input_cols inference, but need to avoid using specific + columns, like index columns, during training or inference. + + sample_weight_col: Optional[str] + A string representing the column name containing the examples’ weights. + This argument is only required when working with weighted datasets. + + drop_input_cols: Optional[bool], default=False + If set, the response of predict(), transform() methods will not contain input columns. + + n_iter: int, default=10 Number of parameter settings that are sampled. n_iter trades off runtime vs quality of the solution. - scoring : str, callable, list, tuple or dict, default=None + scoring: str, callable, list, tuple or dict, default=None Strategy to evaluate the performance of the cross-validated model on the test set. @@ -88,13 +122,13 @@ class RandomizedSearchCV(BaseTransformer): If None, the estimator's score method is used. - n_jobs : int, default=None + n_jobs: int, default=None Number of jobs to run in parallel. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. - refit : bool, str, or callable, default=True + refit: bool, str, or callable, default=True Refit an estimator using the best found parameters on the whole dataset. @@ -121,7 +155,7 @@ class RandomizedSearchCV(BaseTransformer): See ``scoring`` parameter to know more about multiple metric evaluation. - cv : int, cross-validation generator or an iterable, default=None + cv: int, cross-validation generator or an iterable, default=None Determines the cross-validation splitting strategy. Possible inputs for cv are: @@ -138,7 +172,7 @@ class RandomizedSearchCV(BaseTransformer): Refer :ref:`User Guide ` for the various cross-validation strategies that can be used here. - verbose : int + verbose: int Controls the verbosity: the higher, the more messages. - >1 : the computation time for each fold and parameter candidate is @@ -147,7 +181,7 @@ class RandomizedSearchCV(BaseTransformer): - >3 : the fold and candidate parameter indexes are also displayed together with the starting time of the computation. - pre_dispatch : int, or str, default='2*n_jobs' + pre_dispatch: int, or str, default='2*n_jobs' Controls the number of jobs that get dispatched during parallel execution. Reducing this number can be useful to avoid an explosion of memory consumption when more jobs get dispatched @@ -164,20 +198,20 @@ class RandomizedSearchCV(BaseTransformer): - A str, giving an expression as a function of n_jobs, as in '2*n_jobs' - random_state : int, RandomState instance or None, default=None + random_state: int, RandomState instance or None, default=None Pseudo random number generator state used for random uniform sampling from lists of possible values instead of scipy.stats distributions. Pass an int for reproducible output across multiple function calls. See :term:`Glossary `. - error_score : 'raise' or numeric, default=np.nan + error_score: 'raise' or numeric, default=np.nan Value to assign to the score if an error occurs in estimator fitting. If set to 'raise', the error is raised. If a numeric value is given, FitFailedWarning is raised. This parameter does not affect the refit step, which will always raise the error. - return_train_score : bool, default=False + return_train_score: bool, default=False If ``False``, the ``cv_results_`` attribute will not include training scores. Computing training scores is used to get insights on how different @@ -185,41 +219,6 @@ class RandomizedSearchCV(BaseTransformer): However computing the scores on the training set can be computationally expensive and is not strictly required to select the parameters that yield the best generalization performance. - - input_cols : Optional[Union[str, List[str]]] - A string or list of strings representing column names that contain features. - If this parameter is not specified, all columns in the input DataFrame except - the columns specified by label_cols and sample-weight_col parameters are - considered input columns. - - label_cols : Optional[Union[str, List[str]]] - A string or list of strings representing column names that contain labels. - This is a required param for estimators, as there is no way to infer these - columns. If this parameter is not specified, then object is fitted without - labels(Like a transformer). - - output_cols: Optional[Union[str, List[str]]] - A string or list of strings representing column names that will store the - output of predict and transform operations. The length of output_cols mus - match the expected number of output columns from the specific estimator or - transformer class used. - If this parameter is not specified, output column names are derived by - adding an OUTPUT_ prefix to the label column names. These inferred output - column names work for estimator's predict() method, but output_cols must - be set explicitly for transformers. - - passthrough_cols: A string or a list of strings indicating column names to be excluded from any - operations (such as train, transform, or inference). These specified column(s) - will remain untouched throughout the process. This option is helpful in scenarios - requiring automatic input_cols inference, but need to avoid using specific - columns, like index columns, during training or inference. - - sample_weight_col: Optional[str] - A string representing the column name containing the examples’ weights. - This argument is only required when working with weighted datasets. - - drop_input_cols: Optional[bool], default=False - If set, the response of predict(), transform() methods will not contain input columns. """ _ENABLE_DISTRIBUTED = True @@ -246,7 +245,11 @@ def __init__( # type: ignore[no-untyped-def] sample_weight_col: Optional[str] = None, ) -> None: super().__init__() - deps: Set[str] = set(SklearnModelSelectionWrapperProvider().dependencies) + deps: Set[str] = { + f"numpy=={np.__version__}", + f"scikit-learn=={sklearn.__version__}", + f"cloudpickle=={cp.__version__}", + } deps = deps | gather_dependencies(estimator) self._deps = list(deps) estimator = transform_snowml_obj_to_sklearn_obj(estimator) @@ -265,7 +268,7 @@ def __init__( # type: ignore[no-untyped-def] "return_train_score": (return_train_score, False, False), } cleaned_up_init_args = validate_sklearn_args(args=init_args, klass=sklearn.model_selection.RandomizedSearchCV) - self._sklearn_object = sklearn.model_selection.RandomizedSearchCV( + self._sklearn_object: Any = sklearn.model_selection.RandomizedSearchCV( **cleaned_up_init_args, ) self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None @@ -278,7 +281,6 @@ def __init__( # type: ignore[no-untyped-def] self._handlers: CVHandlers = HandlersImpl( class_name=self.__class__.__name__, subproject=_SUBPROJECT, - wrapper_provider=SklearnModelSelectionWrapperProvider(), ) def _get_rand_id(self) -> str: @@ -306,10 +308,6 @@ def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "RandomizedSearchCV": For more details on this function, see [sklearn.model_selection.RandomizedSearchCV.fit] (https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html#sklearn.model_selection.RandomizedSearchCV.fit) - - Raises: - TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame. - Args: dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame] Snowpark or Pandas DataFrame. @@ -318,74 +316,37 @@ def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "RandomizedSearchCV": self """ self._infer_input_output_cols(dataset) - if isinstance(dataset, pd.DataFrame): - self._estimator = self._handlers.fit_pandas( - dataset, self._sklearn_object, self.input_cols, self.label_cols, self.sample_weight_col - ) - elif isinstance(dataset, DataFrame): - self._fit_snowpark(dataset) - else: - raise TypeError( - f"Unexpected dataset type: {type(dataset)}." - "Supported dataset types: snowpark.DataFrame, pandas.DataFrame." + if hasattr(self._sklearn_object, "n_jobs") and self._sklearn_object.n_jobs is None: + self._sklearn_object.n_jobs = -1 + if isinstance(dataset, DataFrame): + session = dataset._session + assert session is not None # keep mypy happy + # Validate that key package version in user workspace are supported in snowflake conda channel + # If customer doesn't have package in conda channel, replace the ones have the closest versions + self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( + pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT ) - self._is_fitted = True - self._get_model_signatures(dataset) - return self - def _fit_snowpark(self, dataset: DataFrame) -> None: - session = dataset._session - assert session is not None # keep mypy happy - # Validate that key package version in user workspace are supported in snowflake conda channel - # If customer doesn't have package in conda channel, replace the ones have the closest versions - self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( - pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT - ) + # Specify input columns so column pruning will be enforced + selected_cols = self._get_active_columns() + if len(selected_cols) > 0: + dataset = dataset.select(selected_cols) - selected_cols = self._get_active_columns() - if len(selected_cols) > 0: - dataset = dataset.select(selected_cols) + self._snowpark_cols = dataset.select(self.input_cols).columns - assert self._sklearn_object is not None - is_distributed = not is_single_node(session) and self._ENABLE_DISTRIBUTED is True - if is_distributed: - # Set the default value of the `n_jobs` attribute for the estimator. - # If minus one is set, it will not be abided by in the UDTF, so we set that to the default value as well. - if hasattr(self._sklearn_object.estimator, "n_jobs") and self._sklearn_object.estimator.n_jobs in [ - None, - -1, - ]: - self._sklearn_object.estimator.n_jobs = DEFAULT_UDTF_NJOBS - self._sklearn_object = self._handlers.fit_search_snowpark( - param_grid=ParameterSampler( - self._sklearn_object.param_distributions, - n_iter=self._sklearn_object.n_iter, - random_state=self._sklearn_object.random_state, - ), - dataset=dataset, - session=session, - estimator=self._sklearn_object, - dependencies=self._get_dependencies(), - udf_imports=["sklearn"], - input_cols=self.input_cols, - label_cols=self.label_cols, - sample_weight_col=self.sample_weight_col, - ) - else: - # Fall back with stored procedure implementation - # set the parallel factor to default to minus one, to fully accelerate the cores in single node - if self._sklearn_object.n_jobs is None: - self._sklearn_object.n_jobs = -1 - - self._sklearn_object = self._handlers.fit_snowpark( - dataset, - session, - self._sklearn_object, - ["snowflake-snowpark-python"] + self._get_dependencies(), - self.input_cols, - self.label_cols, - self.sample_weight_col, - ) + model_trainer = ModelTrainerBuilder.build( + estimator=self._sklearn_object, + dataset=dataset, + input_cols=self.input_cols, + label_cols=self.label_cols, + sample_weight_col=self.sample_weight_col, + autogenerated=False, + subproject=_SUBPROJECT, + ) + self._sklearn_object = model_trainer.train() + self._is_fitted = True + self._get_model_signatures(dataset) + return self def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]: if self._drop_input_cols: @@ -539,10 +500,6 @@ def _sklearn_inference( project=_PROJECT, subproject=_SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - ) def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]: """Call predict on the estimator with the best found parameters For more details on this function, see [sklearn.model_selection.RandomizedSearchCV.predict] @@ -584,10 +541,6 @@ def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, p project=_PROJECT, subproject=_SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - ) def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]: """Call transform on the estimator with the best found parameters For more details on this function, see [sklearn.model_selection.RandomizedSearchCV.transform] @@ -651,10 +604,6 @@ def _get_output_column_names(self, output_cols_prefix: str) -> List[str]: project=_PROJECT, subproject=_SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - ) def predict_proba( self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_proba_" ) -> Union[DataFrame, pd.DataFrame]: @@ -692,10 +641,6 @@ def predict_proba( project=_PROJECT, subproject=_SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - ) def predict_log_proba( self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_log_proba_" ) -> Union[DataFrame, pd.DataFrame]: @@ -734,10 +679,6 @@ def predict_log_proba( project=_PROJECT, subproject=_SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=_PROJECT, - subproject=_SUBPROJECT, - ) def decision_function( self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "decision_function_" ) -> Union[DataFrame, pd.DataFrame]: @@ -774,6 +715,8 @@ def decision_function( @available_if(original_estimator_has_callable("score")) # type: ignore[misc] def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float: """ + If implemented by the original estimator, return the score for the dataset. + Args: dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame] Snowpark or Pandas DataFrame. @@ -826,9 +769,9 @@ def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None # For classifier, the type of predict is the same as the type of label if self._sklearn_object._estimator_type == "classifier": # label columns is the desired type for output - outputs = _infer_signature(dataset[self.label_cols], "output") + outputs = list(_infer_signature(dataset[self.label_cols], "output")) # rename the output columns - outputs = model_signature_utils.rename_features(outputs, self.output_cols) + outputs = list(model_signature_utils.rename_features(outputs, self.output_cols)) self._model_signature_dict["predict"] = ModelSignature( inputs, ([] if self._drop_input_cols else inputs) + outputs ) @@ -865,6 +808,9 @@ def model_signatures(self) -> Dict[str, ModelSignature]: return self._model_signature_dict def to_sklearn(self) -> sklearn.model_selection.RandomizedSearchCV: + """ + Get sklearn.model_selection.RandomizedSearchCV object. + """ assert self._sklearn_object is not None return self._sklearn_object diff --git a/snowflake/ml/modeling/parameters/BUILD.bazel b/snowflake/ml/modeling/parameters/BUILD.bazel index 45de007f..45637277 100644 --- a/snowflake/ml/modeling/parameters/BUILD.bazel +++ b/snowflake/ml/modeling/parameters/BUILD.bazel @@ -8,8 +8,7 @@ py_library( "disable_distributed_hpo.py", ], deps = [ - "//snowflake/ml/modeling/model_selection:grid_search_cv", - "//snowflake/ml/modeling/model_selection:randomized_search_cv", + "//snowflake/ml/modeling/_internal:model_trainer_builder", ], ) @@ -20,8 +19,9 @@ py_test( ], deps = [ ":disable_distributed_hpo", - "//snowflake/ml/modeling/model_selection:grid_search_cv", - "//snowflake/ml/modeling/model_selection:randomized_search_cv", + "//snowflake/ml/modeling/_internal:distributed_hpo_trainer", + "//snowflake/ml/modeling/_internal:model_trainer_builder", + "//snowflake/ml/modeling/_internal:snowpark_trainer", "//snowflake/ml/modeling/xgboost:xgb_classifier", ], ) diff --git a/snowflake/ml/modeling/parameters/disable_distributed_hpo.py b/snowflake/ml/modeling/parameters/disable_distributed_hpo.py index bea0113b..f6ec6576 100644 --- a/snowflake/ml/modeling/parameters/disable_distributed_hpo.py +++ b/snowflake/ml/modeling/parameters/disable_distributed_hpo.py @@ -1,8 +1,4 @@ """Disables the distributed implementation of Grid Search and Randomized Search CV""" -from snowflake.ml.modeling.model_selection.grid_search_cv import GridSearchCV -from snowflake.ml.modeling.model_selection.randomized_search_cv import ( - RandomizedSearchCV, -) +from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder -GridSearchCV._ENABLE_DISTRIBUTED = False -RandomizedSearchCV._ENABLE_DISTRIBUTED = False +ModelTrainerBuilder._ENABLE_DISTRIBUTED = False diff --git a/snowflake/ml/modeling/parameters/disable_distributed_hpo_test.py b/snowflake/ml/modeling/parameters/disable_distributed_hpo_test.py index aee4de00..8574a733 100644 --- a/snowflake/ml/modeling/parameters/disable_distributed_hpo_test.py +++ b/snowflake/ml/modeling/parameters/disable_distributed_hpo_test.py @@ -1,143 +1,40 @@ -from typing import List, Optional, Union from unittest import mock -import pandas as pd from absl.testing import absltest -from sklearn import model_selection +from sklearn.model_selection import GridSearchCV from snowflake.ml.modeling.xgboost.xgb_classifier import XGBClassifier -from snowflake.ml.modeling.model_selection.grid_search_cv import GridSearchCV -from snowflake.ml.modeling.model_selection.randomized_search_cv import ( - RandomizedSearchCV, +from snowflake.ml.modeling._internal.distributed_hpo_trainer import ( + DistributedHPOTrainer, ) +from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder +from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer from snowflake.snowpark import DataFrame, Session -class MockHandlers: - def fit_pandas( - self, - dataset: pd.DataFrame, - estimator: object, - input_cols: List[str], - label_cols: Optional[List[str]], - sample_weight_col: Optional[str], - ) -> object: - raise NotImplementedError - - def batch_inference( - self, - dataset: DataFrame, - session: Session, - estimator: object, - dependencies: List[str], - inference_method: str, - input_cols: List[str], - pass_through_columns: List[str], - expected_output_cols_list: List[str], - expected_output_cols_type: str = "", - ) -> DataFrame: - raise NotImplementedError - - def score_pandas( - self, - dataset: pd.DataFrame, - estimator: object, - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], - ) -> float: - raise NotImplementedError - - def score_snowpark( - self, - dataset: DataFrame, - session: Session, - estimator: object, - dependencies: List[str], - score_sproc_imports: List[str], - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], - ) -> float: - raise NotImplementedError - - def fit_snowpark( - self, - dataset: DataFrame, - session: Session, - estimator: object, - dependencies: List[str], - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], - ) -> object: - response_obj = mock.Mock(spec=model_selection.GridSearchCV) - response_obj.function = "FIT_SNOWPARK" - return response_obj - - def fit_search_snowpark( - self, - param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler], - dataset: DataFrame, - session: Session, - estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV], - dependencies: List[str], - udf_imports: List[str], - input_cols: List[str], - label_cols: List[str], - sample_weight_col: Optional[str], - ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]: - response_obj = mock.Mock(spec=model_selection.GridSearchCV) - response_obj.function = "FIT_SEARCH" - return response_obj - - class DisableDistributedHPOTest(absltest.TestCase): - @mock.patch( - "snowflake.ml.modeling.model_selection.grid_search_cv.pkg_version_utils" - ".get_valid_pkg_versions_supported_in_snowflake_conda_channel" - ) - @mock.patch("snowflake.ml.modeling.model_selection.grid_search_cv.is_single_node") - def test_disable_distributed_hpo(self, is_single_node_mock: mock.Mock, pkg_version_mock: mock.Mock) -> None: + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") + def test_disable_distributed_hpo(self, is_single_node_mock: mock.Mock) -> None: is_single_node_mock.return_value = False - pkg_version_mock.return_value = [] + mock_session = mock.MagicMock(spec=Session) mock_dataframe = mock.MagicMock(spec=DataFrame) mock_dataframe._session = mock_session - estimator = XGBClassifier() - grid_search_cv = GridSearchCV(estimator=estimator, param_grid=dict(fake=[1, 2])) - grid_search_cv._handlers = MockHandlers() - - randomized_search_cv = RandomizedSearchCV(estimator=estimator, param_distributions=dict(fake=[1, 2])) - randomized_search_cv._handlers = MockHandlers() + estimator = GridSearchCV(param_grid={"max_leaf_nodes": [10, 100]}, estimator=XGBClassifier()) - grid_search_cv._fit_snowpark(mock_dataframe) - randomized_search_cv._fit_snowpark(mock_dataframe) + trainer = ModelTrainerBuilder.build(estimator=estimator, dataset=mock_dataframe, input_cols=[]) - assert grid_search_cv._sklearn_object is not None - assert randomized_search_cv._sklearn_object is not None - self.assertTrue(grid_search_cv._sklearn_object.function, "FIT_SEARCH") - self.assertEqual(randomized_search_cv._sklearn_object.function, "FIT_SEARCH") + self.assertTrue(isinstance(trainer, DistributedHPOTrainer)) # Disable distributed HPO import snowflake.ml.modeling.parameters.disable_distributed_hpo # noqa: F401 - self.assertFalse(GridSearchCV._ENABLE_DISTRIBUTED) - self.assertFalse(RandomizedSearchCV._ENABLE_DISTRIBUTED) - - grid_search_cv = GridSearchCV(estimator=estimator, param_grid=dict(fake=[1, 2])) - grid_search_cv._handlers = MockHandlers() - randomized_search_cv = RandomizedSearchCV(estimator=estimator, param_distributions=dict(fake=[1, 2])) - randomized_search_cv._handlers = MockHandlers() - - grid_search_cv._fit_snowpark(mock_dataframe) - randomized_search_cv._fit_snowpark(mock_dataframe) + self.assertFalse(ModelTrainerBuilder._ENABLE_DISTRIBUTED) + trainer = ModelTrainerBuilder.build(estimator=estimator, dataset=mock_dataframe, input_cols=[]) - assert grid_search_cv._sklearn_object is not None - assert randomized_search_cv._sklearn_object is not None - self.assertTrue(grid_search_cv._sklearn_object.function, "FIT_SNOWPARK") - self.assertEqual(randomized_search_cv._sklearn_object.function, "FIT_SNOWPARK") + self.assertTrue(isinstance(trainer, SnowparkModelTrainer)) + self.assertFalse(isinstance(trainer, DistributedHPOTrainer)) if __name__ == "__main__": diff --git a/snowflake/ml/modeling/preprocessing/binarizer.py b/snowflake/ml/modeling/preprocessing/binarizer.py index 9fda5ab1..48e8e18e 100644 --- a/snowflake/ml/modeling/preprocessing/binarizer.py +++ b/snowflake/ml/modeling/preprocessing/binarizer.py @@ -21,16 +21,25 @@ class Binarizer(base.BaseTransformer): (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.Binarizer.html). Args: - threshold: Feature values below or equal to this are replaced by 0, above it by 1. Default values is 0.0. - input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be binarized. - output_cols: The name(s) of one or more columns in a DataFrame in which results will be stored. The number of + threshold: float, default=0.0 + Feature values below or equal to this are replaced by 0, above it by 1. Default values is 0.0. + + input_cols: Optional[Union[str, Iterable[str]]], default=None + The name(s) of one or more columns in a DataFrame containing a feature to be binarized. + + output_cols: Optional[Union[str, Iterable[str]]], default=None + The name(s) of one or more columns in a DataFrame in which results will be stored. The number of columns specified must match the number of input columns. - passthrough_cols: A string or a list of strings indicating column names to be excluded from any + + passthrough_cols: Optional[Union[str, Iterable[str]]], default=None + A string or a list of strings indicating column names to be excluded from any operations (such as train, transform, or inference). These specified column(s) will remain untouched throughout the process. This option is helpful in scenarios requiring automatic input_cols inference, but need to avoid using specific columns, like index columns, during training or inference. - drop_input_cols: Remove input columns from output if set True. False by default. + + drop_input_cols: Optional[bool], default=False + Remove input columns from output if set True. False by default. """ def __init__( @@ -108,10 +117,6 @@ def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "Binarizer": project=base.PROJECT, subproject=base.SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=base.PROJECT, - subproject=base.SUBPROJECT, - ) def transform(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[snowpark.DataFrame, pd.DataFrame]: """ Binarize the data. Map to 1 if it is strictly greater than the threshold, otherwise 0. diff --git a/snowflake/ml/modeling/preprocessing/k_bins_discretizer.py b/snowflake/ml/modeling/preprocessing/k_bins_discretizer.py index 93b2e507..6df30264 100644 --- a/snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +++ b/snowflake/ml/modeling/preprocessing/k_bins_discretizer.py @@ -177,10 +177,6 @@ def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> KBinsDiscreti project=base.PROJECT, subproject=base.SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=base.PROJECT, - subproject=base.SUBPROJECT, - ) def transform( self, dataset: Union[snowpark.DataFrame, pd.DataFrame] ) -> Union[snowpark.DataFrame, pd.DataFrame, sparse.csr_matrix]: diff --git a/snowflake/ml/modeling/preprocessing/label_encoder.py b/snowflake/ml/modeling/preprocessing/label_encoder.py index a30ba382..1b82b168 100644 --- a/snowflake/ml/modeling/preprocessing/label_encoder.py +++ b/snowflake/ml/modeling/preprocessing/label_encoder.py @@ -24,15 +24,22 @@ class LabelEncoder(base.BaseTransformer): (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html). Args: - input_cols: The name of a column in a DataFrame to be encoded. May be a string or a list containing one string. - output_cols: The name of a column in a DataFrame where the results will be stored. May be a string or a list + input_cols: Optional[Union[str, List[str]]] + The name of a column in a DataFrame to be encoded. May be a string or a list containing one string. + + output_cols: Optional[Union[str, List[str]]] + The name of a column in a DataFrame where the results will be stored. May be a string or a list containing one string. - passthrough_cols: A string or a list of strings indicating column names to be excluded from any + + passthrough_cols: Optional[Union[str, List[str]]] + A string or a list of strings indicating column names to be excluded from any operations (such as train, transform, or inference). These specified column(s) will remain untouched throughout the process. This option is helpful in scenarios requiring automatic input_cols inference, but need to avoid using specific columns, like index columns, during training or inference. - drop_input_cols: Remove input columns from output if set True. False by default. + + drop_input_cols: Optional[bool], default=False + Remove input columns from output if set True. False by default. """ def __init__( @@ -46,19 +53,24 @@ def __init__( Encode target labels with integers between 0 and n_classes-1. Args: - input_cols: The name of a column in a DataFrame to be encoded. May be a string or a list containing one + input_cols: Optional[Union[str, List[str]]] + The name of a column in a DataFrame to be encoded. May be a string or a list containing one string. - output_cols: The name of a column in a DataFrame where the results will be stored. May be a string or a list + output_cols: Optional[Union[str, List[str]]] + The name of a column in a DataFrame where the results will be stored. May be a string or a list containing one string. - passthrough_cols: A string or a list of strings indicating column names to be excluded from any + passthrough_cols: Optional[Union[str, List[str]]] + A string or a list of strings indicating column names to be excluded from any operations (such as train, transform, or inference). These specified column(s) will remain untouched throughout the process. This option is helful in scenarios requiring automatic input_cols inference, but need to avoid using specific columns, like index columns, during in training or inference. - drop_input_cols: Remove input columns from output if set True. False by default. + drop_input_cols: Optional[bool], default=False + Remove input columns from output if set True. False by default. Attributes: - classes_: A np.ndarray that holds the label for each class. + classes_: Optional[type_utils.LiteralNDArrayType] + A np.ndarray that holds the label for each class. Attributes are valid only after fit() has been called. """ @@ -126,10 +138,6 @@ def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "LabelEncoder project=base.PROJECT, subproject=base.SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=base.PROJECT, - subproject=base.SUBPROJECT, - ) def transform(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[snowpark.DataFrame, pd.DataFrame]: """ Use fit result to transform snowpark dataframe or pandas dataframe. The original dataset with diff --git a/snowflake/ml/modeling/preprocessing/max_abs_scaler.py b/snowflake/ml/modeling/preprocessing/max_abs_scaler.py index 6904ed98..9b8fa69f 100644 --- a/snowflake/ml/modeling/preprocessing/max_abs_scaler.py +++ b/snowflake/ml/modeling/preprocessing/max_abs_scaler.py @@ -27,19 +27,29 @@ class MaxAbsScaler(base.BaseTransformer): (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MaxAbsScaler.html). Args: - input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be scaled. - output_cols: The name(s) of one or more columns in a DataFrame in which results will be stored. The number of + input_cols: Optional[Union[str, List[str]]], default=None + The name(s) of one or more columns in a DataFrame containing a feature to be scaled. + + output_cols: Optional[Union[str, List[str]]], default=None + The name(s) of one or more columns in a DataFrame in which results will be stored. The number of columns specified must match the number of input columns. - passthrough_cols: A string or a list of strings indicating column names to be excluded from any - operations (such as train, transform, or inference). These specified column(s) - will remain untouched throughout the process. This option is helpful in scenarios - requiring automatic input_cols inference, but need to avoid using specific - columns, like index columns, during training or inference. - drop_input_cols: Remove input columns from output if set True. False by default. + + passthrough_cols: Optional[Union[str, List[str]]], default=None + A string or a list of strings indicating column names to be excluded from any + operations (such as train, transform, or inference). These specified column(s) + will remain untouched throughout the process. This option is helpful in scenarios + requiring automatic input_cols inference, but need to avoid using specific + columns, like index columns, during training or inference. + + drop_input_cols: Optional[bool], default=False + Remove input columns from output if set True. False by default. Attributes: - scale_: dict {column_name: value} or None. Per-feature relative scaling factor. - max_abs_: dict {column_name: value} or None. Per-feature maximum absolute value. + scale_: Dict[str, float] + dict {column_name: value} or None. Per-feature relative scaling factor. + + max_abs_: Dict[str, float] + dict {column_name: value} or None. Per-feature maximum absolute value. """ def __init__( @@ -150,10 +160,6 @@ def _fit_snowpark(self, dataset: snowpark.DataFrame) -> None: project=base.PROJECT, subproject=base.SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=base.PROJECT, - subproject=base.SUBPROJECT, - ) def transform(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[snowpark.DataFrame, pd.DataFrame]: """ Scale the data. diff --git a/snowflake/ml/modeling/preprocessing/min_max_scaler.py b/snowflake/ml/modeling/preprocessing/min_max_scaler.py index e677a68f..9e122cb7 100644 --- a/snowflake/ml/modeling/preprocessing/min_max_scaler.py +++ b/snowflake/ml/modeling/preprocessing/min_max_scaler.py @@ -21,25 +21,45 @@ class MinMaxScaler(base.BaseTransformer): (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html). Args: - feature_range: Desired range of transformed data (default is 0 to 1). - clip: Whether to clip transformed values of held-out data to the specified feature range (default is True). - input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be scaled. Each specified + feature_range: Tuple[float, float], default=(0, 1) + Desired range of transformed data (default is 0 to 1). + + clip: bool, default=False + Whether to clip transformed values of held-out data to the specified feature range (default is True). + + input_cols: Optional[Union[str, List[str]]], default=None + The name(s) of one or more columns in a DataFrame containing a feature to be scaled. Each specified input column is scaled independently and stored in the corresponding output column. - output_cols: The name(s) of one or more columns in a DataFrame in which results will be stored. The number of + + output_cols: Optional[Union[str, List[str]]], default=None + The name(s) of one or more columns in a DataFrame in which results will be stored. The number of columns specified must match the number of input columns. - passthrough_cols: A string or a list of strings indicating column names to be excluded from any - operations (such as train, transform, or inference). These specified column(s) - will remain untouched throughout the process. This option is helpful in scenarios - requiring automatic input_cols inference, but need to avoid using specific - columns, like index columns, during training or inference. - drop_input_cols: Remove input columns from output if set True. False by default. + + passthrough_cols: Optional[Union[str, List[str]]], default=None + A string or a list of strings indicating column names to be excluded from any + operations (such as train, transform, or inference). These specified column(s) + will remain untouched throughout the process. This option is helpful in scenarios + requiring automatic input_cols inference, but need to avoid using specific + columns, like index columns, during training or inference. + + drop_input_cols: Optional[bool], default=False + Remove input columns from output if set True. False by default. Attributes: - min_: dict {column_name: value} or None. Per-feature adjustment for minimum. - scale_: dict {column_name: value} or None. Per-feature relative scaling factor. - data_min_: dict {column_name: value} or None. Per-feature minimum seen in the data. - data_max_: dict {column_name: value} or None. Per-feature maximum seen in the data. - data_range_: dict {column_name: value} or None. Per-feature range seen in the data as a (min, max) tuple. + min_: Dict[str, float] + dict {column_name: value} or None. Per-feature adjustment for minimum. + + scale_: Dict[str, float] + dict {column_name: value} or None. Per-feature relative scaling factor. + + data_min_: Dict[str, float] + dict {column_name: value} or None. Per-feature minimum seen in the data. + + data_max_: Dict[str, float] + dict {column_name: value} or None. Per-feature maximum seen in the data. + + data_range_: Dict[str, float] + dict {column_name: value} or None. Per-feature range seen in the data as a (min, max) tuple. """ def __init__( @@ -170,10 +190,6 @@ def _fit_snowpark(self, dataset: snowpark.DataFrame) -> None: project=base.PROJECT, subproject=base.SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=base.PROJECT, - subproject=base.SUBPROJECT, - ) def transform(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[snowpark.DataFrame, pd.DataFrame]: """ Scale features according to feature_range. diff --git a/snowflake/ml/modeling/preprocessing/normalizer.py b/snowflake/ml/modeling/preprocessing/normalizer.py index c1c4e96a..8a14df4b 100644 --- a/snowflake/ml/modeling/preprocessing/normalizer.py +++ b/snowflake/ml/modeling/preprocessing/normalizer.py @@ -34,11 +34,12 @@ class Normalizer(base.BaseTransformer): A string or list of strings representing column names that will store the output of transform operation. The length of `output_cols` must equal the length of `input_cols`. - passthrough_cols: A string or a list of strings indicating column names to be excluded from any - operations (such as train, transform, or inference). These specified column(s) - will remain untouched throughout the process. This option is helpful in scenarios - requiring automatic input_cols inference, but need to avoid using specific - columns, like index columns, during training or inference. + passthrough_cols: Optional[Union[str, List[str]]] + A string or a list of strings indicating column names to be excluded from any + operations (such as train, transform, or inference). These specified column(s) + will remain untouched throughout the process. This option is helpful in scenarios + requiring automatic input_cols inference, but need to avoid using specific + columns, like index columns, during training or inference. drop_input_cols: bool, default=False Remove input columns from output if set `True`. @@ -90,10 +91,6 @@ def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "Normalizer": project=base.PROJECT, subproject=base.SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=base.PROJECT, - subproject=base.SUBPROJECT, - ) def transform(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[snowpark.DataFrame, pd.DataFrame]: """ Scale each non-zero row of the input dataset to the unit norm. diff --git a/snowflake/ml/modeling/preprocessing/one_hot_encoder.py b/snowflake/ml/modeling/preprocessing/one_hot_encoder.py index 623e2b37..010f252e 100644 --- a/snowflake/ml/modeling/preprocessing/one_hot_encoder.py +++ b/snowflake/ml/modeling/preprocessing/one_hot_encoder.py @@ -101,7 +101,7 @@ class OneHotEncoder(base.BaseTransformer): (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html). Args: - categories: 'auto' or dict {column_name: ndarray([category])}, default='auto' + categories: 'auto' or dict {column_name: np.ndarray([category])}, default='auto' Categories (unique values) per feature: - 'auto': Determine categories automatically from the training data. - dict: ``categories[column_name]`` holds the categories expected in @@ -109,6 +109,7 @@ class OneHotEncoder(base.BaseTransformer): and numeric values within a single feature, and should be sorted in case of numeric values. The used categories can be found in the ``categories_`` attribute. + drop: {‘first’, ‘if_binary’} or an array-like of shape (n_features,), default=None Specifies a methodology to use to drop one of the categories per feature. This is useful in situations where perfectly collinear @@ -128,15 +129,18 @@ class OneHotEncoder(base.BaseTransformer): When `max_categories` or `min_frequency` is configured to group infrequent categories, the dropping behavior is handled after the grouping. + sparse: bool, default=False Will return a column with sparse representation if set True else will return a separate column for each category. + handle_unknown: {'error', 'ignore'}, default='error' Specifies the way unknown categories are handled during :meth:`transform`. - 'error': Raise an error if an unknown category is present during transform. - 'ignore': When an unknown category is encountered during transform, the resulting one-hot encoded columns for this feature will be all zeros. + min_frequency: int or float, default=None Specifies the minimum frequency below which a category will be considered infrequent. @@ -144,22 +148,29 @@ class OneHotEncoder(base.BaseTransformer): infrequent. - If `float`, categories with a smaller cardinality than `min_frequency * n_samples` will be considered infrequent. + max_categories: int, default=None Specifies an upper limit to the number of output features for each input feature when considering infrequent categories. If there are infrequent categories, `max_categories` includes the category representing the infrequent categories along with the frequent categories. If `None`, there is no limit to the number of output features. - input_cols: str or Iterable [column_name], default=None + + input_cols: Optional[Union[str, List[str]]], default=None Single or multiple input columns. - output_cols: str or Iterable [column_name], default=None + + output_cols: Optional[Union[str, List[str]]], default=None Single or multiple output columns. - passthrough_cols: A string or a list of strings indicating column names to be excluded from any - operations (such as train, transform, or inference). These specified column(s) - will remain untouched throughout the process. This option is helpful in scenarios - requiring automatic input_cols inference, but need to avoid using specific - columns, like index columns, during training or inference. - drop_input_cols: Remove input columns from output if set True. False by default. + + passthrough_cols: Optional[Union[str, List[str]]] + A string or a list of strings indicating column names to be excluded from any + operations (such as train, transform, or inference). These specified column(s) + will remain untouched throughout the process. This option is helpful in scenarios + requiring automatic input_cols inference, but need to avoid using specific + columns, like index columns, during training or inference. + + drop_input_cols: Optional[Union[str, List[str]]] + Remove input columns from output if set True. False by default. Attributes: categories_: dict {column_name: ndarray([category])} @@ -665,10 +676,6 @@ def map_encoding(row: pd.Series) -> int: project=base.PROJECT, subproject=base.SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=base.PROJECT, - subproject=base.SUBPROJECT, - ) def transform( self, dataset: Union[snowpark.DataFrame, pd.DataFrame] ) -> Union[snowpark.DataFrame, pd.DataFrame, sparse.csr_matrix]: diff --git a/snowflake/ml/modeling/preprocessing/ordinal_encoder.py b/snowflake/ml/modeling/preprocessing/ordinal_encoder.py index db914fd0..1da49922 100644 --- a/snowflake/ml/modeling/preprocessing/ordinal_encoder.py +++ b/snowflake/ml/modeling/preprocessing/ordinal_encoder.py @@ -45,31 +45,47 @@ class OrdinalEncoder(base.BaseTransformer): (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OrdinalEncoder.html). Args: - categories: The string 'auto' (the default) causes the categories to be extracted from the input columns. + categories: Union[str, Dict[str, type_utils.LiteralNDArrayType]], default="auto" + The string 'auto' (the default) causes the categories to be extracted from the input columns. To specify the categories yourself, pass a dictionary mapping the column name to an ndarray containing the categories. - handle_unknown: Specifies how unknown categories are handled during transformation. Applicable only if + + handle_unknown: str, default="error" + Specifies how unknown categories are handled during transformation. Applicable only if categories is not 'auto'. Valid values are: - 'error': Raise an error if an unknown category is present during transform (default). - 'use_encoded_value': When an unknown category is encountered during transform, the specified encoded_missing_value (below) is used. - unknown_value: When the parameter handle_unknown is set to 'use_encoded_value', this parameter is required and + + unknown_value: Optional[Union[int, float]], default=None + When the parameter handle_unknown is set to 'use_encoded_value', this parameter is required and will set the encoded value of unknown categories. It has to be distinct from the values used to encode any of the categories in `fit`. - encoded_missing_value: The value to be used to encode unknown categories. - input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be encoded. - output_cols: The name(s) of one or more columns in a DataFrame in which results will be stored. The number of + + encoded_missing_value: Union[int, float], default=np.nan + The value to be used to encode unknown categories. + + input_cols: Optional[Union[str, List[str]]], default=None + The name(s) of one or more columns in a DataFrame containing a feature to be encoded. + + output_cols: Optional[Union[str, List[str]]], default=None + The name(s) of one or more columns in a DataFrame in which results will be stored. The number of columns specified must match the number of input columns. - passthrough_cols: A string or a list of strings indicating column names to be excluded from any + + passthrough_cols: Optional[Union[str, List[str]]], default=None + A string or a list of strings indicating column names to be excluded from any operations (such as train, transform, or inference). These specified column(s) will remain untouched throughout the process. This option is helpful in scenarios requiring automatic input_cols inference, but need to avoid using specific columns, like index columns, during training or inference. - drop_input_cols: Remove input columns from output if set True. False by default. + + drop_input_cols: Optional[bool], default=False + Remove input columns from output if set True. False by default. Attributes: - categories_ (dict of ndarray): The categories of each feature determined during fitting. Maps input column + categories_ (dict of ndarray): List[type_utils.LiteralNDArrayType] + The categories of each feature determined during fitting. Maps input column names to an array of the detected categories. Attributes are valid only after fit() has been called. """ @@ -429,10 +445,6 @@ def _validate_encoded_missing_value(self) -> None: project=base.PROJECT, subproject=base.SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=base.PROJECT, - subproject=base.SUBPROJECT, - ) def transform(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[snowpark.DataFrame, pd.DataFrame]: """ Transform dataset to ordinal codes. diff --git a/snowflake/ml/modeling/preprocessing/robust_scaler.py b/snowflake/ml/modeling/preprocessing/robust_scaler.py index 3a7c0547..db1df6de 100644 --- a/snowflake/ml/modeling/preprocessing/robust_scaler.py +++ b/snowflake/ml/modeling/preprocessing/robust_scaler.py @@ -20,28 +20,46 @@ class RobustScaler(base.BaseTransformer): (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html). Args: - with_centering: If True, center the data around zero before scaling. - with_scaling: If True, scale the data to interquartile range. - quantile_range: tuple like (q_min, q_max), where 0.0 < q_min < q_max < 100.0, default=(25.0, 75.0). Quantile + with_centering: bool, default=True + If True, center the data around zero before scaling. + + with_scaling: bool, default=True + If True, scale the data to interquartile range. + + quantile_range: Tuple[float, float], default=(25.0, 75.0) + tuple like (q_min, q_max), where 0.0 < q_min < q_max < 100.0, default=(25.0, 75.0). Quantile range used to calculate scale_. By default, this is equal to the IQR, i.e., q_min is the first quantile and q_max is the third quantile. - unit_variance: If True, scale data so that normally-distributed features have a variance of 1. In general, if + + unit_variance: bool, default=False + If True, scale data so that normally-distributed features have a variance of 1. In general, if the difference between the x-values of q_max and q_min for a standard normal distribution is greater than 1, the dataset is scaled down. If less than 1, the dataset is scaled up. - input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be scaled. - output_cols: The name(s) of one or more columns in a DataFrame in which results will be stored. The number of + + input_cols: Optional[Union[str, List[str]]], default=None + The name(s) of one or more columns in a DataFrame containing a feature to be scaled. + + output_cols: Optional[Union[str, List[str]]], default=None + The name(s) of one or more columns in a DataFrame in which results will be stored. The number of columns specified must match the number of input columns. For dense output, the column names specified are used as base names for the columns created for each category. - passthrough_cols: A string or a list of strings indicating column names to be excluded from any + + passthrough_cols: Optional[Union[str, List[str]]], default=None + A string or a list of strings indicating column names to be excluded from any operations (such as train, transform, or inference). These specified column(s) will remain untouched throughout the process. This option is helpful in scenarios requiring automatic input_cols inference, but need to avoid using specific columns, like index columns, during training or inference. - drop_input_cols: Remove input columns from output if set True. False by default. + + drop_input_cols: Optional[bool], default=False + Remove input columns from output if set True. False by default. Attributes: - center_: Dictionary mapping input column name to the median value for that feature. - scale_: Dictionary mapping input column name to the (scaled) interquartile range for that feature. + center_: Dict[str, float] + Dictionary mapping input column name to the median value for that feature. + + scale_: Dict[str, float] + Dictionary mapping input column name to the (scaled) interquartile range for that feature. """ def __init__( @@ -199,10 +217,6 @@ def _fit_snowpark(self, dataset: snowpark.DataFrame) -> None: project=base.PROJECT, subproject=base.SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=base.PROJECT, - subproject=base.SUBPROJECT, - ) def transform(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[snowpark.DataFrame, pd.DataFrame]: """ Center and scale the data. diff --git a/snowflake/ml/modeling/preprocessing/standard_scaler.py b/snowflake/ml/modeling/preprocessing/standard_scaler.py index 45702e93..4dabac2a 100644 --- a/snowflake/ml/modeling/preprocessing/standard_scaler.py +++ b/snowflake/ml/modeling/preprocessing/standard_scaler.py @@ -19,24 +19,40 @@ class StandardScaler(base.BaseTransformer): (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html). Args: - with_mean: If True, center the data before scaling. - with_std: If True, scale the data unit variance (i.e. unit standard deviation). - input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be scaled. - output_cols: The name(s) of one or more columns in a DataFrame in which results will be stored. The number of + with_mean: bool, default=True + If True, center the data before scaling. + + with_std: bool, default=True + If True, scale the data unit variance (i.e. unit standard deviation). + + input_cols: Optional[Union[str, List[str]]], default=None + The name(s) of one or more columns in a DataFrame containing a feature to be scaled. + + output_cols: Optional[Union[str, List[str]]], default=None + The name(s) of one or more columns in a DataFrame in which results will be stored. The number of columns specified must match the number of input columns. - passthrough_cols: A string or a list of strings indicating column names to be excluded from any + + passthrough_cols: Optional[Union[str, List[str]]], default=None + A string or a list of strings indicating column names to be excluded from any operations (such as train, transform, or inference). These specified column(s) will remain untouched throughout the process. This option is helpful in scenarios requiring automatic input_cols inference, but need to avoid using specific columns, like index columns, during training or inference. - drop_input_cols: Remove input columns from output if set True. False by default. + + drop_input_cols: Optional[bool], default=False + Remove input columns from output if set True. False by default. Attributes: - scale_: Dictionary mapping input column names to relative scaling factor to achieve zero mean and unit variance. + scale_: Optional[Dict[str, float]] = {} + Dictionary mapping input column names to relative scaling factor to achieve zero mean and unit variance. If a variance is zero, unit variance could not be achieved, and the data is left as-is, giving a scaling factor of 1. None if with_std is False. - mean_: Dictionary mapping input column name to the mean value for that feature. None if with_mean is False. - var_: Dictionary mapping input column name to the variance for that feature. Used to compute scale_. None if + + mean_: Optional[Dict[str, float]] = {} + Dictionary mapping input column name to the mean value for that feature. None if with_mean is False. + + var_: Optional[Dict[str, float]] = {} + Dictionary mapping input column name to the variance for that feature. Used to compute scale_. None if with_std is False """ @@ -177,10 +193,6 @@ def _fit_snowpark(self, dataset: snowpark.DataFrame) -> None: project=base.PROJECT, subproject=base.SUBPROJECT, ) - @telemetry.add_stmt_params_to_df( - project=base.PROJECT, - subproject=base.SUBPROJECT, - ) def transform(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[snowpark.DataFrame, pd.DataFrame]: """ Perform standardization by centering and scaling. diff --git a/snowflake/ml/registry/BUILD.bazel b/snowflake/ml/registry/BUILD.bazel index 7d57ca50..89bb441b 100644 --- a/snowflake/ml/registry/BUILD.bazel +++ b/snowflake/ml/registry/BUILD.bazel @@ -15,11 +15,14 @@ py_library( "//snowflake/ml/_internal/utils:formatting", "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/_internal/utils:query_result_checker", + "//snowflake/ml/_internal/utils:spcs_attribution_utils", "//snowflake/ml/_internal/utils:table_manager", "//snowflake/ml/_internal/utils:uri", "//snowflake/ml/dataset", "//snowflake/ml/model:_api", "//snowflake/ml/model:deploy_platforms", + "//snowflake/ml/model:model_signature", + "//snowflake/ml/model:type_hints", "//snowflake/ml/modeling/framework", ], ) @@ -74,10 +77,44 @@ py_test( ], ) +py_library( + name = "registry", + srcs = [ + "registry.py", + ], + deps = [ + "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model:model_signature", + "//snowflake/ml/model:type_hints", + "//snowflake/ml/model/_client/model:model_impl", + "//snowflake/ml/model/_client/model:model_version_impl", + "//snowflake/ml/model/_client/ops:model_ops", + "//snowflake/ml/model/_model_composer:model_composer", + ], +) + +py_test( + name = "registry_test", + srcs = [ + "registry_test.py", + ], + deps = [ + ":registry", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model/_client/model:model_version_impl", + "//snowflake/ml/model/_model_composer:model_composer", + "//snowflake/ml/test_utils:mock_data_frame", + "//snowflake/ml/test_utils:mock_session", + ], +) + py_package( name = "model_registry_pkg", packages = ["snowflake.ml"], deps = [ ":model_registry", + ":registry", ], ) diff --git a/snowflake/ml/registry/model_registry.py b/snowflake/ml/registry/model_registry.py index 31c89529..0146e52c 100644 --- a/snowflake/ml/registry/model_registry.py +++ b/snowflake/ml/registry/model_registry.py @@ -24,6 +24,7 @@ formatting, identifier, query_result_checker, + spcs_attribution_utils, table_manager, uri, ) @@ -1767,6 +1768,7 @@ def delete_deployment(self, model_name: str, model_version: str, *, deployment_n service_name = identifier.get_schema_level_object_identifier( self._name, self._schema, f"service_{deployment['MODEL_ID']}" ) + spcs_attribution_utils.record_service_end(self._session, service_name) query_result_checker.SqlResultValidator( self._session, f"DROP SERVICE IF EXISTS {service_name}", diff --git a/snowflake/ml/registry/notebooks/Model Packaging SnowML Examples.ipynb b/snowflake/ml/registry/notebooks/Model Packaging SnowML Examples.ipynb deleted file mode 100644 index 452b1efa..00000000 --- a/snowflake/ml/registry/notebooks/Model Packaging SnowML Examples.ipynb +++ /dev/null @@ -1,939 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "5de3eb26", - "metadata": {}, - "source": [ - "# Model Packaging Example" - ] - }, - { - "cell_type": "markdown", - "id": "197efd00", - "metadata": {}, - "source": [ - "## Before Everything" - ] - }, - { - "cell_type": "markdown", - "id": "6ce97b36", - "metadata": {}, - "source": [ - "### Install `snowflake-ml-python` locally" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "1117c596", - "metadata": {}, - "source": [ - "Please refer to our [readme file](https://docs.google.com/document/d/10DmBHYFGKINQwyvJupfuhARDk-cyG5_Fn3Uy2OQcQPk) to install `snowflake-ml-python`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "da314158", - "metadata": {}, - "outputs": [], - "source": [ - "# Snowpark Connector, Snowpark Library, Session\n", - "import snowflake.connector\n", - "import snowflake.snowpark\n", - "import snowflake.ml.modeling.preprocessing as snowml\n", - "from snowflake.snowpark import Session\n", - "from snowflake.snowpark.version import VERSION\n", - "from snowflake.ml.utils import connection_params" - ] - }, - { - "cell_type": "markdown", - "id": "99e58d8c", - "metadata": {}, - "source": [ - "### Setup Notebook" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "afd16ff5", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d609ff44", - "metadata": {}, - "outputs": [], - "source": [ - "# Scale cell width with the browser window to accommodate .show() commands for wider tables.\n", - "from IPython.display import display, HTML\n", - "\n", - "display(HTML(\"\"))" - ] - }, - { - "cell_type": "markdown", - "id": "1ac32c6f", - "metadata": {}, - "source": [ - "### Start Snowpark Session\n", - "\n", - "To avoid exposing credentials in Github, we use a small utility `SnowflakeLoginOptions`. It allows you to score your default credentials in `~/.snowsql/config` in the following format:\n", - "```\n", - "[connections]\n", - "accountname = # Account identifier to connect to Snowflake.\n", - "username = # User name in the account. Optional.\n", - "password = # User password. Optional.\n", - "dbname = # Default database. Optional.\n", - "schemaname = # Default schema. Optional.\n", - "warehousename = # Default warehouse. Optional.\n", - "#rolename = # Default role. Optional.\n", - "#authenticator = # Authenticator: 'snowflake', 'externalbrowser', etc\n", - "```\n", - "Please follow [this](https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings) for more details." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b2efc0a8", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", - "from snowflake.snowpark import Session\n", - "\n", - "session = Session.builder.configs(SnowflakeLoginOptions()).create()" - ] - }, - { - "cell_type": "markdown", - "id": "dfa9ab88", - "metadata": {}, - "source": [ - "### Open/Create Model Registry" - ] - }, - { - "cell_type": "markdown", - "id": "b0a0c8a8", - "metadata": {}, - "source": [ - "A model registry needs to be created before it can be used. The creation will create a new database in the current account so the active role needs to have permissions to create a database. After the first creation, the model registry can be opened without the need to create it again." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a95e3431", - "metadata": {}, - "outputs": [], - "source": [ - "REGISTRY_DATABASE_NAME = \"MODEL_REGISTRY\"\n", - "REGISTRY_SCHEMA_NAME = \"PUBLIC\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7fff21bc", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.registry import model_registry\n", - "model_registry.create_model_registry(session=session, database_name=REGISTRY_DATABASE_NAME, schema_name=REGISTRY_SCHEMA_NAME)\n", - "registry = model_registry.ModelRegistry(session=session, database_name=REGISTRY_DATABASE_NAME, schema_name=REGISTRY_SCHEMA_NAME)" - ] - }, - { - "cell_type": "markdown", - "id": "ca0f443d", - "metadata": {}, - "source": [ - "## Use with snowml model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6271c9d1", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.modeling.xgboost import XGBClassifier\n", - "from sklearn.datasets import load_iris\n", - "import numpy as np\n", - "import pandas as pd\n", - "\n", - "\n", - "iris = load_iris()\n", - "df = pd.DataFrame(data= np.c_[iris['data'], iris['target']],\n", - " columns= iris['feature_names'] + ['target'])\n", - "df.columns = [s.replace(\" (CM)\", '').replace(' ', '') for s in df.columns.str.upper()]\n", - "\n", - "INPUT_COLUMNS = ['SEPALLENGTH', 'SEPALWIDTH', 'PETALLENGTH', 'PETALWIDTH']\n", - "LABEL_COLUMNS = 'TARGET'\n", - "OUTPUT_COLUMNS = 'PREDICTED_TARGET'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c8de352", - "metadata": {}, - "outputs": [], - "source": [ - "df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7ca901eb", - "metadata": {}, - "outputs": [], - "source": [ - "test_features = df[:10]\n", - "model_version = \"1_008\"" - ] - }, - { - "cell_type": "markdown", - "id": "b9441f7a", - "metadata": {}, - "source": [ - "### XGBoost model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4ac4c21e", - "metadata": {}, - "outputs": [], - "source": [ - "clf_xgb = XGBClassifier(input_cols=INPUT_COLUMNS,\n", - " output_cols=OUTPUT_COLUMNS,\n", - " label_cols=LABEL_COLUMNS,\n", - " drop_input_cols=True)\n", - "\n", - "clf_xgb.fit(df)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dd0ca646", - "metadata": {}, - "outputs": [], - "source": [ - "prediction = clf_xgb.predict(test_features)\n", - "prediction_proba = clf_xgb.predict_proba(test_features)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3d872431", - "metadata": {}, - "outputs": [], - "source": [ - "model_name = \"SIMPLE_XGB_MODEL\"\n", - "deploy_name = \"xgb_model_predict\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "523cc249", - "metadata": { - "code_folding": [] - }, - "outputs": [], - "source": [ - "# A name and model tags can be added to the model at registration time.\n", - "model = registry.log_model(\n", - " model_name=model_name,\n", - " model_version=model_version,\n", - " model=clf_xgb,\n", - " tags={\"stage\": \"testing\", \"classifier_type\": \"XGBClassifier\"},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "fe9e2081", - "metadata": {}, - "source": [ - "### Testing on deploy" - ] - }, - { - "cell_type": "markdown", - "id": "56834c5c", - "metadata": {}, - "source": [ - "#### Predict function match/mismatch? - comparison between deploy and local" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf55701d", - "metadata": {}, - "outputs": [], - "source": [ - "model.deploy(\n", - " deployment_name=deploy_name,\n", - " target_method=\"predict\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6f159a5f", - "metadata": {}, - "outputs": [], - "source": [ - "remote_prediction = model.predict(deployment_name=deploy_name, data=test_features)\n", - "\n", - "print(\"Remote prediction:\", remote_prediction[:10])\n", - "\n", - "print(\"Result comparison:\", np.array_equal(prediction, remote_prediction.values))" - ] - }, - { - "cell_type": "markdown", - "id": "65af7944", - "metadata": {}, - "source": [ - "#### Predict_proba function match/mismatch? - comparison between deploy and local" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0c77d583", - "metadata": {}, - "outputs": [], - "source": [ - "model.deploy(\n", - " deployment_name=deploy_name,\n", - " target_method=\"predict_proba\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1216dbe8", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "remote_prediction_proba = model.predict(deployment_name=deploy_name, data=test_features)\n", - "\n", - "print(\"Remote prediction:\", remote_prediction_proba[:10])\n", - "\n", - "print(\"Result comparison:\", np.allclose(prediction_proba, remote_prediction_proba.values))" - ] - }, - { - "cell_type": "markdown", - "id": "e5f9c9b7", - "metadata": {}, - "source": [ - "### Random Forest model *from ensemble*\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "48780cb2", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.modeling.ensemble import RandomForestClassifier" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "93d42010", - "metadata": {}, - "outputs": [], - "source": [ - "clf_rf = RandomForestClassifier(input_cols=INPUT_COLUMNS,\n", - " output_cols=OUTPUT_COLUMNS,\n", - " label_cols=LABEL_COLUMNS,\n", - " drop_input_cols=True)\n", - "\n", - "clf_rf.fit(df)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cfe55d82", - "metadata": {}, - "outputs": [], - "source": [ - "prediction = clf_rf.predict(test_features)\n", - "prediction_proba = clf_rf.predict_proba(test_features)\n", - "prediction_log_proba = clf_rf.predict_log_proba(test_features)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4ef91e18", - "metadata": {}, - "outputs": [], - "source": [ - "model_name = \"SIMPLE_RF_MODEL\"\n", - "deploy_name = \"rf_model_predict\"\n", - "classifier_type = \"RFClassifier\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f9401b46", - "metadata": {}, - "outputs": [], - "source": [ - "# A name and model tags can be added to the model at registration time.\n", - "model = registry.log_model(\n", - " model_name=model_name,\n", - " model_version=model_version,\n", - " model=clf_rf,\n", - " tags={\"stage\": \"testing\", \"classifier_type\": classifier_type},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "cefbad30", - "metadata": {}, - "source": [ - "#### Comparison between deploy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f81f663e", - "metadata": {}, - "outputs": [], - "source": [ - "model.deploy(\n", - " deployment_name=deploy_name,\n", - " target_method=\"predict\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "726838d0", - "metadata": {}, - "outputs": [], - "source": [ - "remote_prediction = model.predict(deployment_name=deploy_name, data=test_features)\n", - "\n", - "print(\"Remote prediction:\", remote_prediction[:10])\n", - "\n", - "print(\"Result comparison:\", np.array_equal(prediction, remote_prediction.values))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33833e23", - "metadata": {}, - "outputs": [], - "source": [ - "model.deploy(\n", - " deployment_name=deploy_name,\n", - " target_method=\"predict_proba\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4e5d8d89", - "metadata": {}, - "outputs": [], - "source": [ - "remote_prediction_proba = model.predict(deployment_name=deploy_name, data=test_features)\n", - "\n", - "print(\"Remote prediction:\", remote_prediction_proba[:10])\n", - "\n", - "print(\"Result comparison:\", np.array_equal(prediction_proba, remote_prediction_proba.values))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8ddc04e8", - "metadata": {}, - "outputs": [], - "source": [ - "model.deploy(\n", - " deployment_name=deploy_name,\n", - " target_method=\"predict_log_proba\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cf688655", - "metadata": {}, - "outputs": [], - "source": [ - "remote_prediction_log_proba = model.predict(deployment_name=deploy_name, data=test_features)\n", - "\n", - "print(\"Remote prediction:\", remote_prediction_log_proba[:10])\n", - "\n", - "print(\"Result comparison:\", np.array_equal(prediction_log_proba, remote_prediction_log_proba.values))" - ] - }, - { - "cell_type": "markdown", - "id": "eb7b90fe", - "metadata": {}, - "source": [ - "### Logistic Regression model\n", - "\n", - "The reason to test w/ LR model is because, it has all the functions such as `predict, predict_log_proba, predict_proba, decision_function`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6b1d0b93", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.modeling.linear_model import LogisticRegression" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3280b02f", - "metadata": {}, - "outputs": [], - "source": [ - "clf_lr = LogisticRegression(input_cols=INPUT_COLUMNS,\n", - " output_cols=OUTPUT_COLUMNS,\n", - " label_cols=LABEL_COLUMNS,\n", - " drop_input_cols=True,\n", - " max_iter=1000)\n", - "\n", - "clf_lr.fit(df)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a74cef89", - "metadata": {}, - "outputs": [], - "source": [ - "prediction = clf_lr.predict(test_features)\n", - "prediction_proba = clf_lr.predict_proba(test_features)\n", - "prediction_log_proba = clf_lr.predict_log_proba(test_features)\n", - "prediction_decision = clf_lr.decision_function(test_features)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "de6fa3a0", - "metadata": {}, - "outputs": [], - "source": [ - "model_name = \"SIMPLE_LR_MODEL\"\n", - "deploy_name = \"lr_model_predict\"\n", - "classifier_type = \"LogisticRegression\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "35ca8aa6", - "metadata": {}, - "outputs": [], - "source": [ - "# A name and model tags can be added to the model at registration time.\n", - "model = registry.log_model(\n", - " model_name=model_name,\n", - " model_version=model_version,\n", - " model=clf_lr,\n", - " tags={\"stage\": \"testing\", \"classifier_type\": classifier_type},\n", - " options={\"embed_local_ml_library\": True}\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "83ff0a1b", - "metadata": {}, - "source": [ - "#### Comparison between deploy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25be7377", - "metadata": {}, - "outputs": [], - "source": [ - "model.deploy(\n", - " deployment_name=deploy_name,\n", - " target_method=\"predict\",\n", - " options={\"relax_version\": True},\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "afd5f285", - "metadata": {}, - "outputs": [], - "source": [ - "remote_prediction = model.predict(deployment_name=deploy_name, data=test_features)\n", - "\n", - "print(\"Remote prediction:\", remote_prediction[:10])\n", - "\n", - "print(\"Result comparison:\", np.array_equal(prediction, remote_prediction.values))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fa054c3c", - "metadata": {}, - "outputs": [], - "source": [ - "model.deploy(\n", - " deployment_name=deploy_name,\n", - " target_method=\"predict_proba\",\n", - " options={\"relax_version\": True},\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ec25c905", - "metadata": {}, - "outputs": [], - "source": [ - "remote_prediction_proba = model.predict(deployment_name=deploy_name, data=test_features)\n", - "\n", - "print(\"Remote prediction:\", remote_prediction_proba[:10])\n", - "\n", - "print(\"Result comparison:\", np.allclose(prediction_proba, remote_prediction_proba.values))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5a425e55", - "metadata": {}, - "outputs": [], - "source": [ - "model.deploy(\n", - " deployment_name=deploy_name,\n", - " target_method=\"predict_log_proba\",\n", - " options={\"relax_version\": True},\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ff4a4c54", - "metadata": {}, - "outputs": [], - "source": [ - "remote_prediction_log_proba = model.predict(deployment_name=deploy_name, data=test_features)\n", - "\n", - "print(\"Remote prediction:\", remote_prediction_log_proba[:10])\n", - "\n", - "print(\"Result comparison:\", np.allclose(prediction_log_proba, remote_prediction_log_proba.values))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2904de8c", - "metadata": {}, - "outputs": [], - "source": [ - "model.deploy(\n", - " deployment_name=deploy_name,\n", - " target_method=\"decision_function\",\n", - " options={\"relax_version\": True},\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "713806ec", - "metadata": {}, - "outputs": [], - "source": [ - "remote_prediction_decision_function = model.predict(deployment_name=deploy_name, data=test_features)\n", - "\n", - "print(\"Remote prediction:\", remote_prediction_decision_function[:10])\n", - "\n", - "print(\"Result comparison:\", np.allclose(prediction_decision, remote_prediction_decision_function.values))" - ] - }, - { - "cell_type": "markdown", - "id": "d6930720", - "metadata": {}, - "source": [ - "### Pipeline model\n", - "\n", - "It is important to see if the whole pipeline is stored" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "846db56a", - "metadata": {}, - "outputs": [], - "source": [ - "def add_simple_category(df):\n", - " bins = (-1, 4, 5, 6, 10)\n", - " group_names = ['Unknown', '1_quartile', '2_quartile', '3_quartile']\n", - " categories = pd.cut(df.SEPALLENGTH, bins, labels=group_names)\n", - " df['SIMPLE'] = categories\n", - " return df\n", - "df_cat = add_simple_category(df)\n", - "\n", - "numeric_features=['SEPALLENGTH', 'SEPALWIDTH', 'PETALLENGTH', 'PETALWIDTH']\n", - "categorical_features = ['SIMPLE']\n", - "numeric_features_output = [x + '_O' for x in numeric_features]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2033ef31", - "metadata": {}, - "outputs": [], - "source": [ - "# Define the Table and Cleanup Cols, have a work_schema for testing\n", - "\n", - "\n", - "############################################################################\n", - "# NOTE: \n", - "# Set work_schema variable to some schema that exists in your account.\n", - "# set data_dir to point to the directory that contains the diamonds.csv file.\n", - "############################################################################\n", - "work_schema = 'TEST'\n", - "demo_table = 'IRIS_UPPER'\n", - "\n", - "# write the DF to Snowflake and create a Snowflake DF\n", - "session.write_pandas(df_cat, demo_table, auto_create_table=True, table_type=\"temporary\", schema=work_schema)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6b150ff8", - "metadata": {}, - "outputs": [], - "source": [ - "# Diamonds Snowflake Table\n", - "input_tbl = f\"{session.get_current_database()}.{session.get_current_schema()}.{demo_table}\"\n", - "iris_df = session.table(input_tbl)\n", - "print(iris_df.limit(10).to_pandas())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "86f8b074", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.modeling.linear_model import LogisticRegression\n", - "from snowflake.ml.preprocessing import MinMaxScaler, StandardScaler, OneHotEncoder\n", - "from snowflake.ml.framework.pipeline import Pipeline\n", - "pipeline = Pipeline(\n", - " steps=[\n", - " ('OHEHOT', OneHotEncoder(input_cols=categorical_features, output_cols='cat_output', drop_input_cols=True), ),\n", - " ('SCALER', MinMaxScaler(clip=True, input_cols=numeric_features, output_cols=numeric_features_output, drop_input_cols=True), ),\n", - " ('CLASSIFIER', LogisticRegression(label_cols=LABEL_COLUMNS))\n", - " ])\n", - "pipeline.fit(iris_df)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "94231eb1", - "metadata": {}, - "outputs": [], - "source": [ - "iris_df_test = iris_df.limit(10)\n", - "prediction = pipeline.predict(iris_df_test)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2720275f", - "metadata": {}, - "outputs": [], - "source": [ - "pipeline.fit(iris_df.to_pandas())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a3b5159a", - "metadata": {}, - "outputs": [], - "source": [ - "prediction = pipeline.predict(iris_df_test.to_pandas())\n", - "prediction_log_proba = pipeline.predict_log_proba(iris_df_test.to_pandas())\n", - "prediction_proba = pipeline.predict_proba(iris_df_test.to_pandas())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "85917118", - "metadata": {}, - "outputs": [], - "source": [ - "model_name = \"SIMPLE_PP_MODEL\"\n", - "deploy_name = \"pp_model_predict\"\n", - "classifier_type = \"Pipeline\"\n", - "model_version = f\"{model_name}_007\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "735ff3ca", - "metadata": {}, - "outputs": [], - "source": [ - "# A name and model tags can be added to the model at registration time.\n", - "model = registry.log_model(\n", - " model_name=model_name,\n", - " model_version=model_version,\n", - " model=pipeline,\n", - " tags={\"stage\": \"testing\", \"classifier_type\": classifier_type},\n", - " options={\"embed_local_ml_library\": True}\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "2ca2e15e", - "metadata": {}, - "source": [ - "#### Comparison between deploy predict" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7c1210c5", - "metadata": {}, - "outputs": [], - "source": [ - "model.deploy(\n", - " deployment_name=deploy_name,\n", - " target_method=\"predict\",\n", - " options={\"relax_version\": True},\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d2ff838f", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "remote_prediction = model.predict(deployment_name=deploy_name, data=iris_df_test.to_pandas())\n", - "\n", - "print(\"Remote prediction:\", remote_prediction[:10])\n", - "\n", - "print(\"Result comparison:\", np.allclose(prediction, remote_prediction.values))" - ] - } - ], - "metadata": { - "hide_input": false, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - }, - "vscode": { - "interpreter": { - "hash": "fb0a62cbfaa59af7646af5a6672c5c3e72ec75fbadf6ff0336b6769523f221a5" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/snowflake/ml/registry/notebooks/Model Registry Demo.ipynb b/snowflake/ml/registry/notebooks/Model Registry Demo.ipynb deleted file mode 100644 index 794771c0..00000000 --- a/snowflake/ml/registry/notebooks/Model Registry Demo.ipynb +++ /dev/null @@ -1,760 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "id": "5de3eb26", - "metadata": {}, - "source": [ - "# Model Registry Demo" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "99e58d8c", - "metadata": {}, - "source": [ - "## Setup Notebook and Import Path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d609ff44", - "metadata": {}, - "outputs": [], - "source": [ - "# Scale cell width with the browser window to accommodate .show() commands for wider tables.\n", - "from IPython.display import display, HTML\n", - "display(HTML(\"\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "95cde1f7", - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "\n", - "# Simplify reading from the local repository\n", - "cwd=os.getcwd()\n", - "REPO_PREFIX=\"snowflake/ml\"\n", - "LOCAL_REPO_PATH=cwd[:cwd.find(REPO_PREFIX)].rstrip('/')\n", - "\n", - "if LOCAL_REPO_PATH not in sys.path:\n", - " print(f\"Adding {LOCAL_REPO_PATH} to system path\")\n", - " sys.path.append(LOCAL_REPO_PATH)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "c592d46c", - "metadata": {}, - "source": [ - "## Train A Small Model" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "bec73215", - "metadata": {}, - "source": [ - "The cell below trains a small model for demonstration purposes. The nature of the model does not matter, it is purely used to demonstrate the usage of the Model Registry." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8cf44218", - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn import svm, linear_model\n", - "from sklearn.datasets import load_digits\n", - "\n", - "digits = load_digits()\n", - "target_digit = 6\n", - "num_training_examples = 10\n", - "svc_gamma = 0.001\n", - "svc_C = 10.\n", - "\n", - "clf = svm.SVC(gamma=svc_gamma, C=svc_C, probability=True)\n", - "\n", - "\n", - "def one_vs_all(dataset, digit):\n", - " return [x == digit for x in dataset]\n", - "\n", - "# Train a classifier using num_training_examples and use the last 100 examples for test.\n", - "train_features = digits.data[:num_training_examples]\n", - "train_labels = one_vs_all(digits.target[:num_training_examples], target_digit)\n", - "clf.fit(train_features, train_labels)\n", - "\n", - "test_features = digits.data[-100:]\n", - "test_labels = one_vs_all(digits.target[-100:], target_digit)\n", - "prediction = clf.predict(test_features)\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "20eef3b6", - "metadata": {}, - "source": [ - "## Start Snowpark Session\n", - "\n", - "To avoid exposing credentials in Github, we use a small utility `SnowflakeLoginOptions`. It allows you to score your default credentials in `~/.snowsql/config` in the following format:\n", - "```\n", - "[connections]\n", - "accountname = # Account identifier to connect to Snowflake.\n", - "username = # User name in the account. Optional.\n", - "password = # User password. Optional.\n", - "dbname = # Default database. Optional.\n", - "schemaname = # Default schema. Optional.\n", - "warehousename = # Default warehouse. Optional.\n", - "#rolename = # Default role. Optional.\n", - "#authenticator = # Authenticator: 'snowflake', 'externalbrowser', etc\n", - "```\n", - "Please follow [this](https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings) for more details." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "75282f6d", - "metadata": {}, - "outputs": [], - "source": [ - "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", - "from snowflake.snowpark import Session, Column, functions\n", - "\n", - "session = Session.builder.configs(SnowflakeLoginOptions()).create()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "dfa9ab88", - "metadata": {}, - "source": [ - "## Open/Create Model Registry" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "676b28b3", - "metadata": {}, - "source": [ - "A model registry needs to be created before it can be used. The creation will create a new database in the current account so the active role needs to have permissions to create a database. After the first creation, the model registry can be opened without the need to create it again." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5d37ad34", - "metadata": {}, - "outputs": [], - "source": [ - "import importlib\n", - "from snowflake.ml.registry import model_registry\n", - "# Force re-loading model_registry in case we updated the package during the runtime of this notebook.\n", - "importlib.reload(model_registry)\n", - "\n", - "registry_name = \"model_registry_zzhu\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "98dbe771", - "metadata": {}, - "outputs": [], - "source": [ - "# Create a new model registry. This will be a no-op if the registry already exists.\n", - "create_result = model_registry.create_model_registry(session=session, database_name=registry_name)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7fff21bc", - "metadata": {}, - "outputs": [], - "source": [ - "registry = model_registry.ModelRegistry(session=session, database_name=registry_name)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "5d6a85b3", - "metadata": {}, - "source": [ - "There are two functionally equivalent APIs to interact with the model registry.\n", - "\n", - "* A _relational API_ where all operations are performed as methods of the `registry` object and \n", - "* a _object API_ where operations on a specific model are performend as methods of a `ModelReference` object.\n", - "\n", - "The usage examples below will add some color to the two APIs and how they behave." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "317e7843", - "metadata": {}, - "source": [ - "## Register a new Model" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "3bdda91a", - "metadata": {}, - "source": [ - "Registering a new model is always performed through the relational API. \n", - "\n", - "The call to `log_model` executes a few steps:\n", - "1. The given model object is serialized and uploaded to a stage.\n", - "1. An entry in the Model Registry is created for the model, referencing the model stage location.\n", - "1. Additional metadata is updated for the model as provided in the call.\n", - "\n", - "For the serialization to work, the model object needs to be serializable in python." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9d8ad06e", - "metadata": {}, - "outputs": [], - "source": [ - "# A name and model tags can be added to the model at registration time.\n", - "model_name = \"my_model\"\n", - "model_version = \"108.2.4\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "84e10dc9", - "metadata": {}, - "outputs": [], - "source": [ - "model = registry.log_model(model_name=model_name, model_version=model_version, model=clf, tags={\n", - " \"stage\": \"testing\", \"classifier_type\": \"svm.SVC\", \"svc_gamma\": svc_gamma, \"svc_C\": svc_C}, sample_input_data=train_features, options={\"embed_local_ml_library\": True})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b463bad9", - "metadata": {}, - "outputs": [], - "source": [ - "# The object API can be used to reference a model after creation.\n", - "model = model_registry.ModelReference(registry=registry, model_name=model_name, model_version=model_version)\n", - "print(\"Registered new model id:\", model_id)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "27d1158d", - "metadata": {}, - "source": [ - "## Add Metrics" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "d6035ca5", - "metadata": {}, - "source": [ - "Metrics are a type of metadata annotation that can be associated with models stored in the Model Registry. Metrics often take the form of scalars but we also support more complex objects such as arrays or dictionaries to represent metrics. In the exmamples below, we add scalars, dictionaries, and a 2-dimensional numpy array as metrics." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c2b0cdbd", - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn import metrics\n", - "\n", - "test_accuracy = metrics.accuracy_score(test_labels, prediction)\n", - "print(\"Model test accuracy:\", test_accuracy)\n", - "\n", - "# Simple scalar metrics.\n", - "\n", - "# Relational API\n", - "registry.set_metric(model_name=model_name, model_version=model_version, metric_name=\"test_accuracy\", metric_value=test_accuracy)\n", - "\n", - "# Object API\n", - "model.set_metric(metric_name=\"num_training_examples\", metric_value=num_training_examples)\n", - "\n", - "# Hierarchical metric.\n", - "registry.set_metric(model_name=model_name, model_version=model_version, metric_name=\"dataset_test\", metric_value={\"accuracy\": test_accuracy})\n", - "\n", - "# Multivalent metric:\n", - "test_confusion_matrix = metrics.confusion_matrix(test_labels, prediction)\n", - "print(\"Confusion matrix:\", test_confusion_matrix)\n", - "\n", - "registry.set_metric(model_name=model_name, model_version=model_version, metric_name=\"confusion_matrix\", metric_value=test_confusion_matrix)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "45b81834", - "metadata": {}, - "outputs": [], - "source": [ - "# Relational API\n", - "registry.get_metrics(model_name=model_name, model_version=model_version)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9a2627c5", - "metadata": {}, - "outputs": [], - "source": [ - "# Object API\n", - "model.get_metrics()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "98164cb7", - "metadata": {}, - "source": [ - "## List Model in Registry" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "67eac368", - "metadata": {}, - "source": [ - "Listing models in the registry returns a SnowPark DataFrame. That allows the caller to select and filter the models as needed. In the example below, we list the name, tags, and metrics for the model we just added." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dc82b541", - "metadata": {}, - "outputs": [], - "source": [ - "model_list = registry.list_models()\n", - "\n", - "model_list.filter(model_list[\"VERSION\"] == model_version).select(\"NAME\",\"VERSION\",\"TAGS\",\"METRICS\").show()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "5706004c", - "metadata": {}, - "source": [ - "## Metadata: Tags and Descriptions" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "05cee94f", - "metadata": {}, - "source": [ - "Similar to how we changed metrics in the example above, we can also edit tags and descriptions of models both with the relational API and with the object API." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "88707ecd", - "metadata": {}, - "source": [ - "### Relational API" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f80f78da", - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Old tags:\", registry.get_tags(model_name=model_name, model_version=model_version))\n", - "\n", - "registry.set_tag(model_name=model_name, model_version=model_version, tag_name=\"minor_version\", tag_value=\"23\")\n", - "print(\"Added tag:\", registry.get_tags(model_name=model_name, model_version=model_version,))\n", - "\n", - "registry.remove_tag(model_name=model_name, model_version=model_version, tag_name=\"minor_version\")\n", - "print(\"Removed tag\", registry.get_tags(model_name=model_name, model_version=model_version,))\n", - "\n", - "registry.set_tag(model_name, model_version,\"stage\",\"production\")\n", - "print(\"Updated tag:\", registry.get_tags(model_name=model_name, model_version=model_version,))\n", - "\n", - "registry.set_model_description(description=\"My model is better than talkgpt-5!\", model_name=model_name, model_version=model_version,)\n", - "print(\"Added description:\", registry.get_model_description(model_name=model_name, model_version=model_version,))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "47e80e1e", - "metadata": {}, - "source": [ - "### Object API" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7905d9c9", - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Old tags:\", model.get_tags())\n", - "\n", - "model.set_tag(\"minor_version\", \"23\")\n", - "print(\"Added tag:\", model.get_tags())\n", - "\n", - "model.remove_tag(\"minor_version\")\n", - "print(\"Removed tag\", model.get_tags())\n", - "\n", - "model.set_tag(\"stage\", \"production\")\n", - "print(\"Updated tag:\", model.get_tags())\n", - "\n", - "model.set_model_description(description=\"My model is better than speakgpt-6!\")\n", - "print(\"New description:\", model.get_model_description())" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "72ade02b", - "metadata": {}, - "source": [ - "## List recent Models in Registry" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "645df90e", - "metadata": {}, - "source": [ - "Listing the models in the Model Registry returns a dataframe that allows us to conveniently manipulate the model list. In the example below, we show all models in the Model Registry sorted by recency." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eef6965d", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "model_list.select(\"ID\",\"NAME\",\"VERSION\",\"CREATION_TIME\",\"TAGS\").order_by(\"CREATION_TIME\", ascending=False).show(3)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "b2a42a8f", - "metadata": {}, - "source": [ - "## List all versions of a Model ordered by test set accuracy" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "833bfd54", - "metadata": {}, - "source": [ - "With a similar logic, we can also list all versions of a model with a given name sorted by a metric, in this case model accuracy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6df2eafc", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "model_list.select(\"ID\",\"NAME\",\"VERSION\",\"TAGS\",\"METRICS\").filter(\n", - " Column(\"NAME\") == model_name).order_by(Column(\"METRICS\")[\"test_accuracy\"], ascending=False \n", - ").show(3) " - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "7a0a1ef4", - "metadata": {}, - "source": [ - "## Model Deployment\n", - "Registry can be used to create deployment, which can be used for prediction. Deployment exists in the form of UDF. It could be either permanent or temporary.\n", - "\n", - "\n", - "### Permanent deployment" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c8473d12", - "metadata": {}, - "outputs": [], - "source": [ - "# Create a permanent deployment\n", - "model.deploy(deployment_name=\"PERM_DEPLOY_1_0\", target_method=\"predict\", permanent=True, options={\"relax_version\": True})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "09d18340", - "metadata": {}, - "outputs": [], - "source": [ - "# Create a permanent deployment with overridden UDF stage path\n", - "stage_path = f'\"{registry_name}\".PUBLIC._SYSTEM_REGISTRY_DEPLOYMENTS_VIEW_TEST'\n", - "session.sql(f\"CREATE STAGE IF NOT EXISTS {stage_path}\").collect()\n", - "model.deploy(deployment_name=\"PERM_DEPLOY_1_1\", target_method=\"predict\", permanent=True, options={\"permanent_udf_stage_location\":'@'+stage_path, \"relax_version\": True})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "99c1b5d7", - "metadata": {}, - "outputs": [], - "source": [ - "model.list_deployments().select(\"MODEL_NAME\", \"MODEL_VERSION\", \"DEPLOYMENT_NAME\").show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33660055", - "metadata": {}, - "outputs": [], - "source": [ - "# Create another permanent deployment\n", - "model.deploy(deployment_name=\"PERM_DEPLOY_1_2\", target_method=\"predict\", permanent=True, options={\"relax_version\": True})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6983128d", - "metadata": {}, - "outputs": [], - "source": [ - "model.predict(\"PERM_DEPLOY_1_0\", test_features)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3ea350f9", - "metadata": {}, - "outputs": [], - "source": [ - "model.predict(\"PERM_DEPLOY_1_1\", test_features)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2f67db59", - "metadata": {}, - "outputs": [], - "source": [ - "model.predict(\"PERM_DEPLOY_1_2\", test_features)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2732495f", - "metadata": {}, - "outputs": [], - "source": [ - "model.delete_deployment(deployment_name=\"PERM_DEPLOY_1_2\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "168f7817", - "metadata": {}, - "outputs": [], - "source": [ - "model.list_deployments().select(\"MODEL_NAME\", \"MODEL_VERSION\", \"DEPLOYMENT_NAME\").show()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "208c112d", - "metadata": {}, - "source": [ - "### Temporary deployments\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "052a1577", - "metadata": {}, - "source": [ - "The key distinction between permanent and temporary deployments lies in their lifespan. Temporary deployments are session-scoped and get removed when the session ends. As a result, the methods `delete_deployment()` and `list_deployments()` currently do not support temporary deployments." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c8027bf8", - "metadata": {}, - "outputs": [], - "source": [ - "# Create a temporary deployment\n", - "model.deploy(deployment_name=\"TEMP_DEPLOY_1_0\", target_method=\"predict\", permanent=False, options={\"relax_version\": True})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3a5cfd5f", - "metadata": {}, - "outputs": [], - "source": [ - "# Create another temporary deployment\n", - "model.deploy(deployment_name=\"TEMP_DEPLOY_1_1\", target_method=\"predict\", permanent=False, options={\"relax_version\": True})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb73482a", - "metadata": {}, - "outputs": [], - "source": [ - "model.predict(\"TEMP_DEPLOY_1_0\", test_features)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d2ed5fcb", - "metadata": {}, - "outputs": [], - "source": [ - "model.predict(\"TEMP_DEPLOY_1_1\", test_features)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "73080975", - "metadata": {}, - "source": [ - "## Examine Model History" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "011d4d6f", - "metadata": {}, - "source": [ - "In addition to the current state of the model metadata, we also give access to the history of all changes to the model metadata. This includes the registration event itself but also changes to any metadata of the model, when they happened and who initiated them." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "025a0065", - "metadata": {}, - "source": [ - "### Relational API" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6620ead0", - "metadata": {}, - "outputs": [], - "source": [ - "registry.get_model_history(model_name=model_name, model_version=model_version).select(\"EVENT_TIMESTAMP\", \"ROLE\", \"ATTRIBUTE_NAME\",\"OPERATION\", \"VALUE[ATTRIBUTE_NAME]\").sort(\"EVENT_TIMESTAMP\", ascending=False).show()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "98bb5522", - "metadata": {}, - "source": [ - "### Object API" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33f284f6", - "metadata": {}, - "outputs": [], - "source": [ - "model.get_model_history().select(\"EVENT_TIMESTAMP\", \"ROLE\", \"ATTRIBUTE_NAME\",\"OPERATION\", \"VALUE[ATTRIBUTE_NAME]\").sort(\"EVENT_TIMESTAMP\", ascending=False).show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - }, - "vscode": { - "interpreter": { - "hash": "fb0a62cbfaa59af7646af5a6672c5c3e72ec75fbadf6ff0336b6769523f221a5" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/snowflake/ml/registry/notebooks/Using MODEL via Registry in Snowflake.ipynb b/snowflake/ml/registry/notebooks/Using MODEL via Registry in Snowflake.ipynb new file mode 100644 index 00000000..895cf8e9 --- /dev/null +++ b/snowflake/ml/registry/notebooks/Using MODEL via Registry in Snowflake.ipynb @@ -0,0 +1,835 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using MODEL via Registry in Snowflake\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Before Everything\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Snowflake-ML-Python Installation\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Please refer to our [landing page](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index) to install `snowflake-ml-python` with the latest version.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup Notebook\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "\n", + "display(HTML(\"\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Start Snowpark Session\n", + "\n", + "To avoid exposing credentials in Github, we use a small utility `SnowflakeLoginOptions`. It allows you to score your default credentials in `~/.snowsql/config` in the following format:\n", + "\n", + "```\n", + "[connections]\n", + "accountname = # Account identifier to connect to Snowflake.\n", + "username = # User name in the account. Optional.\n", + "password = # User password. Optional.\n", + "dbname = # Default database. Optional.\n", + "schemaname = # Default schema. Optional.\n", + "warehousename = # Default warehouse. Optional.\n", + "#rolename = # Default role. Optional.\n", + "#authenticator = # Authenticator: 'snowflake', 'externalbrowser', etc\n", + "```\n", + "\n", + "Please follow [this](https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings) for more details.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", + "from snowflake.snowpark import Session\n", + "\n", + "session = Session.builder.configs(SnowflakeLoginOptions()).create()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Open A Registry\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To start we need to open a registry in a given **pre-created** database and schema, or the schema your session is actively using.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "REGISTRY_DATABASE_NAME = \"MY_REGISTRY\"\n", + "REGISTRY_SCHEMA_NAME = \"PUBLIC\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from snowflake.ml.registry import registry\n", + "\n", + "reg = registry.Registry(session=session, database_name=REGISTRY_DATABASE_NAME, schema_name=REGISTRY_SCHEMA_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Walkthrough Registry with a Small Model\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train a small model\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The cell below trains a small model for demonstration purposes. The nature of the model does not matter, it is purely used to demonstrate the usage of the Registry.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import svm, datasets\n", + "\n", + "digits = datasets.load_digits()\n", + "target_digit = 6\n", + "num_training_examples = 10\n", + "svc_gamma = 0.001\n", + "svc_C = 10.0\n", + "\n", + "clf = svm.SVC(gamma=svc_gamma, C=svc_C, probability=True)\n", + "\n", + "\n", + "def one_vs_all(dataset, digit):\n", + " return [x == digit for x in dataset]\n", + "\n", + "\n", + "# Train a classifier using num_training_examples and use the last 100 examples for test.\n", + "train_features = digits.data[:num_training_examples]\n", + "train_labels = one_vs_all(digits.target[:num_training_examples], target_digit)\n", + "clf.fit(train_features, train_labels)\n", + "\n", + "test_features = digits.data[-100:]\n", + "test_labels = one_vs_all(digits.target[-100:], target_digit)\n", + "prediction = clf.predict(test_features)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Log the model\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To keep the model for future use, we need to log the model. We need to provide a model name and a version name, with the following API, a SQL MODEL object will be created on your behalf.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"my_model\"\n", + "version_name = \"v1\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mv = reg.log_model(clf, model_name=model_name, version_name=version_name, sample_input_data=train_features)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run the model\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After being logged, the model has already been ready to use in Snowflake with Warehouse!\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "remote_prediction = mv.run(test_features, method_name=\"predict\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "print(\"Remote prediction:\", remote_prediction[:10])\n", + "\n", + "print(\"Result comparison:\", np.array_equal(prediction, remote_prediction[\"output_feature_0\"].values))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "All methods available in the original model can be run.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mv.list_methods()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "remote_prediction_proba = mv.run(test_features, method_name=\"predict_proba\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prediction_proba = clf.predict_proba(test_features)\n", + "\n", + "print(\"Remote prediction:\", remote_prediction_proba[:10])\n", + "\n", + "print(\"Result comparison:\", np.allclose(prediction_proba, remote_prediction_proba.values))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get the model and version\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After the model being logged, beside using the returned object, there are other APIs for you to get the object to operate on model or model version.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m = reg.get_model(model_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mv = m.version(version_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### List models and versions\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reg.list_models()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.list_versions()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model Description\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You could set description of a model or a specific version of the model. They are backend by COMMENT feature in the SQL.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.description = \"This is my model.\"\n", + "print(m.description)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mv.description = \"This is the first version of my model.\"\n", + "print(mv.description)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model Metrics\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Metrics are a type of metadata annotation that can be associated with a version of models stored in the Registry. Metrics often take the form of scalars but we also support more complex objects such as arrays or dictionaries to represent metrics, as long as they are JSON serializable. In the examples below, we add scalars, dictionaries, and a 2-dimensional numpy array as metrics.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import metrics\n", + "\n", + "test_accuracy = metrics.accuracy_score(test_labels, prediction)\n", + "print(\"Model test accuracy:\", test_accuracy)\n", + "\n", + "test_confusion_matrix = metrics.confusion_matrix(test_labels, prediction)\n", + "print(\"Confusion matrix:\", test_confusion_matrix)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mv.set_metric(metric_name=\"test_accuracy\", value=test_accuracy)\n", + "\n", + "mv.set_metric(metric_name=\"num_training_examples\", value=num_training_examples)\n", + "\n", + "mv.set_metric(metric_name=\"dataset_test\", value={\"accuracy\": test_accuracy})\n", + "\n", + "mv.set_metric(metric_name=\"confusion_matrix\", value=test_confusion_matrix.tolist())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mv.get_metric(metric_name=\"confusion_matrix\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mv.delete_metric(metric_name=\"confusion_matrix\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mv.list_metrics()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Default version\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You could set a default version of a model\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.default = version_name" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.default" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Delete model\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reg.delete_model(model_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reg.list_models()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use with Snowpark ML Modeling Model and Snowpark DataFrame\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare Dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DATA_TABLE_NAME = \"KDDCUP99_DATASET\"\n", + "\n", + "kddcup99_data = datasets.fetch_kddcup99(as_frame=True)\n", + "kddcup99_sp_df = session.create_dataframe(kddcup99_data.frame)\n", + "kddcup99_sp_df.write.mode(\"overwrite\").save_as_table(DATA_TABLE_NAME)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from snowflake.ml.modeling.preprocessing import one_hot_encoder, ordinal_encoder, standard_scaler\n", + "from snowflake.ml.modeling.pipeline import pipeline\n", + "from snowflake.ml.modeling.xgboost import xgb_classifier\n", + "import snowflake.snowpark.functions as F\n", + "\n", + "quote_fn = lambda x: f'\"{x}\"'\n", + "\n", + "ONE_HOT_ENCODE_COL_NAMES = [\"protocol_type\", \"service\", \"flag\"]\n", + "ORDINAL_ENCODE_COL_NAMES = [\"labels\"]\n", + "STANDARD_SCALER_COL_NAMES = [\n", + " \"duration\",\n", + " \"src_bytes\",\n", + " \"dst_bytes\",\n", + " \"wrong_fragment\",\n", + " \"urgent\",\n", + " \"hot\",\n", + " \"num_failed_logins\",\n", + " \"num_compromised\",\n", + " \"num_root\",\n", + " \"num_file_creations\",\n", + " \"num_shells\",\n", + " \"num_access_files\",\n", + " \"num_outbound_cmds\",\n", + " \"count\",\n", + " \"srv_count\",\n", + " \"dst_host_count\",\n", + " \"dst_host_srv_count\",\n", + "]\n", + "\n", + "TRAIN_SIZE_K = 0.2\n", + "kddcup99_data = session.table(DATA_TABLE_NAME)\n", + "kddcup99_data = kddcup99_data.with_columns(\n", + " list(map(quote_fn, ONE_HOT_ENCODE_COL_NAMES + ORDINAL_ENCODE_COL_NAMES)),\n", + " [\n", + " F.to_char(col_name, \"utf-8\")\n", + " for col_name in list(map(quote_fn, ONE_HOT_ENCODE_COL_NAMES + ORDINAL_ENCODE_COL_NAMES))\n", + " ],\n", + ")\n", + "kddcup99_sp_df_train, kddcup99_sp_df_test = tuple(\n", + " kddcup99_data.random_split([TRAIN_SIZE_K, 1 - TRAIN_SIZE_K], seed=2568)\n", + ")\n", + "\n", + "pipe = pipeline.Pipeline(\n", + " steps=[\n", + " (\n", + " \"OHEHOT\",\n", + " one_hot_encoder.OneHotEncoder(\n", + " handle_unknown=\"ignore\",\n", + " input_cols=list(map(quote_fn, ONE_HOT_ENCODE_COL_NAMES)),\n", + " output_cols=ONE_HOT_ENCODE_COL_NAMES,\n", + " drop_input_cols=True,\n", + " ),\n", + " ),\n", + " (\n", + " \"ORDINAL\",\n", + " ordinal_encoder.OrdinalEncoder(\n", + " input_cols=list(map(quote_fn, ORDINAL_ENCODE_COL_NAMES)),\n", + " output_cols=['\"encoded_labels\"'],\n", + " drop_input_cols=True,\n", + " ),\n", + " ),\n", + " (\n", + " \"STD\",\n", + " standard_scaler.StandardScaler(\n", + " input_cols=list(map(quote_fn, STANDARD_SCALER_COL_NAMES)),\n", + " output_cols=list(map(quote_fn, STANDARD_SCALER_COL_NAMES)),\n", + " drop_input_cols=True,\n", + " ),\n", + " ),\n", + " (\"CLASSIFIER\", xgb_classifier.XGBClassifier(label_cols=['\"encoded_labels\"'])),\n", + " ]\n", + ")\n", + "pipe.fit(kddcup99_sp_df_train)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"pipeline_model\"\n", + "version_name = \"v2\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mv = reg.log_model(pipe, model_name=model_name, version_name=version_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mv.run(kddcup99_sp_df_test, method_name=\"predict\").show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use with customize model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download a GPT-2 model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "model_name = \"gpt2-medium\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "model = AutoModelForCausalLM.from_pretrained(model_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Store GPT-2 Model components locally" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ARTIFACTS_DIR = \"/tmp/gpt-2/\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.makedirs(os.path.join(ARTIFACTS_DIR, \"model\"), exist_ok=True)\n", + "os.makedirs(os.path.join(ARTIFACTS_DIR, \"tokenizer\"), exist_ok=True)\n", + "\n", + "model.save_pretrained(os.path.join(ARTIFACTS_DIR, \"model\"))\n", + "tokenizer.save_pretrained(os.path.join(ARTIFACTS_DIR, \"tokenizer\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create a custom model using GPT-2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from snowflake.ml.model import custom_model\n", + "import pandas as pd\n", + "\n", + "\n", + "class GPT2Model(custom_model.CustomModel):\n", + " def __init__(self, context: custom_model.ModelContext) -> None:\n", + " super().__init__(context)\n", + "\n", + " self.model = AutoModelForCausalLM.from_pretrained(self.context.path(\"model\"))\n", + " self.tokenizer = AutoTokenizer.from_pretrained(self.context.path(\"tokenizer\"))\n", + "\n", + " @custom_model.inference_api\n", + " def predict(self, X: pd.DataFrame) -> pd.DataFrame:\n", + " def _generate(input_text: str) -> str:\n", + " input_ids = self.tokenizer.encode(input_text, return_tensors=\"pt\")\n", + "\n", + " output = self.model.generate(input_ids, max_length=50, do_sample=True, top_p=0.95, top_k=60)\n", + " generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)\n", + "\n", + " return generated_text\n", + "\n", + " res_df = pd.DataFrame({\"output\": pd.Series.apply(X[\"input\"], _generate)})\n", + " return res_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gpt_model = GPT2Model(\n", + " custom_model.ModelContext(\n", + " models={},\n", + " artifacts={\n", + " \"model\": os.path.join(ARTIFACTS_DIR, \"model\"),\n", + " \"tokenizer\": os.path.join(ARTIFACTS_DIR, \"tokenizer\"),\n", + " },\n", + " )\n", + ")\n", + "\n", + "gpt_model.predict(pd.DataFrame({\"input\": [\"Hello, are you GPT?\"]}))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Register the custom model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here, how to specify dependencies and model signature manually is shown." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"gpt2_medium\"\n", + "version_name = \"v1\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from snowflake.ml.model import model_signature\n", + "\n", + "mv = reg.log_model(\n", + " gpt_model,\n", + " model_name=model_name,\n", + " version_name=version_name,\n", + " conda_dependencies=[\"pytorch\", \"transformers\"],\n", + " signatures={\n", + " \"predict\": model_signature.ModelSignature(\n", + " inputs=[model_signature.FeatureSpec(name=\"input\", dtype=model_signature.DataType.STRING)],\n", + " outputs=[model_signature.FeatureSpec(name=\"output\", dtype=model_signature.DataType.STRING)],\n", + " )\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mv.run(pd.DataFrame({\"input\": [\"Hello, are you GPT?\"]}))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/snowflake/ml/registry/registry.py b/snowflake/ml/registry/registry.py new file mode 100644 index 00000000..f8031fac --- /dev/null +++ b/snowflake/ml/registry/registry.py @@ -0,0 +1,215 @@ +from types import ModuleType +from typing import Dict, List, Optional + +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import model_signature, type_hints as model_types +from snowflake.ml.model._client.model import model_impl, model_version_impl +from snowflake.ml.model._client.ops import model_ops +from snowflake.ml.model._model_composer import model_composer +from snowflake.snowpark import session + +_TELEMETRY_PROJECT = "MLOps" +_MODEL_TELEMETRY_SUBPROJECT = "ModelManagement" + + +class Registry: + def __init__( + self, + session: session.Session, + *, + database_name: Optional[str] = None, + schema_name: Optional[str] = None, + ) -> None: + if database_name: + self._database_name = sql_identifier.SqlIdentifier(database_name) + else: + session_db = session.get_current_database() + if session_db: + self._database_name = sql_identifier.SqlIdentifier(session_db) + else: + raise ValueError("You need to provide a database to use registry.") + + if schema_name: + self._schema_name = sql_identifier.SqlIdentifier(schema_name) + elif database_name: + self._schema_name = sql_identifier.SqlIdentifier("PUBLIC") + else: + session_schema = session.get_current_schema() + self._schema_name = ( + sql_identifier.SqlIdentifier(session_schema) + if session_schema + else sql_identifier.SqlIdentifier("PUBLIC") + ) + + self._model_ops = model_ops.ModelOperator( + session, database_name=self._database_name, schema_name=self._schema_name + ) + + @property + def location(self) -> str: + return ".".join([self._database_name.identifier(), self._schema_name.identifier()]) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_MODEL_TELEMETRY_SUBPROJECT, + ) + def log_model( + self, + model: model_types.SupportedModelType, + *, + model_name: str, + version_name: str, + conda_dependencies: Optional[List[str]] = None, + pip_requirements: Optional[List[str]] = None, + python_version: Optional[str] = None, + signatures: Optional[Dict[str, model_signature.ModelSignature]] = None, + sample_input_data: Optional[model_types.SupportedDataType] = None, + code_paths: Optional[List[str]] = None, + ext_modules: Optional[List[ModuleType]] = None, + options: Optional[model_types.ModelSaveOption] = None, + ) -> model_version_impl.ModelVersion: + """Log a model. + + Args: + model: Model Python object + model_name: A string as name. + version_name: A string as version. model_name and version_name combination must be unique. + signatures: Model data signatures for inputs and output for every target methods. If it is None, + sample_input_data would be used to infer the signatures for those models that cannot automatically + infer the signature. If not None, sample_input should not be specified. Defaults to None. + sample_input_data: Sample input data to infer the model signatures from. If it is None, signatures must be + specified if the model cannot automatically infer the signature. If not None, signatures should not be + specified. Defaults to None. + conda_dependencies: List of Conda package specs. Use "[channel::]package [operator version]" syntax to + specify a dependency. It is a recommended way to specify your dependencies using conda. When channel is + not specified, Snowflake Anaconda Channel will be used. + pip_requirements: List of Pip package specs. + python_version: A string of python version where model is run. Used for user override. If specified as None, + current version would be captured. Defaults to None. + code_paths: Directory of code to import. + ext_modules: External modules that user might want to get pickled with model object. Defaults to None. + options: Model specific kwargs. + + Returns: + A ModelVersion object corresponding to the model just get logged. + """ + + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_MODEL_TELEMETRY_SUBPROJECT, + ) + model_name_id = sql_identifier.SqlIdentifier(model_name) + + version_name_id = sql_identifier.SqlIdentifier(version_name) + + stage_path = self._model_ops.prepare_model_stage_path( + statement_params=statement_params, + ) + + mc = model_composer.ModelComposer(self._model_ops._session, stage_path=stage_path) + mc.save( + name=model_name_id.resolved(), + model=model, + signatures=signatures, + sample_input=sample_input_data, + conda_dependencies=conda_dependencies, + pip_requirements=pip_requirements, + python_version=python_version, + code_paths=code_paths, + ext_modules=ext_modules, + options=options, + ) + self._model_ops.create_from_stage( + composed_model=mc, + model_name=model_name_id, + version_name=version_name_id, + statement_params=statement_params, + ) + + return model_version_impl.ModelVersion._ref( + self._model_ops, + model_name=model_name_id, + version_name=version_name_id, + ) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_MODEL_TELEMETRY_SUBPROJECT, + ) + def get_model(self, model_name: str) -> model_impl.Model: + """Get the model object. + + Args: + model_name: The model name. + + Raises: + ValueError: Raised when the model requested does not exist. + + Returns: + The model object. + """ + model_name_id = sql_identifier.SqlIdentifier(model_name) + + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_MODEL_TELEMETRY_SUBPROJECT, + ) + if self._model_ops.validate_existence( + model_name=model_name_id, + statement_params=statement_params, + ): + return model_impl.Model._ref( + self._model_ops, + model_name=model_name_id, + ) + else: + raise ValueError(f"Unable to find model {model_name}") + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_MODEL_TELEMETRY_SUBPROJECT, + ) + def list_models(self) -> List[model_impl.Model]: + """List all models in the schema where the registry is opened. + + Returns: + A List of Model= object representing all models in the schema where the registry is opened. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_MODEL_TELEMETRY_SUBPROJECT, + ) + model_names = self._model_ops.list_models_or_versions( + statement_params=statement_params, + ) + return [ + model_impl.Model._ref( + self._model_ops, + model_name=model_name, + ) + for model_name in model_names + ] + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_MODEL_TELEMETRY_SUBPROJECT, + ) + def delete_model(self, model_name: str) -> None: + """Delete the model. + + Args: + model_name: The model name, can be fully qualified one. + If not, use database name and schema name of the registry. + """ + model_name_id = sql_identifier.SqlIdentifier(model_name) + + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_MODEL_TELEMETRY_SUBPROJECT, + ) + + self._model_ops.delete_model_or_version( + model_name=model_name_id, + statement_params=statement_params, + ) diff --git a/snowflake/ml/registry/registry_test.py b/snowflake/ml/registry/registry_test.py new file mode 100644 index 00000000..b59e27f9 --- /dev/null +++ b/snowflake/ml/registry/registry_test.py @@ -0,0 +1,298 @@ +from typing import cast +from unittest import mock + +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.model import model_impl, model_version_impl +from snowflake.ml.model._model_composer import model_composer +from snowflake.ml.registry import registry +from snowflake.ml.test_utils import mock_session +from snowflake.snowpark import Session + + +class RegistryNameTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + + def test_location(self) -> None: + c_session = cast(Session, self.m_session) + r = registry.Registry(c_session, database_name="TEMP", schema_name="TEST") + self.assertEqual(r.location, "TEMP.TEST") + r = registry.Registry(c_session, database_name="TEMP", schema_name="test") + self.assertEqual(r.location, "TEMP.TEST") + r = registry.Registry(c_session, database_name="TEMP", schema_name='"test"') + self.assertEqual(r.location, 'TEMP."test"') + + with mock.patch.object(c_session, "get_current_schema", return_value='"CURRENT_TEMP"', create=True): + r = registry.Registry(c_session, database_name="TEMP") + self.assertEqual(r.location, "TEMP.PUBLIC") + r = registry.Registry(c_session, database_name="temp") + self.assertEqual(r.location, "TEMP.PUBLIC") + r = registry.Registry(c_session, database_name='"temp"') + self.assertEqual(r.location, '"temp".PUBLIC') + + with mock.patch.object(c_session, "get_current_schema", return_value=None, create=True): + r = registry.Registry(c_session, database_name="TEMP") + self.assertEqual(r.location, "TEMP.PUBLIC") + r = registry.Registry(c_session, database_name="temp") + self.assertEqual(r.location, "TEMP.PUBLIC") + r = registry.Registry(c_session, database_name='"temp"') + self.assertEqual(r.location, '"temp".PUBLIC') + + with mock.patch.object(c_session, "get_current_database", return_value='"CURRENT_TEMP"', create=True): + r = registry.Registry(c_session, schema_name="TEMP") + self.assertEqual(r.location, "CURRENT_TEMP.TEMP") + r = registry.Registry(c_session, schema_name="temp") + self.assertEqual(r.location, "CURRENT_TEMP.TEMP") + r = registry.Registry(c_session, schema_name='"temp"') + self.assertEqual(r.location, 'CURRENT_TEMP."temp"') + + with mock.patch.object(c_session, "get_current_database", return_value='"current_temp"', create=True): + r = registry.Registry(c_session, schema_name="TEMP") + self.assertEqual(r.location, '"current_temp".TEMP') + r = registry.Registry(c_session, schema_name="temp") + self.assertEqual(r.location, '"current_temp".TEMP') + r = registry.Registry(c_session, schema_name='"temp"') + self.assertEqual(r.location, '"current_temp"."temp"') + + with mock.patch.object(c_session, "get_current_database", return_value=None, create=True): + with self.assertRaisesRegex(ValueError, "You need to provide a database to use registry."): + r = registry.Registry(c_session, schema_name="TEMP") + + with mock.patch.object( + c_session, "get_current_database", return_value='"CURRENT_TEMP"', create=True + ), mock.patch.object(c_session, "get_current_schema", return_value='"CURRENT_TEMP"', create=True): + r = registry.Registry(c_session) + self.assertEqual(r.location, "CURRENT_TEMP.CURRENT_TEMP") + + with mock.patch.object( + c_session, "get_current_database", return_value='"CURRENT_TEMP"', create=True + ), mock.patch.object(c_session, "get_current_schema", return_value='"current_temp"', create=True): + r = registry.Registry(c_session) + self.assertEqual(r.location, 'CURRENT_TEMP."current_temp"') + + with mock.patch.object( + c_session, "get_current_database", return_value='"CURRENT_TEMP"', create=True + ), mock.patch.object(c_session, "get_current_schema", return_value=None, create=True): + r = registry.Registry(c_session) + self.assertEqual(r.location, "CURRENT_TEMP.PUBLIC") + + +class RegistryTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.c_session = cast(Session, self.m_session) + self.m_r = registry.Registry(self.c_session, database_name="TEMP", schema_name="TEST") + + def test_get_model_1(self) -> None: + m_model = model_impl.Model._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + ) + with mock.patch.object(self.m_r._model_ops, "validate_existence", return_value=True) as mock_validate_existence: + m = self.m_r.get_model("MODEL") + self.assertEqual(m, m_model) + mock_validate_existence.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + + def test_get_model_2(self) -> None: + with mock.patch.object( + self.m_r._model_ops, "validate_existence", return_value=False + ) as mock_validate_existence: + with self.assertRaisesRegex(ValueError, "Unable to find model MODEL"): + self.m_r.get_model("MODEL") + mock_validate_existence.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + + def test_list_models(self) -> None: + m_model_1 = model_impl.Model._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + ) + m_model_2 = model_impl.Model._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("Model", case_sensitive=True), + ) + with mock.patch.object( + self.m_r._model_ops, + "list_models_or_versions", + return_value=[ + sql_identifier.SqlIdentifier("MODEL"), + sql_identifier.SqlIdentifier("Model", case_sensitive=True), + ], + ) as mock_list_models_or_versions: + m_list = self.m_r.list_models() + self.assertListEqual(m_list, [m_model_1, m_model_2]) + mock_list_models_or_versions.assert_called_once_with( + statement_params=mock.ANY, + ) + + def test_log_model_1(self) -> None: + m_model = mock.MagicMock() + m_conda_dependency = mock.MagicMock() + m_sample_input_data = mock.MagicMock() + m_stage_path = "@TEMP.TEST.MODEL/V1" + with mock.patch.object( + self.m_r._model_ops, "prepare_model_stage_path", return_value=m_stage_path + ) as mock_prepare_model_stage_path, mock.patch.object( + model_composer.ModelComposer, "save" + ) as mock_save, mock.patch.object( + self.m_r._model_ops, "create_from_stage" + ) as mock_create_from_stage: + mv = self.m_r.log_model( + model=m_model, + model_name="MODEL", + version_name="v1", + conda_dependencies=m_conda_dependency, + sample_input_data=m_sample_input_data, + ) + mock_prepare_model_stage_path.assert_called_once_with( + statement_params=mock.ANY, + ) + mock_save.assert_called_once_with( + name="MODEL", + model=m_model, + signatures=None, + sample_input=m_sample_input_data, + conda_dependencies=m_conda_dependency, + pip_requirements=None, + python_version=None, + code_paths=None, + ext_modules=None, + options=None, + ) + mock_create_from_stage.assert_called_once_with( + composed_model=mock.ANY, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1"), + statement_params=mock.ANY, + ) + self.assertEqual( + mv, + model_version_impl.ModelVersion._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1"), + ), + ) + + def test_log_model_2(self) -> None: + m_model = mock.MagicMock() + m_pip_requirements = mock.MagicMock() + m_signatures = mock.MagicMock() + m_options = mock.MagicMock() + m_stage_path = "@TEMP.TEST.MODEL/V1" + with mock.patch.object( + self.m_r._model_ops, "prepare_model_stage_path", return_value=m_stage_path + ) as mock_prepare_model_stage_path, mock.patch.object( + model_composer.ModelComposer, "save" + ) as mock_save, mock.patch.object( + self.m_r._model_ops, "create_from_stage" + ) as mock_create_from_stage: + mv = self.m_r.log_model( + model=m_model, + model_name="MODEL", + version_name="V1", + pip_requirements=m_pip_requirements, + signatures=m_signatures, + options=m_options, + ) + mock_prepare_model_stage_path.assert_called_once_with( + statement_params=mock.ANY, + ) + mock_save.assert_called_once_with( + name="MODEL", + model=m_model, + signatures=m_signatures, + sample_input=None, + conda_dependencies=None, + pip_requirements=m_pip_requirements, + python_version=None, + code_paths=None, + ext_modules=None, + options=m_options, + ) + mock_create_from_stage.assert_called_once_with( + composed_model=mock.ANY, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=mock.ANY, + ) + self.assertEqual( + mv, + model_version_impl.ModelVersion._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + ), + ) + + def test_log_model_3(self) -> None: + m_model = mock.MagicMock() + m_python_version = mock.MagicMock() + m_code_paths = mock.MagicMock() + m_ext_modules = mock.MagicMock() + m_stage_path = "@TEMP.TEST.MODEL/V1" + with mock.patch.object( + self.m_r._model_ops, "prepare_model_stage_path", return_value=m_stage_path + ) as mock_prepare_model_stage_path, mock.patch.object( + model_composer.ModelComposer, "save" + ) as mock_save, mock.patch.object( + self.m_r._model_ops, "create_from_stage" + ) as mock_create_from_stage: + mv = self.m_r.log_model( + model=m_model, + model_name="MODEL", + version_name="V1", + python_version=m_python_version, + code_paths=m_code_paths, + ext_modules=m_ext_modules, + ) + mock_prepare_model_stage_path.assert_called_once_with( + statement_params=mock.ANY, + ) + mock_save.assert_called_once_with( + name="MODEL", + model=m_model, + signatures=None, + sample_input=None, + conda_dependencies=None, + pip_requirements=None, + python_version=m_python_version, + code_paths=m_code_paths, + ext_modules=m_ext_modules, + options=None, + ) + mock_create_from_stage.assert_called_once_with( + composed_model=mock.ANY, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=mock.ANY, + ) + self.assertEqual( + mv, + model_version_impl.ModelVersion._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + ), + ) + + def test_delete_model(self) -> None: + with mock.patch.object(self.m_r._model_ops, "delete_model_or_version") as mock_delete_model_or_version: + self.m_r.delete_model( + model_name="MODEL", + ) + mock_delete_model_or_version.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/requirements.bzl b/snowflake/ml/requirements.bzl index 0761d610..8d5cd87b 100755 --- a/snowflake/ml/requirements.bzl +++ b/snowflake/ml/requirements.bzl @@ -48,8 +48,10 @@ REQUIREMENTS = [ "numpy>=1.23,<2", "packaging>=20.9,<24", "pandas>=1.0.0,<2", + "pyarrow", "pytimeparse>=1.1.8,<2", "pyyaml>=6.0,<7", + "retrying>=1.3.3,<2", "s3fs>=2022.11,<2024", "scikit-learn>=1.2.1,<1.4", "scipy>=1.9,<2", diff --git a/snowflake/ml/version.bzl b/snowflake/ml/version.bzl index 18c31510..d6dfd85a 100644 --- a/snowflake/ml/version.bzl +++ b/snowflake/ml/version.bzl @@ -1,2 +1,2 @@ # This is parsed by regex in conda reciper meta file. Make sure not to break it. -VERSION = "1.1.1" +VERSION = "1.1.2" diff --git a/tests/integ/snowflake/ml/_internal/env_utils_integ_test.py b/tests/integ/snowflake/ml/_internal/env_utils_integ_test.py index 762643ef..35813d87 100644 --- a/tests/integ/snowflake/ml/_internal/env_utils_integ_test.py +++ b/tests/integ/snowflake/ml/_internal/env_utils_integ_test.py @@ -14,12 +14,12 @@ def tearDown(self) -> None: self._session.close() def test_validate_requirement_in_snowflake_conda_channel(self) -> None: - res = env_utils.validate_requirements_in_snowflake_conda_channel( + res = env_utils.validate_requirements_in_information_schema( session=self._session, reqs=[requirements.Requirement("xgboost")], python_version=snowml_env.PYTHON_VERSION ) self.assertNotEmpty(res) - res = env_utils.validate_requirements_in_snowflake_conda_channel( + res = env_utils.validate_requirements_in_information_schema( session=self._session, reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], python_version=snowml_env.PYTHON_VERSION, @@ -27,7 +27,7 @@ def test_validate_requirement_in_snowflake_conda_channel(self) -> None: self.assertNotEmpty(res) self.assertIsNone( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=self._session, reqs=[requirements.Requirement("xgboost==1.0.*")], python_version=snowml_env.PYTHON_VERSION, @@ -35,7 +35,7 @@ def test_validate_requirement_in_snowflake_conda_channel(self) -> None: ) self.assertIsNone( - env_utils.validate_requirements_in_snowflake_conda_channel( + env_utils.validate_requirements_in_information_schema( session=self._session, reqs=[requirements.Requirement("python-package")], python_version=snowml_env.PYTHON_VERSION, diff --git a/tests/integ/snowflake/ml/_internal/snowpark_handlers_test.py b/tests/integ/snowflake/ml/_internal/snowpark_handlers_test.py index d61c8039..ef038bd9 100644 --- a/tests/integ/snowflake/ml/_internal/snowpark_handlers_test.py +++ b/tests/integ/snowflake/ml/_internal/snowpark_handlers_test.py @@ -7,11 +7,7 @@ from sklearn.datasets import load_diabetes from sklearn.linear_model import LinearRegression as SkLinearRegression -from snowflake.ml._internal.exceptions import exceptions -from snowflake.ml.modeling._internal.snowpark_handlers import ( - SklearnWrapperProvider, - SnowparkHandlers, -) +from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers from tests.integ.snowflake.ml.test_utils import common_test_base @@ -19,9 +15,7 @@ class SnowparkHandlersTest(common_test_base.CommonTestBase): def setUp(self) -> None: """Creates Snowpark and Snowflake environments for testing.""" super().setUp() - self._handlers = SnowparkHandlers( - class_name="test", subproject="subproject", wrapper_provider=SklearnWrapperProvider() - ) + self._handlers = SnowparkHandlers(class_name="test", subproject="subproject") def _get_test_dataset(self) -> Tuple[pd.DataFrame, List[str], List[str]]: """Constructs input dataset to be used in the integration test. @@ -54,41 +48,6 @@ def _get_test_dataset(self) -> Tuple[pd.DataFrame, List[str], List[str]]: return (input_df_pandas, input_cols, label_cols) - @common_test_base.CommonTestBase.sproc_test() - def test_fit_snowpark(self) -> None: - input_df_pandas, input_cols, label_cols = self._get_test_dataset() - input_df = self.session.create_dataframe(input_df_pandas) - - sklearn_estimator = SkLinearRegression() - - fit_estimator = self._handlers.fit_snowpark( - dataset=input_df, - session=self.session, - estimator=sklearn_estimator, - dependencies=["snowflake-snowpark-python", "numpy", "scikit-learn", "cloudpickle"], - input_cols=input_cols, - label_cols=label_cols, - sample_weight_col=None, - ) - - pandas_fit_estimator = sklearn_estimator.fit( - X=input_df_pandas[input_cols], y=input_df_pandas[label_cols].squeeze() - ) - - # Confirm that sproc was stored in session._FIT_WRAPPER_SPROCS for reuse. - assert "SklearnWrapperProvider" in self.session._FIT_WRAPPER_SPROCS - - fit_estimator = self._handlers.fit_snowpark( - dataset=input_df, - session=self.session, - estimator=sklearn_estimator, - dependencies=["snowflake-snowpark-python", "numpy", "scikit-learn", "cloudpickle"], - input_cols=input_cols, - label_cols=label_cols, - sample_weight_col=None, - ) - np.testing.assert_allclose(fit_estimator.coef_, pandas_fit_estimator.coef_) - @common_test_base.CommonTestBase.sproc_test() def test_batch_inference(self) -> None: sklearn_estimator = SkLinearRegression() @@ -139,28 +98,6 @@ def test_score_snowpark(self) -> None: np.testing.assert_allclose(score, sklearn_score) - @common_test_base.CommonTestBase.sproc_test() - def test_fit_snowpark_no_label_cols(self) -> None: - input_df_pandas, input_cols, _ = self._get_test_dataset() - label_cols = [] - input_df = self.session.create_dataframe(input_df_pandas) - - sklearn_estimator = SkLinearRegression() - - with self.assertRaises(exceptions.SnowflakeMLException) as e: - self._handlers.fit_snowpark( - dataset=input_df, - session=self.session, - estimator=sklearn_estimator, - dependencies=["snowflake-snowpark-python", "numpy", "scikit-learn", "cloudpickle"], - input_cols=input_cols, - label_cols=label_cols, - sample_weight_col=None, - ) - - self.assertIsInstance(e.exception.original_exception, RuntimeError) - self.assertIn("label_cols", str(e.exception)) - if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/model/BUILD.bazel b/tests/integ/snowflake/ml/model/BUILD.bazel index fffbeb68..ff20119d 100644 --- a/tests/integ/snowflake/ml/model/BUILD.bazel +++ b/tests/integ/snowflake/ml/model/BUILD.bazel @@ -9,6 +9,7 @@ py_library( "//snowflake/ml/model:deploy_platforms", "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_signatures:snowpark_handler", + "//tests/integ/snowflake/ml/test_utils:dataframe_utils", "//tests/integ/snowflake/ml/test_utils:db_manager", "//tests/integ/snowflake/ml/test_utils:test_env_utils", ], @@ -25,6 +26,7 @@ py_test( "//snowflake/ml/model:deploy_platforms", "//snowflake/ml/model:type_hints", "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:dataframe_utils", "//tests/integ/snowflake/ml/test_utils:db_manager", "//tests/integ/snowflake/ml/test_utils:test_env_utils", ], @@ -41,6 +43,7 @@ py_test( "//snowflake/ml/model/_signatures:pytorch_handler", "//snowflake/ml/model/_signatures:snowpark_handler", "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:dataframe_utils", "//tests/integ/snowflake/ml/test_utils:db_manager", "//tests/integ/snowflake/ml/test_utils:model_factory", ], @@ -57,6 +60,7 @@ py_test( "//snowflake/ml/model/_signatures:snowpark_handler", "//snowflake/ml/model/_signatures:tensorflow_handler", "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:dataframe_utils", "//tests/integ/snowflake/ml/test_utils:db_manager", "//tests/integ/snowflake/ml/test_utils:model_factory", ], @@ -71,6 +75,7 @@ py_test( ":warehouse_model_integ_test_utils", "//snowflake/ml/model:type_hints", "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:dataframe_utils", "//tests/integ/snowflake/ml/test_utils:db_manager", ], ) diff --git a/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel b/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel new file mode 100644 index 00000000..d5bfd8b7 --- /dev/null +++ b/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel @@ -0,0 +1,31 @@ +load("//bazel:py_rules.bzl", "py_test") + +py_test( + name = "model_impl_integ_test", + timeout = "long", + srcs = ["model_impl_integ_test.py"], + shard_count = 6, + deps = [ + "//snowflake/ml/registry", + "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:db_manager", + "//tests/integ/snowflake/ml/test_utils:model_factory", + "//tests/integ/snowflake/ml/test_utils:test_env_utils", + ], +) + +py_test( + name = "model_version_impl_integ_test", + timeout = "long", + srcs = ["model_version_impl_integ_test.py"], + shard_count = 6, + deps = [ + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model/_client/model:model_version_impl", + "//snowflake/ml/registry", + "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:db_manager", + "//tests/integ/snowflake/ml/test_utils:model_factory", + "//tests/integ/snowflake/ml/test_utils:test_env_utils", + ], +) diff --git a/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py b/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py new file mode 100644 index 00000000..48760ac5 --- /dev/null +++ b/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py @@ -0,0 +1,84 @@ +import unittest +import uuid + +from absl.testing import absltest, parameterized +from packaging import version + +from snowflake.ml.registry import registry +from snowflake.ml.utils import connection_params +from snowflake.snowpark import Session +from tests.integ.snowflake.ml.test_utils import ( + db_manager, + model_factory, + test_env_utils, +) + +MODEL_NAME = "TEST_MODEL" +VERSION_NAME = "V1" +VERSION_NAME2 = "V2" + + +class TestModelImplInteg(parameterized.TestCase): + @classmethod + def setUpClass(self) -> None: + """Creates Snowpark and Snowflake environments for testing.""" + login_options = connection_params.SnowflakeLoginOptions() + + self._run_id = uuid.uuid4().hex + self._test_db = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self._run_id, "db").upper() + self._test_schema = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( + self._run_id, "schema" + ).upper() + + self._session = Session.builder.configs( + { + **login_options, + **{"database": self._test_db, "schema": self._test_schema}, + } + ).create() + + current_sf_version = test_env_utils.get_current_snowflake_version(self._session) + + if current_sf_version < version.parse("8.0.0"): + raise unittest.SkipTest("This test requires Snowflake Version 8.0.0 or higher.") + + self._db_manager = db_manager.DBManager(self._session) + self._db_manager.create_database(self._test_db) + self._db_manager.create_schema(self._test_schema) + self._db_manager.cleanup_databases(expire_hours=6) + self.registry = registry.Registry(self._session) + + model, test_features, _ = model_factory.ModelFactory.prepare_sklearn_model() + self._mv = self.registry.log_model( + model=model, + model_name=MODEL_NAME, + version_name=VERSION_NAME, + sample_input_data=test_features, + ) + self._mv2 = self.registry.log_model( + model=model, + model_name=MODEL_NAME, + version_name=VERSION_NAME2, + sample_input_data=test_features, + ) + self._model = self.registry.get_model(model_name=MODEL_NAME) + + @classmethod + def tearDownClass(self) -> None: + self._db_manager.drop_database(self._test_db) + self._session.close() + + def test_description(self) -> None: + description = "test description" + self._model.description = description + self.assertEqual(self._model.description, description) + + def test_default(self) -> None: + self.assertEqual(self._model.default.version_name, VERSION_NAME) + + self._model.default = VERSION_NAME2 + self.assertEqual(self._model.default.version_name, VERSION_NAME2) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py b/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py new file mode 100644 index 00000000..ca7b367d --- /dev/null +++ b/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py @@ -0,0 +1,85 @@ +import unittest +import uuid + +from absl.testing import absltest, parameterized +from packaging import version + +from snowflake.ml.registry import registry +from snowflake.ml.utils import connection_params +from snowflake.snowpark import Session +from tests.integ.snowflake.ml.test_utils import ( + db_manager, + model_factory, + test_env_utils, +) + +MODEL_NAME = "TEST_MODEL" +VERSION_NAME = "V1" + + +class TestModelVersionImplInteg(parameterized.TestCase): + @classmethod + def setUpClass(self) -> None: + """Creates Snowpark and Snowflake environments for testing.""" + login_options = connection_params.SnowflakeLoginOptions() + + self._run_id = uuid.uuid4().hex + self._test_db = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self._run_id, "db").upper() + self._test_schema = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( + self._run_id, "schema" + ).upper() + + self._session = Session.builder.configs( + { + **login_options, + **{"database": self._test_db, "schema": self._test_schema}, + } + ).create() + + current_sf_version = test_env_utils.get_current_snowflake_version(self._session) + + if current_sf_version < version.parse("8.0.0"): + raise unittest.SkipTest("This test requires Snowflake Version 8.0.0 or higher.") + + self._db_manager = db_manager.DBManager(self._session) + self._db_manager.create_database(self._test_db) + self._db_manager.create_schema(self._test_schema) + self._db_manager.cleanup_databases(expire_hours=6) + self.registry = registry.Registry(self._session) + + model, test_features, _ = model_factory.ModelFactory.prepare_sklearn_model() + self._mv = self.registry.log_model( + model=model, + model_name=MODEL_NAME, + version_name=VERSION_NAME, + sample_input_data=test_features, + ) + + @classmethod + def tearDownClass(self) -> None: + self._db_manager.drop_database(self._test_db) + self._session.close() + + def test_description(self) -> None: + description = "test description" + self._mv.description = description + self.assertEqual(self._mv.description, description) + + def test_metrics(self) -> None: + self._mv.set_metric("a", 1) + expected_metrics = {"a": 2, "b": 1.0, "c": True} + for k, v in expected_metrics.items(): + self._mv.set_metric(k, v) + + self.assertEqual(self._mv.get_metric("a"), expected_metrics["a"]) + self.assertDictEqual(self._mv.list_metrics(), expected_metrics) + + expected_metrics.pop("b") + self._mv.delete_metric("b") + self.assertDictEqual(self._mv.list_metrics(), expected_metrics) + with self.assertRaises(KeyError): + self._mv.get_metric("b") + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/model/spcs_llm_model_integ_test.py b/tests/integ/snowflake/ml/model/spcs_llm_model_integ_test.py index c740b4b9..6639d3da 100644 --- a/tests/integ/snowflake/ml/model/spcs_llm_model_integ_test.py +++ b/tests/integ/snowflake/ml/model/spcs_llm_model_integ_test.py @@ -54,7 +54,7 @@ def test_text_generation_pipeline( model=model, options={"embed_local_ml_library": True}, conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server(self._session, "snowflake-snowpark-python"), + test_env_utils.get_latest_package_version_spec_in_conda("snowflake-snowpark-python"), ], ) svc_func_name = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( diff --git a/tests/integ/snowflake/ml/model/warehouse_custom_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_custom_model_integ_test.py index 7877b310..40efd52f 100644 --- a/tests/integ/snowflake/ml/model/warehouse_custom_model_integ_test.py +++ b/tests/integ/snowflake/ml/model/warehouse_custom_model_integ_test.py @@ -12,7 +12,7 @@ from snowflake.ml.utils import connection_params from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import db_manager +from tests.integ.snowflake.ml.test_utils import dataframe_utils, db_manager class DemoModel(custom_model.CustomModel): @@ -170,7 +170,7 @@ def test_custom_demo_model_sp( deploy_params={ "": ( {}, - lambda res: warehouse_model_integ_test_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), ), }, permanent_deploy=permanent_deploy, @@ -202,6 +202,31 @@ def test_custom_demo_model_sp_quote( permanent_deploy=permanent_deploy, ) + @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] + def test_custom_demo_model_sp_quote_norm_1( + self, + permanent_deploy: Optional[bool] = False, + ) -> None: + lm = DemoModelSPQuote(custom_model.ModelContext()) + arr = [[1, 2, 3], [4, 2, 5]] + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + sp_df = self._session.create_dataframe(arr, schema=['"""c1"""', '"""c2"""', '"""c3"""']) + sp_df_1 = self._session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) + y_df_expected = pd.concat([pd_df, pd_df[["c1"]].rename(columns={"c1": "output"})], axis=1) + self.base_test_case( + name="custom_demo_model_sp_quote", + model=lm, + sample_input=sp_df, + test_input=sp_df_1, + deploy_params={ + "": ( + {}, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + ), + }, + permanent_deploy=permanent_deploy, + ) + @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] def test_custom_demo_model_sp_mix_1( self, @@ -220,7 +245,33 @@ def test_custom_demo_model_sp_mix_1( deploy_params={ "": ( {}, - lambda res: warehouse_model_integ_test_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + ), + }, + permanent_deploy=permanent_deploy, + ) + + @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] + def test_custom_demo_model_sp_mix_1_norm( + self, + permanent_deploy: Optional[bool] = False, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + arr = [[1, 2, 3], [4, 2, 5]] + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + sp_df = self._session.create_dataframe(arr, schema=["c1", "c2", "c3"]) + y_df_expected = pd.concat( + [pd_df.rename(columns=str.upper), pd_df[["c1"]].rename(columns={"c1": "OUTPUT"})], axis=1 + ) + self.base_test_case( + name="custom_demo_model_sp1", + model=lm, + sample_input=pd_df, + test_input=sp_df, + deploy_params={ + "": ( + {}, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), ), }, permanent_deploy=permanent_deploy, @@ -252,6 +303,31 @@ def test_custom_demo_model_sp_mix_2( permanent_deploy=permanent_deploy, ) + @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] + def test_custom_demo_model_sp_mix_2_norm( + self, + permanent_deploy: Optional[bool] = False, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + arr = [[1, 2, 3], [4, 2, 5]] + sp_df = self._session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) + sp_df_1 = self._session.create_dataframe(arr, schema=["c1", "c2", "c3"]) + pd_df = pd.DataFrame(arr, columns=["C1", "C2", "C3"]) + y_df_expected = pd.concat([pd_df, pd_df[["C1"]].rename(columns={"C1": "OUTPUT"})], axis=1) + self.base_test_case( + name="custom_demo_model_sp2", + model=lm, + sample_input=sp_df, + test_input=sp_df_1, + deploy_params={ + "": ( + {}, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + ), + }, + permanent_deploy=permanent_deploy, + ) + @parameterized.product(permanent_deploy=[True, False]) # type: ignore[misc] def test_custom_demo_model_array( self, @@ -319,7 +395,7 @@ def test_custom_demo_model_array_sp( deploy_params={ "": ( {}, - lambda res: warehouse_model_integ_test_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), ) }, permanent_deploy=permanent_deploy, @@ -342,7 +418,7 @@ def test_custom_demo_model_str_sp( deploy_params={ "": ( {}, - lambda res: warehouse_model_integ_test_utils.check_sp_df_res(res, y_df_expected), + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), ) }, permanent_deploy=permanent_deploy, @@ -425,9 +501,7 @@ def test_custom_model_bool_sp( deploy_params={ "": ( {}, - lambda res: warehouse_model_integ_test_utils.check_sp_df_res( - res, y_df_expected, check_dtype=False - ), + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), ) }, permanent_deploy=permanent_deploy, diff --git a/tests/integ/snowflake/ml/model/warehouse_model_integ_test_utils.py b/tests/integ/snowflake/ml/model/warehouse_model_integ_test_utils.py index 9fd7e5b1..14ecd17b 100644 --- a/tests/integ/snowflake/ml/model/warehouse_model_integ_test_utils.py +++ b/tests/integ/snowflake/ml/model/warehouse_model_integ_test_utils.py @@ -1,8 +1,6 @@ import posixpath -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np -import numpy.typing as npt import pandas as pd from snowflake.ml.model import ( @@ -10,7 +8,6 @@ deploy_platforms, type_hints as model_types, ) -from snowflake.ml.model._signatures import snowpark_handler from snowflake.snowpark import DataFrame as SnowparkDataFrame from tests.integ.snowflake.ml.test_utils import db_manager, test_env_utils @@ -81,39 +78,3 @@ def base_test_case( if permanent_deploy: db.drop_function(function_name=function_name, args=["OBJECT"]) - - -def check_sp_df_res( - res_sp_df: SnowparkDataFrame, - expected_pd_df: pd.DataFrame, - *, - check_dtype: bool = True, - check_index_type: Union[bool, Literal["equiv"]] = "equiv", - check_column_type: Union[bool, Literal["equiv"]] = "equiv", - check_frame_type: bool = True, - check_names: bool = True, -) -> None: - res_pd_df = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(res_sp_df) - - def totuple(a: Union[npt.ArrayLike, Tuple[object], object]) -> Union[Tuple[object], object]: - try: - return tuple(totuple(i) for i in a) # type: ignore[union-attr] - except TypeError: - return a - - for df in [res_pd_df, expected_pd_df]: - for col in df.columns: - if isinstance(df[col][0], list): - df[col] = df[col].apply(tuple) - elif isinstance(df[col][0], np.ndarray): - df[col] = df[col].apply(totuple) - - pd.testing.assert_frame_equal( - res_pd_df.sort_values(by=res_pd_df.columns.tolist()).reset_index(drop=True), - expected_pd_df.sort_values(by=expected_pd_df.columns.tolist()).reset_index(drop=True), - check_dtype=check_dtype, - check_index_type=check_index_type, - check_column_type=check_column_type, - check_frame_type=check_frame_type, - check_names=check_names, - ) diff --git a/tests/integ/snowflake/ml/model/warehouse_pytorch_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_pytorch_model_integ_test.py index 1d435b76..d1ce5937 100644 --- a/tests/integ/snowflake/ml/model/warehouse_pytorch_model_integ_test.py +++ b/tests/integ/snowflake/ml/model/warehouse_pytorch_model_integ_test.py @@ -10,7 +10,11 @@ from snowflake.ml.utils import connection_params from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import db_manager, model_factory +from tests.integ.snowflake.ml.test_utils import ( + dataframe_utils, + db_manager, + model_factory, +) class TestWarehousePytorchModelINteg(parameterized.TestCase): @@ -138,7 +142,7 @@ def test_pytorch_sp( deploy_params={ "": ( {}, - lambda res: warehouse_model_integ_test_utils.check_sp_df_res(res, y_df_expected), + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), ), }, permanent_deploy=permanent_deploy, @@ -219,7 +223,7 @@ def test_torchscript_sp( deploy_params={ "": ( {}, - lambda res: warehouse_model_integ_test_utils.check_sp_df_res(res, y_df_expected), + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), ), }, permanent_deploy=permanent_deploy, diff --git a/tests/integ/snowflake/ml/model/warehouse_sklearn_xgboost_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_sklearn_xgboost_model_integ_test.py index ab467dfa..459879bf 100644 --- a/tests/integ/snowflake/ml/model/warehouse_sklearn_xgboost_model_integ_test.py +++ b/tests/integ/snowflake/ml/model/warehouse_sklearn_xgboost_model_integ_test.py @@ -1,6 +1,7 @@ import uuid from typing import Any, Callable, Dict, Optional, Tuple, Union, cast +import inflection import numpy as np import pandas as pd import xgboost @@ -11,7 +12,7 @@ from snowflake.ml.utils import connection_params from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import db_manager +from tests.integ.snowflake.ml.test_utils import dataframe_utils, db_manager class TestWarehouseSKLearnXGBoostModelInteg(parameterized.TestCase): @@ -155,6 +156,7 @@ def test_xgb( cal_data = datasets.load_breast_cancer(as_frame=True) cal_X = cal_data.data cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) regressor.fit(cal_X_train, cal_y_train) @@ -179,8 +181,9 @@ def test_xgb_sp( self, permanent_deploy: Optional[bool] = False, ) -> None: - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_data_sp_df = self._session.create_dataframe(cal_data.frame) + cal_data = datasets.load_breast_cancer(as_frame=True).frame + cal_data.columns = [inflection.parameterize(c, "_") for c in cal_data] + cal_data_sp_df = self._session.create_dataframe(cal_data) cal_data_sp_df_train, cal_data_sp_df_test = tuple(cal_data_sp_df.random_split([0.25, 0.75], seed=2568)) regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) cal_data_pd_df_train = cal_data_sp_df_train.to_pandas() @@ -202,7 +205,7 @@ def test_xgb_sp( deploy_params={ "predict": ( {}, - lambda res: warehouse_model_integ_test_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), ), }, permanent_deploy=permanent_deploy, @@ -216,6 +219,7 @@ def test_xgb_booster( cal_data = datasets.load_breast_cancer(as_frame=True) cal_X = cal_data.data cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) params = dict(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3, objective="binary:logistic") regressor = xgboost.train(params, xgboost.DMatrix(data=cal_X_train, label=cal_y_train)) @@ -239,8 +243,9 @@ def test_xgb_booster_sp( self, permanent_deploy: Optional[bool] = False, ) -> None: - cal_data = datasets.load_breast_cancer(as_frame=True) - cal_data_sp_df = self._session.create_dataframe(cal_data.frame) + cal_data = datasets.load_breast_cancer(as_frame=True).frame + cal_data.columns = [inflection.parameterize(c, "_") for c in cal_data] + cal_data_sp_df = self._session.create_dataframe(cal_data) cal_data_sp_df_train, cal_data_sp_df_test = tuple(cal_data_sp_df.random_split([0.25, 0.75], seed=2568)) cal_data_pd_df_train = cal_data_sp_df_train.to_pandas() params = dict(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3, objective="binary:logistic") @@ -267,7 +272,7 @@ def test_xgb_booster_sp( deploy_params={ "predict": ( {}, - lambda res: warehouse_model_integ_test_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), ), }, permanent_deploy=permanent_deploy, diff --git a/tests/integ/snowflake/ml/model/warehouse_tensorflow_model_integ_test.py b/tests/integ/snowflake/ml/model/warehouse_tensorflow_model_integ_test.py index d0305125..72c15b92 100644 --- a/tests/integ/snowflake/ml/model/warehouse_tensorflow_model_integ_test.py +++ b/tests/integ/snowflake/ml/model/warehouse_tensorflow_model_integ_test.py @@ -16,7 +16,11 @@ from snowflake.ml.utils import connection_params from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session from tests.integ.snowflake.ml.model import warehouse_model_integ_test_utils -from tests.integ.snowflake.ml.test_utils import db_manager, model_factory +from tests.integ.snowflake.ml.test_utils import ( + dataframe_utils, + db_manager, + model_factory, +) class SimpleModule(tf.Module): @@ -164,7 +168,7 @@ def test_tf_sp( deploy_params={ "": ( {}, - lambda res: warehouse_model_integ_test_utils.check_sp_df_res(res, y_df_expected), + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), ), }, permanent_deploy=permanent_deploy, @@ -249,7 +253,7 @@ def test_keras_sp( deploy_params={ "": ( {}, - lambda res: warehouse_model_integ_test_utils.check_sp_df_res(res, y_df_expected), + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), ), }, permanent_deploy=permanent_deploy, diff --git a/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py b/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py index 3c008849..c152748a 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py +++ b/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py @@ -60,7 +60,7 @@ def _compare_cv_results(self, cv_result_1, cv_result_2) -> None: np.testing.assert_allclose(v, cv_result_2[k], rtol=1.0e-1, atol=1.0e-2) # Do not compare the fit time - @mock.patch("snowflake.ml.modeling.model_selection.grid_search_cv.is_single_node") + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") def test_fit_and_compare_results(self, mock_is_single_node) -> None: mock_is_single_node.return_value = True # falls back to HPO implementation @@ -114,7 +114,7 @@ def test_fit_and_compare_results(self, mock_is_single_node) -> None: "estimator_kwargs": dict(seed=42), }, ) - @mock.patch("snowflake.ml.modeling.model_selection.grid_search_cv.is_single_node") + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") def test_fit_and_compare_results_distributed( self, mock_is_single_node, is_single_node, skmodel, model, params, kwargs, estimator_kwargs ) -> None: @@ -214,7 +214,7 @@ def test_fit_and_compare_results_distributed( actual_pandas_result.flatten(), sklearn_decision_function.flatten(), rtol=1.0e-1, atol=1.0e-2 ) - @mock.patch("snowflake.ml.modeling.model_selection.grid_search_cv.is_single_node") + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") def test_transform(self, mock_is_single_node) -> None: mock_is_single_node.return_value = False diff --git a/tests/integ/snowflake/ml/modeling/model_selection/randomized_search_integ_test.py b/tests/integ/snowflake/ml/modeling/model_selection/randomized_search_integ_test.py index 666b9ebb..dc6391bb 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/randomized_search_integ_test.py +++ b/tests/integ/snowflake/ml/modeling/model_selection/randomized_search_integ_test.py @@ -86,7 +86,7 @@ def _compare_cv_results(self, cv_result_1, cv_result_2) -> None: "estimator_kwargs": dict(seed=42), }, ) - @mock.patch("snowflake.ml.modeling.model_selection.randomized_search_cv.is_single_node") + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") def test_fit_and_compare_results( self, mock_is_single_node, is_single_node, skmodel, model, params, kwargs, estimator_kwargs ) -> None: @@ -191,7 +191,7 @@ def test_fit_and_compare_results( actual_pandas_result.flatten(), sklearn_decision_function.flatten(), rtol=1.0e-1, atol=1.0e-2 ) - @mock.patch("snowflake.ml.modeling.model_selection.randomized_search_cv.is_single_node") + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") def test_transform(self, mock_is_single_node) -> None: mock_is_single_node.return_value = False diff --git a/tests/integ/snowflake/ml/modeling/model_selection/search_single_node_test.py b/tests/integ/snowflake/ml/modeling/model_selection/search_single_node_test.py index f6fd1e4f..4ebb3534 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/search_single_node_test.py +++ b/tests/integ/snowflake/ml/modeling/model_selection/search_single_node_test.py @@ -17,7 +17,7 @@ def setUp(self) -> None: def tearDown(self) -> None: self._session.close() - @mock.patch("snowflake.ml.modeling.model_selection.grid_search_cv.is_single_node") + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") def test_single_node_grid(self, mock_is_single_node) -> None: mock_is_single_node.return_value = True input_df_pandas = load_iris(as_frame=True).frame @@ -43,7 +43,7 @@ def test_single_node_grid(self, mock_is_single_node) -> None: self.assertEqual(reg._sklearn_object.n_jobs, -1) - @mock.patch("snowflake.ml.modeling.model_selection.grid_search_cv.is_single_node") + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") def test_single_node_random(self, mock_is_single_node) -> None: mock_is_single_node.return_value = True input_df_pandas = load_iris(as_frame=True).frame @@ -69,7 +69,7 @@ def test_single_node_random(self, mock_is_single_node) -> None: self.assertEqual(reg._sklearn_object.n_jobs, -1) - @mock.patch("snowflake.ml.modeling.model_selection.grid_search_cv.is_single_node") + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") def test_not_single_node_grid(self, mock_is_single_node) -> None: mock_is_single_node.return_value = False input_df_pandas = load_iris(as_frame=True).frame @@ -93,7 +93,7 @@ def test_not_single_node_grid(self, mock_is_single_node) -> None: self.assertEqual(reg._sklearn_object.estimator.n_jobs, 3) - @mock.patch("snowflake.ml.modeling.model_selection.grid_search_cv.is_single_node") + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") def test_not_single_node_random(self, mock_is_single_node) -> None: mock_is_single_node.return_value = False input_df_pandas = load_iris(as_frame=True).frame diff --git a/tests/integ/snowflake/ml/registry/model/BUILD.bazel b/tests/integ/snowflake/ml/registry/model/BUILD.bazel new file mode 100644 index 00000000..4dd54706 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/BUILD.bazel @@ -0,0 +1,106 @@ +load("//bazel:py_rules.bzl", "py_library", "py_test") + +package(default_visibility = ["//tests/integ/snowflake/ml:__subpackages__"]) + +py_library( + name = "registry_model_test_base", + testonly = True, + srcs = ["registry_model_test_base.py"], + deps = [ + "//snowflake/ml/model:type_hints", + "//snowflake/ml/registry", + "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/test_utils:db_manager", + "//tests/integ/snowflake/ml/test_utils:test_env_utils", + ], +) + +py_test( + name = "registry_sklearn_model_test", + srcs = ["registry_sklearn_model_test.py"], + shard_count = 2, + deps = [ + ":registry_model_test_base", + ], +) + +py_test( + name = "registry_xgboost_model_test", + srcs = ["registry_xgboost_model_test.py"], + shard_count = 2, + deps = [ + ":registry_model_test_base", + "//tests/integ/snowflake/ml/test_utils:dataframe_utils", + ], +) + +py_test( + name = "registry_custom_model_test", + srcs = ["registry_custom_model_test.py"], + shard_count = 4, + deps = [ + ":registry_model_test_base", + "//snowflake/ml/model:custom_model", + "//tests/integ/snowflake/ml/test_utils:dataframe_utils", + ], +) + +py_test( + name = "registry_pytorch_model_test", + srcs = ["registry_pytorch_model_test.py"], + shard_count = 4, + deps = [ + ":registry_model_test_base", + "//snowflake/ml/model/_signatures:pytorch_handler", + "//snowflake/ml/model/_signatures:snowpark_handler", + "//tests/integ/snowflake/ml/test_utils:dataframe_utils", + "//tests/integ/snowflake/ml/test_utils:model_factory", + ], +) + +py_test( + name = "registry_tensorflow_model_test", + srcs = ["registry_tensorflow_model_test.py"], + shard_count = 4, + deps = [ + ":registry_model_test_base", + "//snowflake/ml/model/_signatures:numpy_handler", + "//snowflake/ml/model/_signatures:snowpark_handler", + "//snowflake/ml/model/_signatures:tensorflow_handler", + "//tests/integ/snowflake/ml/test_utils:dataframe_utils", + "//tests/integ/snowflake/ml/test_utils:model_factory", + ], +) + +py_test( + name = "registry_modeling_model_test", + srcs = ["registry_modeling_model_test.py"], + shard_count = 2, + deps = [ + ":registry_model_test_base", + "//snowflake/ml/modeling/lightgbm:lgbm_regressor", + "//snowflake/ml/modeling/linear_model:logistic_regression", + "//snowflake/ml/modeling/xgboost:xgb_regressor", + ], +) + +py_test( + name = "registry_mlflow_model_test", + srcs = ["registry_mlflow_model_test.py"], + shard_count = 2, + deps = [ + ":registry_model_test_base", + "//snowflake/ml/_internal:env", + "//snowflake/ml/model/_signatures:numpy_handler", + ], +) + +py_test( + name = "registry_huggingface_pipeline_model_test", + srcs = ["registry_huggingface_pipeline_model_test.py"], + shard_count = 6, + deps = [ + ":registry_model_test_base", + "//snowflake/ml/_internal:env_utils", + ], +) diff --git a/tests/integ/snowflake/ml/registry/model/registry_custom_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_custom_model_test.py new file mode 100644 index 00000000..8f68db1f --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/registry_custom_model_test.py @@ -0,0 +1,317 @@ +import asyncio +import os +import tempfile + +import numpy as np +import pandas as pd +from absl.testing import absltest + +from snowflake.ml.model import custom_model +from tests.integ.snowflake.ml.registry.model import registry_model_test_base +from tests.integ.snowflake.ml.test_utils import dataframe_utils + + +class DemoModel(custom_model.CustomModel): + def __init__(self, context: custom_model.ModelContext) -> None: + super().__init__(context) + + @custom_model.inference_api + def predict(self, input: pd.DataFrame) -> pd.DataFrame: + return pd.DataFrame({"output": input["c1"]}) + + +class DemoModelSPQuote(custom_model.CustomModel): + def __init__(self, context: custom_model.ModelContext) -> None: + super().__init__(context) + + @custom_model.inference_api + def predict(self, input: pd.DataFrame) -> pd.DataFrame: + return pd.DataFrame({'"output"': input['"c1"']}) + + +class DemoModelArray(custom_model.CustomModel): + def __init__(self, context: custom_model.ModelContext) -> None: + super().__init__(context) + + @custom_model.inference_api + def predict(self, input: pd.DataFrame) -> pd.DataFrame: + return pd.DataFrame({"output": input.values.tolist()}) + + +class AsyncComposeModel(custom_model.CustomModel): + def __init__(self, context: custom_model.ModelContext) -> None: + super().__init__(context) + + @custom_model.inference_api + async def predict(self, input: pd.DataFrame) -> pd.DataFrame: + res1 = await self.context.model_ref("m1").predict.async_run(input) + res_sum = res1["output"] + self.context.model_ref("m2").predict(input)["output"] + return pd.DataFrame({"output": res_sum / 2}) + + +class DemoModelWithArtifacts(custom_model.CustomModel): + def __init__(self, context: custom_model.ModelContext) -> None: + super().__init__(context) + with open(context.path("bias"), encoding="utf-8") as f: + v = int(f.read()) + self.bias = v + + @custom_model.inference_api + def predict(self, input: pd.DataFrame) -> pd.DataFrame: + return pd.DataFrame({"output": (input["c1"] + self.bias) > 12}) + + +class TestRegistryCustomModelInteg(registry_model_test_base.RegistryModelTestBase): + def test_async_model_composition( + self, + ) -> None: + async def _test(self: "TestRegistryCustomModelInteg") -> None: + arr = np.random.randint(100, size=(10000, 3)) + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + clf = DemoModel(custom_model.ModelContext()) + model_context = custom_model.ModelContext( + models={ + "m1": clf, + "m2": clf, + } + ) + acm = AsyncComposeModel(model_context) + self._test_registry_model( + model=acm, + sample_input=pd_df, + prediction_assert_fns={ + "predict": ( + pd_df, + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(arr[:, 0], columns=["output"], dtype=float), + ), + ), + }, + ) + + asyncio.get_event_loop().run_until_complete(_test(self)) + + def test_custom_demo_model_sp( + self, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + arr = [[1, 2, 3], [4, 2, 5]] + sp_df = self._session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) + y_df_expected = pd.DataFrame([[1, 2, 3, 1], [4, 2, 5, 4]], columns=["c1", "c2", "c3", "output"]) + self._test_registry_model( + model=lm, + sample_input=sp_df, + prediction_assert_fns={ + "predict": (sp_df, lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False)) + }, + ) + + def test_custom_demo_model_sp_quote( + self, + ) -> None: + lm = DemoModelSPQuote(custom_model.ModelContext()) + arr = [[1, 2, 3], [4, 2, 5]] + sp_df = self._session.create_dataframe(arr, schema=['"""c1"""', '"""c2"""', '"""c3"""']) + pd_df = pd.DataFrame(arr, columns=['"c1"', '"c2"', '"c3"']) + self._test_registry_model( + model=lm, + sample_input=sp_df, + prediction_assert_fns={ + "predict": ( + pd_df, + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame([1, 4], columns=['"output"'], dtype=np.int8), + ), + ) + }, + ) + + def test_custom_demo_model_sp_mix_1( + self, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + arr = [[1, 2, 3], [4, 2, 5]] + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + sp_df = self._session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) + y_df_expected = pd.concat([pd_df, pd_df[["c1"]].rename(columns={"c1": "output"})], axis=1) + self._test_registry_model( + model=lm, + sample_input=pd_df, + prediction_assert_fns={ + "predict": ( + sp_df, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + ) + }, + ) + + def test_custom_demo_model_sp_mix_2( + self, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + arr = [[1, 2, 3], [4, 2, 5]] + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + sp_df = self._session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) + self._test_registry_model( + model=lm, + sample_input=sp_df, + prediction_assert_fns={ + "predict": ( + pd_df, + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame([1, 4], columns=["output"], dtype=np.int8), + ), + ) + }, + ) + + def test_custom_demo_model_array( + self, + ) -> None: + lm = DemoModelArray(custom_model.ModelContext()) + arr = np.array([[1, 2, 3], [4, 2, 5]]) + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + self._test_registry_model( + model=lm, + sample_input=pd_df, + prediction_assert_fns={ + "predict": ( + pd_df, + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(data={"output": [[1, 2, 3], [4, 2, 5]]}), + ), + ) + }, + ) + + def test_custom_demo_model_str( + self, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + pd_df = pd.DataFrame([["Yogiri", "Civia", "Echo"], ["Artia", "Doris", "Rosalyn"]], columns=["c1", "c2", "c3"]) + self._test_registry_model( + model=lm, + sample_input=pd_df, + prediction_assert_fns={ + "predict": ( + pd_df, + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(data={"output": ["Yogiri", "Artia"]}), + ), + ) + }, + ) + + def test_custom_demo_model_array_sp( + self, + ) -> None: + lm = DemoModelArray(custom_model.ModelContext()) + arr = np.array([[1, 2, 3], [4, 2, 5]]) + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + sp_df = self._session.create_dataframe(pd_df) + y_df_expected = pd.concat([pd_df, pd.DataFrame(data={"output": [[1, 2, 3], [4, 2, 5]]})], axis=1) + self._test_registry_model( + model=lm, + sample_input=sp_df, + prediction_assert_fns={ + "predict": ( + sp_df, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + ) + }, + ) + + def test_custom_demo_model_str_sp( + self, + ) -> None: + lm = DemoModel(custom_model.ModelContext()) + pd_df = pd.DataFrame([["Yogiri", "Civia", "Echo"], ["Artia", "Doris", "Rosalyn"]], columns=["c1", "c2", "c3"]) + sp_df = self._session.create_dataframe(pd_df) + y_df_expected = pd.concat([pd_df, pd.DataFrame(data={"output": ["Yogiri", "Artia"]})], axis=1) + self._test_registry_model( + model=lm, + sample_input=sp_df, + prediction_assert_fns={ + "predict": ( + sp_df, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), + ) + }, + ) + + def test_custom_demo_model_array_str( + self, + ) -> None: + lm = DemoModelArray(custom_model.ModelContext()) + pd_df = pd.DataFrame([["Yogiri", "Civia", "Echo"], ["Artia", "Doris", "Rosalyn"]], columns=["c1", "c2", "c3"]) + self._test_registry_model( + model=lm, + sample_input=pd_df, + prediction_assert_fns={ + "predict": ( + pd_df, + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(data={"output": [["Yogiri", "Civia", "Echo"], ["Artia", "Doris", "Rosalyn"]]}), + ), + ) + }, + ) + + def test_custom_model_with_artifacts( + self, + ) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "bias"), "w", encoding="utf-8") as f: + f.write("10") + lm = DemoModelWithArtifacts( + custom_model.ModelContext(models={}, artifacts={"bias": os.path.join(tmpdir, "bias")}) + ) + arr = np.array([[1, 2, 3], [4, 2, 5]]) + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + self._test_registry_model( + model=lm, + sample_input=pd_df, + prediction_assert_fns={ + "predict": ( + pd_df, + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame([False, True], columns=["output"]), + ), + ) + }, + ) + + def test_custom_model_bool_sp( + self, + ) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "bias"), "w", encoding="utf-8") as f: + f.write("10") + lm = DemoModelWithArtifacts( + custom_model.ModelContext(models={}, artifacts={"bias": os.path.join(tmpdir, "bias")}) + ) + arr = np.array([[1, 2, 3], [4, 2, 5]]) + pd_df = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) + sp_df = self._session.create_dataframe(pd_df) + y_df_expected = pd.concat([pd_df, pd.DataFrame([False, True], columns=["output"])], axis=1) + self._test_registry_model( + model=lm, + sample_input=sp_df, + prediction_assert_fns={ + "predict": ( + sp_df, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + ) + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_huggingface_pipeline_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_huggingface_pipeline_model_test.py new file mode 100644 index 00000000..3e5f2870 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/registry_huggingface_pipeline_model_test.py @@ -0,0 +1,548 @@ +import json +import os +import tempfile + +import numpy as np +import pandas as pd +from absl.testing import absltest +from packaging import requirements + +from snowflake.ml._internal import env_utils +from tests.integ.snowflake.ml.registry.model import registry_model_test_base + + +class TestRegistryHuggingFacePipelineModelInteg(registry_model_test_base.RegistryModelTestBase): + @classmethod + def setUpClass(self) -> None: + self.cache_dir = tempfile.TemporaryDirectory() + self._original_cache_dir = os.getenv("TRANSFORMERS_CACHE", None) + os.environ["TRANSFORMERS_CACHE"] = self.cache_dir.name + + @classmethod + def tearDownClass(self) -> None: + if self._original_cache_dir: + os.environ["TRANSFORMERS_CACHE"] = self._original_cache_dir + self.cache_dir.cleanup() + + def test_conversational_pipeline( + self, + ) -> None: + # We have to import here due to cache location issue. + # Only by doing so can we make the cache dir setting effective. + import transformers + + model = transformers.pipeline(task="conversational", model="ToddGoldfarb/Cadet-Tiny") + + x_df = pd.DataFrame( + [ + { + "user_inputs": [ + "Do you speak French?", + "Do you know how to say Snowflake in French?", + ], + "generated_responses": ["Yes I do."], + }, + ] + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["generated_responses"])) + + for row in res["generated_responses"]: + self.assertIsInstance(row, list) + for resp in row: + self.assertIsInstance(resp, str) + + self._test_registry_model( + model=model, + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + def test_fill_mask_pipeline( + self, + ) -> None: + import transformers + + model = transformers.pipeline( + task="fill-mask", + model="sshleifer/tiny-distilroberta-base", + top_k=1, + ) + + x_df = pd.DataFrame( + [ + ["LynYuu is the of the Grand Duchy of Yu."], + ] + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["outputs"])) + + for row in res["outputs"]: + self.assertIsInstance(row, str) + resp = json.loads(row) + self.assertIsInstance(resp, list) + self.assertIn("score", resp[0]) + self.assertIn("token", resp[0]) + self.assertIn("token_str", resp[0]) + self.assertIn("sequence", resp[0]) + + self._test_registry_model( + model=model, + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + def test_ner_pipeline( + self, + ) -> None: + import transformers + + model = transformers.pipeline(task="ner", model="hf-internal-testing/tiny-bert-for-token-classification") + + x_df = pd.DataFrame( + [ + ["My name is Izumi and I live in Tokyo, Japan."], + ] + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["outputs"])) + + for row in res["outputs"]: + self.assertIsInstance(row, str) + resp = json.loads(row) + self.assertIsInstance(resp, list) + self.assertIn("entity", resp[0]) + self.assertIn("score", resp[0]) + self.assertIn("index", resp[0]) + self.assertIn("word", resp[0]) + self.assertIn("start", resp[0]) + self.assertIn("end", resp[0]) + + self._test_registry_model( + model=model, + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + def test_question_answering_pipeline( + self, + ) -> None: + import transformers + + model = transformers.pipeline( + task="question-answering", + model="sshleifer/tiny-distilbert-base-cased-distilled-squad", + top_k=1, + ) + + x_df = pd.DataFrame( + [ + { + "question": "What did Doris want to do?", + "context": ( + "Doris is a cheerful mermaid from the ocean depths. She transformed into a bipedal creature " + 'and came to see everyone because she wanted to "learn more about the world of athletics."' + " She dislikes cuisines with seafood." + ), + } + ], + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["score", "start", "end", "answer"])) + + self.assertEqual(res["score"].dtype.type, np.float64) + self.assertEqual(res["start"].dtype.type, np.int64) + self.assertEqual(res["end"].dtype.type, np.int64) + self.assertEqual(res["answer"].dtype.type, np.object_) + + self._test_registry_model( + model=model, + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + def test_question_answering_pipeline_multiple_output( + self, + ) -> None: + import transformers + + model = transformers.pipeline( + task="question-answering", + model="sshleifer/tiny-distilbert-base-cased-distilled-squad", + top_k=3, + ) + + x_df = pd.DataFrame( + [ + { + "question": "What did Doris want to do?", + "context": ( + "Doris is a cheerful mermaid from the ocean depths. She transformed into a bipedal creature " + 'and came to see everyone because she wanted to "learn more about the world of athletics."' + " She dislikes cuisines with seafood." + ), + } + ], + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["outputs"])) + + for row in res["outputs"]: + self.assertIsInstance(row, str) + resp = json.loads(row) + self.assertIsInstance(resp, list) + self.assertIn("score", resp[0]) + self.assertIn("start", resp[0]) + self.assertIn("end", resp[0]) + self.assertIn("answer", resp[0]) + + self._test_registry_model( + model=model, + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + def test_summarization_pipeline( + self, + ) -> None: + import transformers + + model = transformers.pipeline(task="summarization", model="sshleifer/tiny-mbart") + + x_df = pd.DataFrame( + [ + [ + ( + "Neuro-sama is a chatbot styled after a female VTuber that hosts live streams on the Twitch " + 'channel "vedal987". Her speech and personality are generated by an artificial intelligence' + " (AI) system which utilizes a large language model, allowing her to communicate with " + "viewers in a live chat. She was created by a computer programmer and AI-developer named " + "Jack Vedal, who decided to build upon the concept of an AI VTuber by combining interactions " + "between AI game play and a computer-generated avatar. She debuted on Twitch on December 19, " + "2022 after four years of development." + ) + ] + ], + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["summary_text"])) + + self.assertEqual(res["summary_text"].dtype.type, np.object_) + + self._test_registry_model( + model=model, + additional_dependencies=[ + str(env_utils.get_local_installed_version_of_pip_package(requirements.Requirement("sentencepiece"))) + ], + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + def test_table_question_answering_pipeline( + self, + ) -> None: + import transformers + + model = transformers.pipeline(task="table-question-answering", model="google/tapas-tiny-finetuned-wtq") + + x_df = pd.DataFrame( + [ + { + "query": "Which channel has the most subscribers?", + "table": json.dumps( + { + "Channel": [ + "A.I.Channel", + "Kaguya Luna", + "Mirai Akari", + "Siro", + ], + "Subscribers": [ + "3,020,000", + "872,000", + "694,000", + "660,000", + ], + "Videos": ["1,200", "113", "639", "1,300"], + "Created At": [ + "Jun 30 2016", + "Dec 4 2017", + "Feb 28 2014", + "Jun 23 2017", + ], + } + ), + } + ], + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["answer", "coordinates", "cells", "aggregator"])) + + self.assertEqual(res["answer"].dtype.type, np.object_) + self.assertEqual(res["coordinates"].dtype.type, np.object_) + self.assertIsInstance(res["coordinates"][0], list) + self.assertEqual(res["cells"].dtype.type, np.object_) + self.assertIsInstance(res["cells"][0], list) + self.assertEqual(res["aggregator"].dtype.type, np.object_) + + self._test_registry_model( + model=model, + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + def test_text_classification_pair_pipeline( + self, + ) -> None: + import transformers + + model = transformers.pipeline(task="text-classification", model="cross-encoder/ms-marco-MiniLM-L-12-v2") + + x_df = pd.DataFrame( + [{"text": "I like you.", "text_pair": "I love you, too."}], + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["label", "score"])) + + self.assertEqual(res["label"].dtype.type, np.object_) + self.assertEqual(res["score"].dtype.type, np.float64) + + self._test_registry_model( + model=model, + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + def test_text_classification_pipeline( + self, + ) -> None: + import transformers + + model = transformers.pipeline( + task="text-classification", + model="hf-internal-testing/tiny-random-distilbert", + top_k=1, + ) + + x_df = pd.DataFrame( + [ + { + "text": "I am wondering if I should have udon or rice for lunch", + "text_pair": "", + } + ], + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["outputs"])) + + for row in res["outputs"]: + self.assertIsInstance(row, str) + resp = json.loads(row) + self.assertIsInstance(resp, list) + self.assertIn("label", resp[0]) + self.assertIn("score", resp[0]) + + self._test_registry_model( + model=model, + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + def test_text_generation_pipeline( + self, + ) -> None: + import transformers + + model = transformers.pipeline( + task="text-generation", + model="sshleifer/tiny-ctrl", + ) + + x_df = pd.DataFrame( + [['A descendant of the Lost City of Atlantis, who swam to Earth while saying, "']], + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["outputs"])) + + for row in res["outputs"]: + self.assertIsInstance(row, str) + resp = json.loads(row) + self.assertIsInstance(resp, list) + self.assertIn("generated_text", resp[0]) + + self._test_registry_model( + model=model, + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + def test_text2text_generation_pipeline( + self, + ) -> None: + import transformers + + model = transformers.pipeline( + task="text2text-generation", + model="patrickvonplaten/t5-tiny-random", + ) + + x_df = pd.DataFrame( + [['A descendant of the Lost City of Atlantis, who swam to Earth while saying, "']], + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["generated_text"])) + self.assertEqual(res["generated_text"].dtype.type, np.object_) + + self._test_registry_model( + model=model, + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + def test_translation_pipeline( + self, + ) -> None: + import transformers + + model = transformers.pipeline(task="translation_en_to_ja", model="patrickvonplaten/t5-tiny-random") + + x_df = pd.DataFrame( + [ + [ + ( + "Snowflake's Data Cloud is powered by an advanced data platform provided as a self-managed " + "service. Snowflake enables data storage, processing, and analytic solutions that are faster, " + "easier to use, and far more flexible than traditional offerings. The Snowflake data platform " + "is not built on any existing database technology or “big data” software platforms such as " + "Hadoop. Instead, Snowflake combines a completely new SQL query engine with an innovative " + "architecture natively designed for the cloud. To the user, Snowflake provides all of the " + "functionality of an enterprise analytic database, along with many additional special features " + "and unique capabilities." + ) + ] + ], + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["translation_text"])) + self.assertEqual(res["translation_text"].dtype.type, np.object_) + + self._test_registry_model( + model=model, + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + def test_zero_shot_classification_pipeline( + self, + ) -> None: + import transformers + + model = transformers.pipeline( + task="zero-shot-classification", + model="sshleifer/tiny-distilbert-base-cased-distilled-squad", + ) + + x_df = pd.DataFrame( + [ + { + "sequences": "I have a problem with Snowflake that needs to be resolved asap!!", + "candidate_labels": ["urgent", "not urgent"], + }, + { + "sequences": "I have a problem with Snowflake that needs to be resolved asap!!", + "candidate_labels": ["English", "Japanese"], + }, + ], + ) + + def check_res(res: pd.DataFrame) -> None: + pd.testing.assert_index_equal(res.columns, pd.Index(["sequence", "labels", "scores"])) + self.assertEqual(res["sequence"].dtype.type, np.object_) + self.assertEqual( + res["sequence"][0], + "I have a problem with Snowflake that needs to be resolved asap!!", + ) + self.assertEqual( + res["sequence"][1], + "I have a problem with Snowflake that needs to be resolved asap!!", + ) + self.assertEqual(res["labels"].dtype.type, np.object_) + self.assertListEqual(sorted(res["labels"][0]), sorted(["urgent", "not urgent"])) + self.assertListEqual(sorted(res["labels"][1]), sorted(["English", "Japanese"])) + self.assertEqual(res["scores"].dtype.type, np.object_) + self.assertIsInstance(res["labels"][0], list) + self.assertIsInstance(res["labels"][1], list) + + self._test_registry_model( + model=model, + prediction_assert_fns={ + "": ( + x_df, + check_res, + ), + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_mlflow_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_mlflow_model_test.py new file mode 100644 index 00000000..1940ba94 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/registry_mlflow_model_test.py @@ -0,0 +1,114 @@ +from importlib import metadata as importlib_metadata + +import mlflow +import numpy as np +from absl.testing import absltest +from sklearn import datasets, ensemble, model_selection + +from snowflake.ml._internal import env +from snowflake.ml.model._signatures import numpy_handler +from tests.integ.snowflake.ml.registry.model import registry_model_test_base + + +class TestRegistryMLFlowModelInteg(registry_model_test_base.RegistryModelTestBase): + def test_mlflow_model_deploy_sklearn_df( + self, + ) -> None: + db = datasets.load_diabetes(as_frame=True) + X_train, X_test, y_train, y_test = model_selection.train_test_split(db.data, db.target) + with mlflow.start_run() as run: + rf = ensemble.RandomForestRegressor(n_estimators=100, max_depth=6, max_features=3) + rf.fit(X_train, y_train) + + # Use the model to make predictions on the test dataset. + predictions = rf.predict(X_test) + signature = mlflow.models.signature.infer_signature(X_test, predictions) + mlflow.sklearn.log_model( + rf, + "model", + signature=signature, + metadata={"author": "halu", "version": "1"}, + conda_env={ + "dependencies": [f"python=={env.PYTHON_VERSION}"] + + list( + map( + lambda pkg: f"{pkg}=={importlib_metadata.distribution(pkg).version}", + [ + "mlflow", + "cloudpickle", + "numpy", + "scikit-learn", + "scipy", + "typing-extensions", + ], + ) + ), + "name": "mlflow-env", + }, + ) + + run_id = run.info.run_id + + self._test_registry_model( + model=mlflow.pyfunc.load_model(f"runs:/{run_id}/model"), + prediction_assert_fns={ + "predict": ( + X_test, + lambda res: np.testing.assert_allclose(np.expand_dims(predictions, axis=1), res.to_numpy()), + ), + }, + ) + + def test_mlflow_model_deploy_sklearn( + self, + ) -> None: + db = datasets.load_diabetes() + X_train, X_test, y_train, y_test = model_selection.train_test_split(db.data, db.target) + with mlflow.start_run() as run: + rf = ensemble.RandomForestRegressor(n_estimators=100, max_depth=6, max_features=3) + rf.fit(X_train, y_train) + + # Use the model to make predictions on the test dataset. + predictions = rf.predict(X_test) + signature = mlflow.models.signature.infer_signature(X_test, predictions) + mlflow.sklearn.log_model( + rf, + "model", + signature=signature, + metadata={"author": "halu", "version": "1"}, + conda_env={ + "dependencies": [f"python=={env.PYTHON_VERSION}"] + + list( + map( + lambda pkg: f"{pkg}=={importlib_metadata.distribution(pkg).version}", + [ + "mlflow", + "cloudpickle", + "numpy", + "scikit-learn", + "scipy", + "typing-extensions", + ], + ) + ), + "name": "mlflow-env", + }, + ) + + run_id = run.info.run_id + + X_test_df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df([X_test]) + + self._test_registry_model( + model=mlflow.pyfunc.load_model(f"runs:/{run_id}/model"), + prediction_assert_fns={ + "predict": ( + X_test_df, + lambda res: np.testing.assert_allclose(np.expand_dims(predictions, axis=1), res.to_numpy()), + ), + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py b/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py new file mode 100644 index 00000000..27311c9f --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py @@ -0,0 +1,86 @@ +import inspect +import unittest +import uuid +from typing import Any, Callable, Dict, List, Optional, Tuple + +from absl.testing import absltest +from packaging import version + +from snowflake.ml.model import type_hints as model_types +from snowflake.ml.registry import registry +from snowflake.ml.utils import connection_params +from snowflake.snowpark import Session +from tests.integ.snowflake.ml.test_utils import db_manager, test_env_utils + + +class RegistryModelTestBase(absltest.TestCase): + def setUp(self) -> None: + """Creates Snowpark and Snowflake environments for testing.""" + login_options = connection_params.SnowflakeLoginOptions() + + self._run_id = uuid.uuid4().hex + self._test_db = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self._run_id, "db").upper() + self._test_schema = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( + self._run_id, "schema" + ).upper() + + self._session = Session.builder.configs( + { + **login_options, + **{"database": self._test_db, "schema": self._test_schema}, + } + ).create() + + current_sf_version = test_env_utils.get_current_snowflake_version(self._session) + + if current_sf_version < version.parse("8.0.0"): + raise unittest.SkipTest("This test requires Snowflake Version 8.0.0 or higher.") + + self._db_manager = db_manager.DBManager(self._session) + self._db_manager.create_database(self._test_db) + self._db_manager.create_schema(self._test_schema) + self._db_manager.cleanup_databases(expire_hours=6) + self.registry = registry.Registry(self._session) + + def tearDown(self) -> None: + self._db_manager.drop_database(self._test_db) + self._session.close() + + def _test_registry_model( + self, + model: model_types.SupportedModelType, + prediction_assert_fns: Dict[str, Tuple[Any, Callable[[Any], Any]]], + sample_input: Optional[model_types.SupportedDataType] = None, + additional_dependencies: Optional[List[str]] = None, + options: Optional[model_types.ModelSaveOption] = None, + ) -> None: + conda_dependencies = [ + test_env_utils.get_latest_package_version_spec_in_server(self._session, "snowflake-snowpark-python") + ] + if additional_dependencies: + conda_dependencies.extend(additional_dependencies) + + # Get the name of the caller as the model name + name = f"model_{inspect.stack()[1].function}" + version = f"ver_{self._run_id}" + mv = self.registry.log_model( + model=model, + model_name=name, + version_name=version, + sample_input_data=sample_input, + conda_dependencies=conda_dependencies, + options=options, + ) + + for target_method, (test_input, check_func) in prediction_assert_fns.items(): + res = mv.run(test_input, method_name=target_method) + + check_func(res) + + self.registry.delete_model(model_name=name) + + self.assertNotIn(mv.model_name, [m.name for m in self.registry.list_models()]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py new file mode 100644 index 00000000..626584af --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py @@ -0,0 +1,89 @@ +import numpy as np +from absl.testing import absltest +from sklearn import datasets + +from snowflake.ml.modeling.lightgbm import LGBMRegressor +from snowflake.ml.modeling.linear_model import LogisticRegression +from snowflake.ml.modeling.xgboost import XGBRegressor +from tests.integ.snowflake.ml.registry.model import registry_model_test_base + + +class TestRegistryModelingModelInteg(registry_model_test_base.RegistryModelTestBase): + def test_snowml_model_deploy_snowml_sklearn( + self, + ) -> None: + iris_X = datasets.load_iris(as_frame=True).frame + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + OUTPUT_COLUMNS = "PREDICTED_TARGET" + regr = LogisticRegression(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X + regr.fit(test_features) + + self._test_registry_model( + model=regr, + prediction_assert_fns={ + "predict": ( + test_features, + lambda res: lambda res: np.testing.assert_allclose( + res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values + ), + ), + }, + ) + + def test_snowml_model_deploy_xgboost( + self, + ) -> None: + iris_X = datasets.load_iris(as_frame=True).frame + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + OUTPUT_COLUMNS = "PREDICTED_TARGET" + regr = XGBRegressor(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X[:10] + regr.fit(test_features) + + self._test_registry_model( + model=regr, + prediction_assert_fns={ + "predict": ( + test_features, + lambda res: np.testing.assert_allclose( + res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values + ), + ), + }, + ) + + def test_snowml_model_deploy_lightgbm( + self, + ) -> None: + iris_X = datasets.load_iris(as_frame=True).frame + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + OUTPUT_COLUMNS = "PREDICTED_TARGET" + regr = LGBMRegressor(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + test_features = iris_X[:10] + regr.fit(test_features) + + self._test_registry_model( + model=regr, + prediction_assert_fns={ + "predict": ( + test_features, + lambda res: np.testing.assert_allclose( + res[OUTPUT_COLUMNS].values, regr.predict(test_features)[OUTPUT_COLUMNS].values + ), + ), + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_pytorch_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_pytorch_model_test.py new file mode 100644 index 00000000..91410c84 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/registry_pytorch_model_test.py @@ -0,0 +1,142 @@ +import pandas as pd +import torch +from absl.testing import absltest + +from snowflake.ml.model._signatures import pytorch_handler, snowpark_handler +from tests.integ.snowflake.ml.registry.model import registry_model_test_base +from tests.integ.snowflake.ml.test_utils import dataframe_utils, model_factory + + +class TestRegistryPytorchModelInteg(registry_model_test_base.RegistryModelTestBase): + def test_pytorch_tensor_as_sample( + self, + ) -> None: + model, data_x, data_y = model_factory.ModelFactory.prepare_torch_model() + x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) + y_pred = model.forward(data_x).detach() + + self._test_registry_model( + model=model, + sample_input=[data_x], + prediction_assert_fns={ + "": ( + x_df, + lambda res: torch.testing.assert_close( + pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(res)[0], y_pred, check_dtype=False + ), + ), + }, + ) + + def test_pytorch_df_as_sample( + self, + ) -> None: + model, data_x, data_y = model_factory.ModelFactory.prepare_torch_model(torch.float64) + x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) + y_pred = model.forward(data_x).detach() + + self._test_registry_model( + model=model, + sample_input=x_df, + prediction_assert_fns={ + "": ( + x_df, + lambda res: torch.testing.assert_close( + pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(res)[0], y_pred + ), + ), + }, + ) + + def test_pytorch_sp( + self, + ) -> None: + model, data_x, data_y = model_factory.ModelFactory.prepare_torch_model(torch.float64) + x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) + x_df.columns = ["col_0"] + y_pred = model.forward(data_x) + x_df_sp = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(self._session, x_df) + y_pred_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([y_pred]) + y_pred_df.columns = ["output_feature_0"] + y_df_expected = pd.concat([x_df, y_pred_df], axis=1) + + self._test_registry_model( + model=model, + sample_input=x_df, + prediction_assert_fns={ + "": ( + x_df_sp, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), + ), + }, + ) + + def test_torchscript_tensor_as_sample( + self, + ) -> None: + model, data_x, data_y = model_factory.ModelFactory.prepare_jittable_torch_model() + x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) + model_script = torch.jit.script(model) # type:ignore[attr-defined] + y_pred = model_script.forward(data_x).detach() + + self._test_registry_model( + model=model_script, + sample_input=[data_x], + prediction_assert_fns={ + "": ( + x_df, + lambda res: torch.testing.assert_close( + pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(res)[0], y_pred, check_dtype=False + ), + ), + }, + ) + + def test_torchscript_df_as_sample( + self, + ) -> None: + model, data_x, data_y = model_factory.ModelFactory.prepare_jittable_torch_model(torch.float64) + x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) + model_script = torch.jit.script(model) # type:ignore[attr-defined] + y_pred = model_script.forward(data_x).detach() + + self._test_registry_model( + model=model_script, + sample_input=x_df, + prediction_assert_fns={ + "": ( + x_df, + lambda res: torch.testing.assert_close( + pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(res)[0], y_pred + ), + ), + }, + ) + + def test_torchscript_sp( + self, + ) -> None: + model, data_x, data_y = model_factory.ModelFactory.prepare_jittable_torch_model(torch.float64) + x_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([data_x], ensure_serializable=False) + x_df.columns = ["col_0"] + model_script = torch.jit.script(model) # type:ignore[attr-defined] + y_pred = model_script.forward(data_x) + x_df_sp = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(self._session, x_df) + y_pred_df = pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df([y_pred]) + y_pred_df.columns = ["output_feature_0"] + y_df_expected = pd.concat([x_df, y_pred_df], axis=1) + + self._test_registry_model( + model=model_script, + sample_input=x_df, + prediction_assert_fns={ + "": ( + x_df_sp, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), + ), + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py new file mode 100644 index 00000000..ebb237e5 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py @@ -0,0 +1,88 @@ +from typing import cast + +import numpy as np +import pandas as pd +from absl.testing import absltest +from sklearn import datasets, ensemble, linear_model, multioutput + +from tests.integ.snowflake.ml.registry.model import registry_model_test_base + + +class TestRegistrySKLearnModelInteg(registry_model_test_base.RegistryModelTestBase): + def test_skl_model( + self, + ) -> None: + iris_X, iris_y = datasets.load_iris(return_X_y=True) + # LogisticRegression is for classfication task, such as iris + regr = linear_model.LogisticRegression() + regr.fit(iris_X, iris_y) + self._test_registry_model( + model=regr, + sample_input=iris_X, + prediction_assert_fns={ + "predict": ( + iris_X, + lambda res: np.testing.assert_allclose(res["output_feature_0"].values, regr.predict(iris_X)), + ), + "predict_proba": ( + iris_X[:10], + lambda res: np.testing.assert_allclose(res.values, regr.predict_proba(iris_X[:10])), + ), + }, + ) + + def test_skl_model_case_sensitive( + self, + ) -> None: + iris_X, iris_y = datasets.load_iris(return_X_y=True) + # LogisticRegression is for classfication task, such as iris + regr = linear_model.LogisticRegression() + regr.fit(iris_X, iris_y) + self._test_registry_model( + model=regr, + sample_input=iris_X, + options={ + "method_options": {"predict": {"case_sensitive": True}, "predict_proba": {"case_sensitive": True}}, + "target_methods": ["predict", "predict_proba"], + }, + prediction_assert_fns={ + '"predict"': ( + iris_X, + lambda res: np.testing.assert_allclose(res["output_feature_0"].values, regr.predict(iris_X)), + ), + '"predict_proba"': ( + iris_X[:10], + lambda res: np.testing.assert_allclose(res.values, regr.predict_proba(iris_X[:10])), + ), + }, + ) + + def test_skl_multiple_output_model( + self, + ) -> None: + iris_X, iris_y = datasets.load_iris(return_X_y=True) + target2 = np.random.randint(0, 6, size=iris_y.shape) + dual_target = np.vstack([iris_y, target2]).T + model = multioutput.MultiOutputClassifier(ensemble.RandomForestClassifier(random_state=42)) + model.fit(iris_X[:10], dual_target[:10]) + self._test_registry_model( + model=model, + sample_input=iris_X, + prediction_assert_fns={ + "predict": ( + iris_X[-10:], + lambda res: np.testing.assert_allclose(res.to_numpy(), model.predict(iris_X[-10:])), + ), + "predict_proba": ( + iris_X[-10:], + lambda res: np.testing.assert_allclose( + np.hstack([np.array(res[col].to_list()) for col in cast(pd.DataFrame, res)]), + np.hstack(model.predict_proba(iris_X[-10:])), + ), + ), + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_tensorflow_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_tensorflow_model_test.py new file mode 100644 index 00000000..48398502 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/registry_tensorflow_model_test.py @@ -0,0 +1,170 @@ +from typing import Optional + +import numpy as np +import pandas as pd +import tensorflow as tf +from absl.testing import absltest + +from snowflake.ml.model._signatures import ( + numpy_handler, + snowpark_handler, + tensorflow_handler, +) +from tests.integ.snowflake.ml.registry.model import registry_model_test_base +from tests.integ.snowflake.ml.test_utils import dataframe_utils, model_factory + + +class SimpleModule(tf.Module): + def __init__(self, name: Optional[str] = None) -> None: + super().__init__(name=name) + self.a_variable = tf.Variable(5.0, name="train_me") + self.non_trainable_variable = tf.Variable(5.0, trainable=False, name="do_not_train_me") + + @tf.function(input_signature=[tf.TensorSpec(shape=(None, 1), dtype=tf.float32)]) # type: ignore[misc] + def __call__(self, tensor: tf.Tensor) -> tf.Tensor: + return self.a_variable * tensor + self.non_trainable_variable + + +class TestRegistryTensorflowModelInteg(registry_model_test_base.RegistryModelTestBase): + def test_tf_tensor_as_sample( + self, + ) -> None: + model = SimpleModule(name="simple") + data_x = tf.constant([[5.0], [10.0]]) + x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) + y_pred = model(data_x) + + self._test_registry_model( + model=model, + sample_input=[data_x], + prediction_assert_fns={ + "": ( + x_df, + lambda res: np.testing.assert_allclose( + tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(res)[0].numpy(), + y_pred.numpy(), + ), + ), + }, + ) + + def test_tf_df_as_sample( + self, + ) -> None: + model = SimpleModule(name="simple") + data_x = tf.constant([[5.0], [10.0]]) + x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) + y_pred = model(data_x) + + self._test_registry_model( + model=model, + sample_input=x_df, + prediction_assert_fns={ + "": ( + x_df, + lambda res: np.testing.assert_allclose( + tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(res)[0].numpy(), + y_pred.numpy(), + ), + ), + }, + ) + + def test_tf_sp( + self, + ) -> None: + model = SimpleModule(name="simple") + data_x = tf.constant([[5.0], [10.0]]) + x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) + x_df.columns = ["col_0"] + y_pred = model(data_x) + x_df_sp = snowpark_handler.SnowparkDataFrameHandler.convert_from_df( + self._session, + x_df, + ) + y_pred_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([y_pred]) + y_pred_df.columns = ["output_feature_0"] + y_df_expected = pd.concat([x_df, y_pred_df], axis=1) + + self._test_registry_model( + model=model, + sample_input=x_df, + prediction_assert_fns={ + "": ( + x_df_sp, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), + ), + }, + ) + + def test_keras_tensor_as_sample( + self, + ) -> None: + model, data_x, data_y = model_factory.ModelFactory.prepare_keras_model() + x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) + y_pred = model.predict(data_x) + self._test_registry_model( + model=model, + sample_input=[data_x], + prediction_assert_fns={ + "": ( + x_df, + lambda res: np.testing.assert_allclose( + tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(res)[0].numpy(), + y_pred, + atol=1e-6, + ), + ), + }, + ) + + def test_keras_df_as_sample( + self, + ) -> None: + model, data_x, data_y = model_factory.ModelFactory.prepare_keras_model() + x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) + y_pred = model.predict(data_x) + self._test_registry_model( + model=model, + sample_input=x_df, + prediction_assert_fns={ + "": ( + x_df, + lambda res: np.testing.assert_allclose( + tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(res)[0].numpy(), + y_pred, + atol=1e-6, + ), + ), + }, + ) + + def test_keras_sp( + self, + ) -> None: + model, data_x, data_y = model_factory.ModelFactory.prepare_keras_model() + x_df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df([data_x], ensure_serializable=False) + x_df.columns = ["col_0"] + y_pred = model.predict(data_x) + x_df_sp = snowpark_handler.SnowparkDataFrameHandler.convert_from_df( + self._session, + x_df, + ) + y_pred_df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df([y_pred]) + y_pred_df.columns = ["output_feature_0"] + y_df_expected = pd.concat([x_df, y_pred_df], axis=1) + + self._test_registry_model( + model=model, + sample_input=x_df, + prediction_assert_fns={ + "": ( + x_df_sp, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected), + ), + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py new file mode 100644 index 00000000..a97b0cf4 --- /dev/null +++ b/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py @@ -0,0 +1,125 @@ +import inflection +import numpy as np +import pandas as pd +import xgboost +from absl.testing import absltest +from sklearn import datasets, model_selection + +from tests.integ.snowflake.ml.registry.model import registry_model_test_base +from tests.integ.snowflake.ml.test_utils import dataframe_utils + + +class TestRegistryXGBoostModelInteg(registry_model_test_base.RegistryModelTestBase): + def test_xgb( + self, + ) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True) + cal_X = cal_data.data + cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) + regressor.fit(cal_X_train, cal_y_train) + self._test_registry_model( + model=regressor, + sample_input=cal_X_test, + prediction_assert_fns={ + "predict": ( + cal_X_test, + lambda res: np.testing.assert_allclose( + res.values, np.expand_dims(regressor.predict(cal_X_test), axis=1) + ), + ), + }, + ) + + def test_xgb_sp( + self, + ) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True).frame + cal_data.columns = [inflection.parameterize(c, "_") for c in cal_data] + cal_data_sp_df = self._session.create_dataframe(cal_data) + cal_data_sp_df_train, cal_data_sp_df_test = tuple(cal_data_sp_df.random_split([0.25, 0.75], seed=2568)) + regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) + cal_data_pd_df_train = cal_data_sp_df_train.to_pandas() + regressor.fit(cal_data_pd_df_train.drop(columns=["target"]), cal_data_pd_df_train["target"]) + cal_data_sp_df_test_X = cal_data_sp_df_test.drop('"target"') + + y_df_expected = pd.concat( + [ + cal_data_sp_df_test_X.to_pandas(), + pd.DataFrame(regressor.predict(cal_data_sp_df_test_X.to_pandas()), columns=["output_feature_0"]), + ], + axis=1, + ) + self._test_registry_model( + model=regressor, + sample_input=cal_data_sp_df_train.drop('"target"'), + prediction_assert_fns={ + "predict": ( + cal_data_sp_df_test_X, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + ), + }, + ) + + def test_xgb_booster( + self, + ) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True) + cal_X = cal_data.data + cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + params = dict(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3, objective="binary:logistic") + regressor = xgboost.train(params, xgboost.DMatrix(data=cal_X_train, label=cal_y_train)) + y_pred = regressor.predict(xgboost.DMatrix(data=cal_X_test)) + self._test_registry_model( + model=regressor, + sample_input=cal_X_test, + prediction_assert_fns={ + "predict": ( + cal_X_test, + lambda res: np.testing.assert_allclose(res.values, np.expand_dims(y_pred, axis=1), rtol=1e-6), + ), + }, + ) + + def test_xgb_booster_sp( + self, + ) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True).frame + cal_data.columns = [inflection.parameterize(c, "_") for c in cal_data] + cal_data_sp_df = self._session.create_dataframe(cal_data) + cal_data_sp_df_train, cal_data_sp_df_test = tuple(cal_data_sp_df.random_split([0.25, 0.75], seed=2568)) + cal_data_pd_df_train = cal_data_sp_df_train.to_pandas() + params = dict(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3, objective="binary:logistic") + regressor = xgboost.train( + params, + xgboost.DMatrix(data=cal_data_pd_df_train.drop(columns=["target"]), label=cal_data_pd_df_train["target"]), + ) + cal_data_sp_df_test_X = cal_data_sp_df_test.drop('"target"') + y_df_expected = pd.concat( + [ + cal_data_sp_df_test_X.to_pandas(), + pd.DataFrame( + regressor.predict(xgboost.DMatrix(data=cal_data_sp_df_test_X.to_pandas())), + columns=["output_feature_0"], + ), + ], + axis=1, + ) + self._test_registry_model( + model=regressor, + sample_input=cal_data_sp_df_train.drop('"target"'), + prediction_assert_fns={ + "predict": ( + cal_data_sp_df_test_X, + lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False), + ), + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model_registry_snowservice_merge_gate_integ_test.py b/tests/integ/snowflake/ml/registry/model_registry_snowservice_merge_gate_integ_test.py index 8fe11e65..bc01667c 100644 --- a/tests/integ/snowflake/ml/registry/model_registry_snowservice_merge_gate_integ_test.py +++ b/tests/integ/snowflake/ml/registry/model_registry_snowservice_merge_gate_integ_test.py @@ -29,7 +29,6 @@ def _run_deployment() -> None: "enable_remote_image_build": True, }, }, - omit_target_method_when_deploy=True, ) # First deployment @@ -60,7 +59,6 @@ def _run_deployment() -> None: "model_in_image": True, }, }, - omit_target_method_when_deploy=True, ) _run_deployment() diff --git a/tests/integ/snowflake/ml/test_utils/BUILD.bazel b/tests/integ/snowflake/ml/test_utils/BUILD.bazel index f6aa7920..e55d4da4 100644 --- a/tests/integ/snowflake/ml/test_utils/BUILD.bazel +++ b/tests/integ/snowflake/ml/test_utils/BUILD.bazel @@ -30,6 +30,15 @@ py_library( ], ) +py_library( + name = "dataframe_utils", + testonly = True, + srcs = ["dataframe_utils.py"], + deps = [ + "//snowflake/ml/model/_signatures:snowpark_handler", + ], +) + py_library( name = "common_test_base", testonly = True, @@ -38,7 +47,7 @@ py_library( ], deps = [ ":_snowml_requirements", - ":test_env_utils", + "//snowflake/ml/_internal:env_utils", "//snowflake/ml/_internal:file_utils", "//snowflake/ml/utils:connection_params", ], @@ -64,6 +73,7 @@ py_library( srcs = ["test_env_utils.py"], deps = [ "//snowflake/ml/_internal:env", + "//snowflake/ml/_internal:env_utils", "//snowflake/ml/_internal/utils:query_result_checker", ], ) diff --git a/tests/integ/snowflake/ml/test_utils/common_test_base.py b/tests/integ/snowflake/ml/test_utils/common_test_base.py index d92460fa..5695c2b2 100644 --- a/tests/integ/snowflake/ml/test_utils/common_test_base.py +++ b/tests/integ/snowflake/ml/test_utils/common_test_base.py @@ -6,13 +6,14 @@ import cloudpickle from absl.testing import absltest, parameterized +from packaging import requirements from typing_extensions import Concatenate, ParamSpec -from snowflake.ml._internal import file_utils +from snowflake.ml._internal import env_utils, file_utils from snowflake.ml.utils import connection_params from snowflake.snowpark import functions as F, session from snowflake.snowpark._internal import udf_utils, utils as snowpark_utils -from tests.integ.snowflake.ml.test_utils import _snowml_requirements, test_env_utils +from tests.integ.snowflake.ml.test_utils import _snowml_requirements _V = TypeVar("_V", bound="CommonTestBase") _T_args = ParamSpec("_T_args") @@ -40,9 +41,9 @@ class CommonTestBase(parameterized.TestCase): def setUp(self) -> None: """Creates Snowpark and Snowflake environments for testing.""" self.session = ( - session.Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - if not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call] # - else session._get_active_session() + session._get_active_session() + if snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call] # + else session.Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() ) def tearDown(self) -> None: @@ -108,7 +109,7 @@ def _in_sproc_test(execute_as: Literal["owner", "caller"] = "owner") -> None: req for req in _snowml_requirements.REQUIREMENTS # Remove "_" not in req once Snowpark 1.11.0 available, it is a workaround for their bug. - if "snowflake-connector-python" not in req and "_" not in req + if not any(offending in req for offending in ["snowflake-connector-python", "pyarrow", "_"]) ] cloudpickle.register_pickle_by_value(test_module) @@ -242,7 +243,9 @@ def {func_name}({first_arg_name}: snowflake.snowpark.Session, {", ".join(arg_lis additional_cases = [ {"_snowml_pkg_ver": pkg_ver} - for pkg_ver in test_env_utils.get_package_versions_in_conda(f"snowflake-ml-python{version_range}") + for pkg_ver in env_utils.get_matched_package_versions_in_snowflake_conda_channel( + req=requirements.Requirement(f"snowflake-ml-python{version_range}") + ) ] modified_test_cases = [{**t1, **t2} for t1 in test_cases for t2 in additional_cases] diff --git a/tests/integ/snowflake/ml/test_utils/dataframe_utils.py b/tests/integ/snowflake/ml/test_utils/dataframe_utils.py new file mode 100644 index 00000000..1cf425a6 --- /dev/null +++ b/tests/integ/snowflake/ml/test_utils/dataframe_utils.py @@ -0,0 +1,44 @@ +from typing import Literal, Tuple, Union + +import numpy as np +import numpy.typing as npt +import pandas as pd + +from snowflake.ml.model._signatures import snowpark_handler +from snowflake.snowpark import DataFrame as SnowparkDataFrame + + +def check_sp_df_res( + res_sp_df: SnowparkDataFrame, + expected_pd_df: pd.DataFrame, + *, + check_dtype: bool = True, + check_index_type: Union[bool, Literal["equiv"]] = "equiv", + check_column_type: Union[bool, Literal["equiv"]] = "equiv", + check_frame_type: bool = True, + check_names: bool = True, +) -> None: + res_pd_df = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(res_sp_df) + + def totuple(a: Union[npt.ArrayLike, Tuple[object], object]) -> Union[Tuple[object], object]: + try: + return tuple(totuple(i) for i in a) # type: ignore[union-attr] + except TypeError: + return a + + for df in [res_pd_df, expected_pd_df]: + for col in df.columns: + if isinstance(df[col][0], list): + df[col] = df[col].apply(tuple) + elif isinstance(df[col][0], np.ndarray): + df[col] = df[col].apply(totuple) + + pd.testing.assert_frame_equal( + res_pd_df.sort_values(by=res_pd_df.columns.tolist()).reset_index(drop=True), + expected_pd_df.sort_values(by=expected_pd_df.columns.tolist()).reset_index(drop=True), + check_dtype=check_dtype, + check_index_type=check_index_type, + check_column_type=check_column_type, + check_frame_type=check_frame_type, + check_names=check_names, + ) diff --git a/tests/integ/snowflake/ml/test_utils/test_env_utils.py b/tests/integ/snowflake/ml/test_utils/test_env_utils.py index 248e6987..f84c241a 100644 --- a/tests/integ/snowflake/ml/test_utils/test_env_utils.py +++ b/tests/integ/snowflake/ml/test_utils/test_env_utils.py @@ -2,15 +2,23 @@ import textwrap from typing import List -import requests from packaging import requirements, version import snowflake.connector -from snowflake.ml._internal import env +from snowflake.ml._internal import env, env_utils from snowflake.ml._internal.utils import query_result_checker from snowflake.snowpark import session +def get_current_snowflake_version(session: session.Session) -> version.Version: + res = session.sql("SELECT CURRENT_VERSION() AS CURRENT_VERSION").collect()[0] + version_str = res.CURRENT_VERSION + assert isinstance(version_str, str) + + version_str = "+".join(version_str.split()) + return version.parse(version_str) + + @functools.lru_cache def get_package_versions_in_server( session: session.Session, @@ -62,47 +70,12 @@ def get_latest_package_version_spec_in_server( return f"{package_req.name}=={max(available_version_list)}" -@functools.lru_cache -def get_package_versions_in_conda( - package_req_str: str, python_version: str = env.PYTHON_VERSION -) -> List[version.Version]: - package_req = requirements.Requirement(package_req_str) - repodata_url = "https://repo.anaconda.com/pkgs/snowflake/linux-64/repodata.json" - - parsed_python_version = version.Version(python_version) - python_version_build_str = f"py{parsed_python_version.major}{parsed_python_version.minor}" - - max_retry = 3 - - exc_list = [] - - while max_retry > 0: - try: - version_list = [] - repodata = requests.get(repodata_url).json() - assert isinstance(repodata, dict) - packages_info = repodata["packages"] - assert isinstance(packages_info, dict) - for package_info in packages_info.values(): - if package_info["name"] == package_req.name and python_version_build_str in package_info["build"]: - version_list.append(version.parse(package_info["version"])) - available_version_list = list(package_req.specifier.filter(version_list)) - return available_version_list - except Exception as e: - max_retry -= 1 - exc_list.append(e) - - raise RuntimeError( - f"Failed to get latest version of package {package_req} in Snowflake Anaconda Channel. " - + "Exceptions are " - + ", ".join(map(str, exc_list)) - ) - - @functools.lru_cache def get_latest_package_version_spec_in_conda(package_req_str: str, python_version: str = env.PYTHON_VERSION) -> str: package_req = requirements.Requirement(package_req_str) - available_version_list = get_package_versions_in_conda(package_req_str, python_version) + available_version_list = env_utils.get_matched_package_versions_in_snowflake_conda_channel( + req=requirements.Requirement(package_req_str), python_version=python_version + ) if len(available_version_list) == 0: return str(package_req) return f"{package_req.name}=={max(available_version_list)}"