diff --git a/CHANGELOG.md b/CHANGELOG.md index 9094b250..9d4b8780 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,67 @@ # Release History -## 1.4.1 +## 1.5.0 + +### Bug Fixes + +- Registry: Fix invalid parameter 'SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL' error. + +### Behavior Changes + +- Model Development: The behavior of `fit_transform` for all estimators is changed. + Firstly, it will cover all the estimator that contains this function, + secondly, the output would be the union of pandas DataFrame and snowpark DataFrame. + +#### Model Registry (PrPr) + +`snowflake.ml.registry.artifact` and related `snowflake.ml.model_registry.ModelRegistry` APIs have been removed. + +- Removed `snowflake.ml.registry.artifact` module. +- Removed `ModelRegistry.log_artifact()`, `ModelRegistry.list_artifacts()`, `ModelRegistry.get_artifact()` +- Removed `artifacts` argument from `ModelRegistry.log_model()` + +#### Dataset (PrPr) + +`snowflake.ml.dataset.Dataset` has been redesigned to be backed by Snowflake Dataset entities. + +- New `Dataset`s can be created with `Dataset.create()` and existing `Dataset`s may be loaded + with `Dataset.load()`. +- `Dataset`s now maintain an immutable `selected_version` state. The `Dataset.create_version()` and + `Dataset.load_version()` APIs return new `Dataset` objects with the requested `selected_version` state. +- Added `dataset.create_from_dataframe()` and `dataset.load_dataset()` convenience APIs as a shortcut + to creating and loading `Dataset`s with a pre-selected version. +- `Dataset.materialized_table` and `Dataset.snapshot_table` no longer exist with `Dataset.fully_qualified_name` + as the closest equivalent. +- `Dataset.df` no longer exists. Instead, use `DatasetReader.read.to_snowpark_dataframe()`. +- `Dataset.owner` has been moved to `Dataset.selected_version.owner` +- `Dataset.desc` has been moved to `DatasetVersion.selected_version.comment` +- `Dataset.timestamp_col`, `Dataset.label_cols`, `Dataset.feature_store_metadata`, and + `Dataset.schema_version` have been removed. + +#### Feature Store (PrPr) + +`FeatureStore.generate_dataset` argument list has been changed to match the new +`snowflake.ml.dataset.Dataset` definition + +- `materialized_table` has been removed and replaced with `name` and `version`. +- `name` moved to first positional argument +- `save_mode` has been removed as `merge` behavior is no longer supported. The new behavior is always `errorifexists`. + +### New Features + +- Registry: Add `export` method to `ModelVersion` instance to export model files. +- Registry: Add `load` method to `ModelVersion` instance to load the underlying object from the model. +- Registry: Add `Model.rename` method to `Model` instance to rename or move a model. + +#### Dataset (PrPr) + +- Added Snowpark DataFrame integration using `Dataset.read.to_snowpark_dataframe()` +- Added Pandas DataFrame integration using `Dataset.read.to_pandas()` +- Added PyTorch and TensorFlow integrations using `Dataset.read.to_torch_datapipe()` + and `Dataset.read.to_tf_dataset()` respectively. +- Added `fsspec` style file integration using `Dataset.read.files()` and `Dataset.read.filesystem()` + +## 1.4.1 (2024-04-18) ### New Features diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index 6bdaef43..997e0f05 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -17,7 +17,7 @@ build: noarch: python package: name: snowflake-ml-python - version: 1.4.1 + version: 1.5.0 requirements: build: - python diff --git a/codegen/sklearn_wrapper_generator.py b/codegen/sklearn_wrapper_generator.py index 64d53c2e..b30b75ab 100644 --- a/codegen/sklearn_wrapper_generator.py +++ b/codegen/sklearn_wrapper_generator.py @@ -13,6 +13,7 @@ NP_CONSTANTS = [c for c in dir(np) if type(getattr(np, c, None)) == float or type(getattr(np, c, None)) == int] LOAD_BREAST_CANCER = "load_breast_cancer" LOAD_IRIS = "load_iris" +LOAD_DIGITS = "load_digits" LOAD_DIABETES = "load_diabetes" @@ -278,6 +279,7 @@ def _is_deterministic(class_object: Tuple[str, type]) -> bool: return not ( WrapperGeneratorFactory._is_class_of_type(class_object[1], "LinearDiscriminantAnalysis") or WrapperGeneratorFactory._is_class_of_type(class_object[1], "BernoulliRBM") + or WrapperGeneratorFactory._is_class_of_type(class_object[1], "TSNE") ) @staticmethod @@ -739,6 +741,7 @@ def _populate_function_doc_fields(self) -> None: _METHODS = [ "fit", "fit_predict", + "fit_transform", "predict", "predict_log_proba", "predict_proba", @@ -775,6 +778,7 @@ def _populate_function_doc_fields(self) -> None: self.transform_docstring = self.estimator_function_docstring["transform"] self.predict_docstring = self.estimator_function_docstring["predict"] self.fit_predict_docstring = self.estimator_function_docstring["fit_predict"] + self.fit_transform_docstring = self.estimator_function_docstring["fit_transform"] self.predict_proba_docstring = self.estimator_function_docstring["predict_proba"] self.score_samples_docstring = self.estimator_function_docstring["score_samples"] self.predict_log_proba_docstring = self.estimator_function_docstring["predict_log_proba"] @@ -898,6 +902,8 @@ def _populate_integ_test_fields(self) -> None: self.test_dataset_func = LOAD_BREAST_CANCER elif self._is_regressor: self.test_dataset_func = LOAD_DIABETES + elif WrapperGeneratorFactory._is_class_of_type(self.class_object[1], "SpectralEmbedding"): + self.test_dataset_func = LOAD_DIGITS else: self.test_dataset_func = LOAD_IRIS diff --git a/codegen/sklearn_wrapper_template.py_template b/codegen/sklearn_wrapper_template.py_template index e1e62d9f..f541baa0 100644 --- a/codegen/sklearn_wrapper_template.py_template +++ b/codegen/sklearn_wrapper_template.py_template @@ -54,12 +54,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "{transform.root_module_name}".re DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame] -def _is_fit_transform_method_enabled() -> Callable[[Any], bool]: - def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]: - return {transform.fit_transform_manifold_function_support} and callable(getattr(self._sklearn_object, "fit_transform", None)) - return check - - class {transform.original_class_name}(BaseTransformer): r"""{transform.estimator_class_docstring} """ @@ -174,20 +168,17 @@ class {transform.original_class_name}(BaseTransformer): self, dataset: DataFrame, inference_method: str, - ) -> List[str]: - """Util method to run validate that batch inference can be run on a snowpark dataframe and - return the available package that exists in the snowflake anaconda channel + ) -> None: + """Util method to run validate that batch inference can be run on a snowpark dataframe. Args: dataset: snowpark dataframe inference_method: the inference method such as predict, score... - + Raises: SnowflakeMLException: If the estimator is not fitted, raise error SnowflakeMLException: If the session is None, raise error - Returns: - A list of available package that exists in the snowflake anaconda channel """ if not self._is_fitted: raise exceptions.SnowflakeMLException( @@ -205,9 +196,7 @@ class {transform.original_class_name}(BaseTransformer): "Session must not specified for snowpark dataset." ), ) - # Validate that key package version in user workspace are supported in snowflake conda channel - return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( - pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT) + @available_if(original_estimator_has_callable("predict")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( @@ -246,7 +235,8 @@ class {transform.original_class_name}(BaseTransformer): expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type()) - self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -321,10 +311,8 @@ class {transform.original_class_name}(BaseTransformer): if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols): expected_dtype = convert_sp_to_sf_type(output_types[0]) - self._deps = self._batch_inference_validate_snowpark( - dataset=dataset, - inference_method=inference_method, - ) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark() transform_kwargs = dict( @@ -383,16 +371,32 @@ class {transform.original_class_name}(BaseTransformer): self._is_fitted = True return output_result - - @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc] - def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]: + + @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc] + def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]: """ {transform.fit_transform_docstring} + output_cols_prefix: Prefix for the response columns Returns: Transformed dataset. """ - self.fit(dataset) - assert self._sklearn_object is not None - return self._sklearn_object.embedding_ + self._infer_input_output_cols(dataset) + super()._check_dataset_type(dataset) + model_trainer = ModelTrainerBuilder.build_fit_transform( + estimator=self._sklearn_object, + dataset=dataset, + input_cols=self.input_cols, + label_cols=self.label_cols, + sample_weight_col=self.sample_weight_col, + autogenerated=self._autogenerated, + subproject=_SUBPROJECT, + ) + output_result, fitted_estimator = model_trainer.train_fit_transform( + drop_input_cols=self._drop_input_cols, + expected_output_cols_list=self.output_cols, + ) + self._sklearn_object = fitted_estimator + self._is_fitted = True + return output_result def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]: @@ -475,10 +479,8 @@ class {transform.original_class_name}(BaseTransformer): expected_output_cols = self._get_output_column_names(output_cols_prefix) if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark( - dataset=dataset, - inference_method=inference_method, - ) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -535,10 +537,8 @@ class {transform.original_class_name}(BaseTransformer): transform_kwargs: BatchInferenceKwargsTypedDict = dict() if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark( - dataset=dataset, - inference_method=inference_method, - ) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -592,10 +592,8 @@ class {transform.original_class_name}(BaseTransformer): expected_output_cols = self._get_output_column_names(output_cols_prefix) if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark( - dataset=dataset, - inference_method=inference_method, - ) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -653,10 +651,8 @@ class {transform.original_class_name}(BaseTransformer): expected_output_cols = self._get_output_column_names(output_cols_prefix) if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark( - dataset=dataset, - inference_method=inference_method, - ) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark() transform_kwargs = dict( session=dataset._session, @@ -710,17 +706,15 @@ class {transform.original_class_name}(BaseTransformer): transform_kwargs: ScoreKwargsTypedDict = dict() if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark( - dataset=dataset, - inference_method="score", - ) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score") + self._deps = self._get_dependencies() selected_cols = self._get_active_columns() if len(selected_cols) > 0: dataset = dataset.select(selected_cols) assert isinstance(dataset._session, Session) # keep mypy happy transform_kwargs = dict( session=dataset._session, - dependencies=["snowflake-snowpark-python"] + self._deps, + dependencies=self._deps, score_sproc_imports={transform.score_sproc_imports}, ) elif isinstance(dataset, pd.DataFrame): @@ -778,11 +772,8 @@ class {transform.original_class_name}(BaseTransformer): if isinstance(dataset, DataFrame): # TODO: Solve inconsistent neigh_ind with sklearn due to different precisions in case of close distances. - self._deps = self._batch_inference_validate_snowpark( - dataset=dataset, - inference_method=inference_method, - - ) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark() transform_kwargs = dict( session = dataset._session, diff --git a/codegen/transformer_autogen_test_template.py_template b/codegen/transformer_autogen_test_template.py_template index 65174ee9..a35cdbbb 100644 --- a/codegen/transformer_autogen_test_template.py_template +++ b/codegen/transformer_autogen_test_template.py_template @@ -13,7 +13,6 @@ from snowflake.ml.utils.connection_params import SnowflakeLoginOptions from snowflake.snowpark import Session, DataFrame - class {transform.test_class_name}(TestCase): def setUp(self) -> None: """Creates Snowpark and Snowflake environments for testing.""" @@ -125,7 +124,7 @@ class {transform.test_class_name}(TestCase): sklearn_reg.fit(**args) - inference_methods = ["transform", "predict", "fit_predict"] + inference_methods = ["transform", "predict", "fit_predict", "fit_transform"] for m in inference_methods: if callable(getattr(sklearn_reg, m, None)): if m == 'predict': @@ -151,7 +150,7 @@ class {transform.test_class_name}(TestCase): # TODO(snandamuri): Implement type inference for transform and predict methods to return results with # correct datatype. - if m == 'transform': + if m == 'transform' or m == 'fit_transform': actual_arr = output_df_pandas.astype("float64").to_numpy() else: actual_arr = output_df_pandas.to_numpy() @@ -163,7 +162,12 @@ class {transform.test_class_name}(TestCase): ] actual_arr = output_df_pandas[actual_output_cols].to_numpy() - sklearn_numpy_arr = getattr(sklearn_reg, m)(input_df_pandas[input_cols]) + if m == 'fit_transform': + sklearn_numpy_arr = sklearn_reg.fit_transform(**args) + else: + sklearn_numpy_arr = getattr(sklearn_reg, m)(input_df_pandas[input_cols]) + + if len(sklearn_numpy_arr.shape) == 3: # VotingClassifier will return results of shape (n_classifiers, n_samples, n_classes) # when voting = "soft" and flatten_transform = False. We can't handle unflatten transforms, diff --git a/snowflake/ml/_internal/BUILD.bazel b/snowflake/ml/_internal/BUILD.bazel index 89d559ed..32de630f 100644 --- a/snowflake/ml/_internal/BUILD.bazel +++ b/snowflake/ml/_internal/BUILD.bazel @@ -41,7 +41,6 @@ py_library( srcs = ["env_utils.py"], deps = [ ":env", - "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/utils:query_result_checker", "//snowflake/ml/_internal/utils:retryable_http", ], diff --git a/snowflake/ml/_internal/env_utils.py b/snowflake/ml/_internal/env_utils.py index a3a87a4d..50298bf3 100644 --- a/snowflake/ml/_internal/env_utils.py +++ b/snowflake/ml/_internal/env_utils.py @@ -13,10 +13,6 @@ import snowflake.connector from snowflake.ml._internal import env as snowml_env -from snowflake.ml._internal.exceptions import ( - error_codes, - exceptions as snowml_exceptions, -) from snowflake.ml._internal.utils import query_result_checker from snowflake.snowpark import context, exceptions, session from snowflake.snowpark._internal import utils as snowpark_utils @@ -237,6 +233,72 @@ def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement return new_pip_req +class IncorrectLocalEnvironmentError(Exception): + ... + + +def validate_local_installed_version_of_pip_package(pip_req: requirements.Requirement) -> None: + """Validate if the package is locally installed, and the local version meet the specifier of the requirements. + + Args: + pip_req: A requirements.Requirement object showing the requirement. + + Raises: + IncorrectLocalEnvironmentError: Raised when cannot find the local installation of the requested package. + IncorrectLocalEnvironmentError: Raised when the local installed version cannot meet the requirement. + """ + try: + local_dist = importlib_metadata.distribution(pip_req.name) + local_dist_version = version.parse(local_dist.version) + except importlib_metadata.PackageNotFoundError: + raise IncorrectLocalEnvironmentError(f"Cannot find the local installation of the requested package {pip_req}.") + + if not pip_req.specifier.contains(local_dist_version): + raise IncorrectLocalEnvironmentError( + f"The local installed version {local_dist_version} cannot meet the requirement {pip_req}." + ) + + +CONDA_PKG_NAME_TO_PYPI_MAP = {"pytorch": "torch"} + + +def try_convert_conda_requirement_to_pip(conda_req: requirements.Requirement) -> requirements.Requirement: + """Return a new requirements.Requirement object whose name has been attempted to convert to name in pypi from conda. + + Args: + conda_req: A requirements.Requirement object showing the requirement in conda. + + Returns: + A new requirements.Requirement object showing the requirement in pypi. + """ + pip_req = copy.deepcopy(conda_req) + pip_req.name = CONDA_PKG_NAME_TO_PYPI_MAP.get(conda_req.name, conda_req.name) + return pip_req + + +def validate_py_runtime_version(provided_py_version_str: str) -> None: + """Validate the provided python version string with python version in current runtime. + If the major or minor is different, errors out. + + Args: + provided_py_version_str: the provided python version string. + + Raises: + IncorrectLocalEnvironmentError: Raised when the provided python version has different major or minor. + """ + if provided_py_version_str != snowml_env.PYTHON_VERSION: + provided_py_version = version.parse(provided_py_version_str) + current_py_version = version.parse(snowml_env.PYTHON_VERSION) + if ( + provided_py_version.major != current_py_version.major + or provided_py_version.minor != current_py_version.minor + ): + raise IncorrectLocalEnvironmentError( + f"Requested python version is {provided_py_version_str} " + f"while current Python version is {snowml_env.PYTHON_VERSION}. " + ) + + def get_package_spec_with_supported_ops_only(req: requirements.Requirement) -> requirements.Requirement: """Get the package spec with supported ops only including ==, >=, <=, > and < @@ -568,33 +630,6 @@ def parse_python_version_string(dep: str) -> Optional[str]: return None -def validate_py_runtime_version(provided_py_version_str: str) -> None: - """Validate the provided python version string with python version in current runtime. - If the major or minor is different, errors out. - - Args: - provided_py_version_str: the provided python version string. - - Raises: - SnowflakeMLException: Raised when the provided python version has different major or minor. - """ - if provided_py_version_str != snowml_env.PYTHON_VERSION: - provided_py_version = version.parse(provided_py_version_str) - current_py_version = version.parse(snowml_env.PYTHON_VERSION) - if ( - provided_py_version.major != current_py_version.major - or provided_py_version.minor != current_py_version.minor - ): - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.LOCAL_ENVIRONMENT_ERROR, - original_exception=RuntimeError( - f"Unable to load model which is saved with Python {provided_py_version_str} " - f"while current Python version is {snowml_env.PYTHON_VERSION}. " - "To load model metadata only, set meta_only to True." - ), - ) - - def _find_conda_dep_spec( conda_chan_deps: DefaultDict[str, List[requirements.Requirement]], pkg_name: str ) -> Optional[Tuple[str, requirements.Requirement]]: diff --git a/snowflake/ml/_internal/env_utils_test.py b/snowflake/ml/_internal/env_utils_test.py index 10095ea2..f69ec27b 100644 --- a/snowflake/ml/_internal/env_utils_test.py +++ b/snowflake/ml/_internal/env_utils_test.py @@ -6,6 +6,7 @@ import textwrap from importlib import metadata as importlib_metadata from typing import DefaultDict, List, cast +from unittest import mock import yaml from absl.testing import absltest @@ -967,6 +968,63 @@ def test_generate_requirements_file(self) -> None: loaded_rl = env_utils.load_requirements_file(pip_file_path) self.assertEqual(rl, loaded_rl) + def test_validate_local_installed_version_of_pip_package(self) -> None: + with mock.patch.object( + importlib_metadata, "distribution", side_effect=importlib_metadata.PackageNotFoundError() + ) as mock_distribution: + with self.assertRaisesRegex( + env_utils.IncorrectLocalEnvironmentError, "Cannot find the local installation of the requested package" + ): + env_utils.validate_local_installed_version_of_pip_package( + requirements.Requirement("my_package >=1.3,<2") + ) + mock_distribution.assert_called_once_with("my_package") + + m_distribution = mock.MagicMock() + m_distribution.version = "2.3.0" + with mock.patch.object(importlib_metadata, "distribution", return_value=m_distribution) as mock_distribution: + with self.assertRaisesRegex( + env_utils.IncorrectLocalEnvironmentError, + "The local installed version 2.3.0 cannot meet the requirement", + ): + env_utils.validate_local_installed_version_of_pip_package( + requirements.Requirement("my_package >=1.3,<2") + ) + mock_distribution.assert_called_once_with("my_package") + + m_distribution = mock.MagicMock() + m_distribution.version = "1.3.0" + with mock.patch.object(importlib_metadata, "distribution", return_value=m_distribution) as mock_distribution: + env_utils.validate_local_installed_version_of_pip_package(requirements.Requirement("my_package >=1.3,<2")) + mock_distribution.assert_called_once_with("my_package") + + def test_try_convert_conda_requirement_to_pip(self) -> None: + self.assertEqual( + env_utils.try_convert_conda_requirement_to_pip(requirements.Requirement("my_package==1.3")), + requirements.Requirement("my_package==1.3"), + ) + + self.assertEqual( + env_utils.try_convert_conda_requirement_to_pip(requirements.Requirement("pytorch==1.3")), + requirements.Requirement("torch==1.3"), + ) + + self.assertEqual( + env_utils.try_convert_conda_requirement_to_pip(requirements.Requirement("numpy==1.3")), + requirements.Requirement("numpy==1.3"), + ) + + def test_validate_py_runtime_version(self) -> None: + with mock.patch.object(snowml_env, "PYTHON_VERSION", "3.8.13"): + env_utils.validate_py_runtime_version("3.8.13") + env_utils.validate_py_runtime_version("3.8.18") + + with self.assertRaisesRegex( + env_utils.IncorrectLocalEnvironmentError, + "Requested python version is 3.10.10 while current Python version is 3.8.13", + ): + env_utils.validate_py_runtime_version("3.10.10") + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/_internal/exceptions/BUILD.bazel b/snowflake/ml/_internal/exceptions/BUILD.bazel index 4dafa771..5f740d20 100644 --- a/snowflake/ml/_internal/exceptions/BUILD.bazel +++ b/snowflake/ml/_internal/exceptions/BUILD.bazel @@ -26,7 +26,7 @@ py_library( srcs = ["dataset_errors.py"], visibility = [ "//bazel:snowml_public_common", - "//snowflake/ml/beta/dataset:__pkg__", + "//snowflake/ml/dataset:__pkg__", ], ) @@ -35,7 +35,7 @@ py_library( srcs = ["dataset_error_messages.py"], visibility = [ "//bazel:snowml_public_common", - "//snowflake/ml/beta/dataset:__pkg__", + "//snowflake/ml/dataset:__pkg__", ], ) diff --git a/snowflake/ml/_internal/exceptions/dataset_errors.py b/snowflake/ml/_internal/exceptions/dataset_errors.py index 33c19e3d..af2e744f 100644 --- a/snowflake/ml/_internal/exceptions/dataset_errors.py +++ b/snowflake/ml/_internal/exceptions/dataset_errors.py @@ -4,6 +4,8 @@ ERRNO_FILES_ALREADY_EXISTING = "001030" ERRNO_VERSION_ALREADY_EXISTS = "092917" ERRNO_DATASET_NOT_EXIST = "399019" +ERRNO_DATASET_VERSION_NOT_EXIST = "399012" +ERRNO_DATASET_VERSION_ALREADY_EXISTS = "399020" class DatasetError(Exception): diff --git a/snowflake/ml/_internal/exceptions/error_codes.py b/snowflake/ml/_internal/exceptions/error_codes.py index 02c276a7..a5507015 100644 --- a/snowflake/ml/_internal/exceptions/error_codes.py +++ b/snowflake/ml/_internal/exceptions/error_codes.py @@ -105,3 +105,6 @@ # Missing required client side dependency. CLIENT_DEPENDENCY_MISSING_ERROR = "2511" + +# Current client side snowpark-ml-python version is outdated and may have forward compatibility issue +SNOWML_PACKAGE_OUTDATED = "2700" diff --git a/snowflake/ml/_internal/lineage/BUILD.bazel b/snowflake/ml/_internal/lineage/BUILD.bazel new file mode 100644 index 00000000..ba1525b2 --- /dev/null +++ b/snowflake/ml/_internal/lineage/BUILD.bazel @@ -0,0 +1,11 @@ +load("//bazel:py_rules.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "dataset_dataframe", + srcs = [ + "data_source.py", + "dataset_dataframe.py", + ], +) diff --git a/snowflake/ml/_internal/lineage/data_source.py b/snowflake/ml/_internal/lineage/data_source.py new file mode 100644 index 00000000..62f15122 --- /dev/null +++ b/snowflake/ml/_internal/lineage/data_source.py @@ -0,0 +1,10 @@ +import dataclasses +from typing import List, Optional + + +@dataclasses.dataclass(frozen=True) +class DataSource: + fully_qualified_name: str + version: str + url: str + exclude_cols: Optional[List[str]] = None diff --git a/snowflake/ml/_internal/lineage/dataset_dataframe.py b/snowflake/ml/_internal/lineage/dataset_dataframe.py new file mode 100644 index 00000000..a415dbe0 --- /dev/null +++ b/snowflake/ml/_internal/lineage/dataset_dataframe.py @@ -0,0 +1,44 @@ +import copy +from typing import List + +from snowflake import snowpark +from snowflake.ml._internal.lineage import data_source + + +class DatasetDataFrame(snowpark.DataFrame): + """ + Represents a lazily-evaluated dataset. It extends :class:`snowpark.DataFrame` so all + :class:`snowpark.DataFrame` operations can be applied to it. It holds additional information + related to the :class`Dataset`. + + It will be created by dataset.read.to_snowpark_dataframe() API and by the transformations + that produce a new dataframe. + """ + + @staticmethod + def from_dataframe( + df: snowpark.DataFrame, data_sources: List[data_source.DataSource], inplace: bool = False + ) -> "DatasetDataFrame": + """ + Create a new DatasetDataFrame instance from a snowpark.DataFrame instance with + additional source information. + + Args: + df (snowpark.DataFrame): The Snowpark DataFrame to be converted. + data_sources (List[DataSource]): A list of data sources to associate with the DataFrame. + inplace (bool): If True, modifies the DataFrame in place; otherwise, returns a new DatasetDataFrame. + + Returns: + DatasetDataFrame: A new or modified DatasetDataFrame depending on the 'inplace' argument. + """ + if not inplace: + df = copy.deepcopy(df) + df.__class__ = DatasetDataFrame + df._data_sources = data_sources # type:ignore[attr-defined] + return df # type: ignore[return-value] + + def _get_sources(self) -> List[data_source.DataSource]: + """ + Returns the data sources associated with the DataFrame. + """ + return self._data_sources # type: ignore[no-any-return] diff --git a/snowflake/ml/_internal/utils/BUILD.bazel b/snowflake/ml/_internal/utils/BUILD.bazel index 423ee9ff..2189d7fd 100644 --- a/snowflake/ml/_internal/utils/BUILD.bazel +++ b/snowflake/ml/_internal/utils/BUILD.bazel @@ -4,7 +4,9 @@ package(default_visibility = ["//visibility:public"]) py_library( name = "snowpark_dataframe_utils", - srcs = ["snowpark_dataframe_utils.py"], + srcs = [ + "snowpark_dataframe_utils.py", + ], ) py_library( diff --git a/snowflake/ml/dataset/BUILD.bazel b/snowflake/ml/dataset/BUILD.bazel index 6dc0b288..32ee3396 100644 --- a/snowflake/ml/dataset/BUILD.bazel +++ b/snowflake/ml/dataset/BUILD.bazel @@ -1,14 +1,58 @@ -load("//bazel:py_rules.bzl", "py_library") +load("//bazel:py_rules.bzl", "py_library", "py_package", "py_test") package(default_visibility = ["//visibility:public"]) +py_package( + name = "dataset_pkg", + packages = ["snowflake.ml"], + deps = [ + ":dataset", + ], +) + +py_library( + name = "dataset_reader", + srcs = [ + "dataset_reader.py", + ], + deps = [ + "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/lineage:dataset_dataframe", + "//snowflake/ml/fileset:snowfs", + "//snowflake/ml/fileset:tf_dataset", + "//snowflake/ml/fileset:torch_datapipe", + ], +) + +py_library( + name = "dataset_metadata", + srcs = ["dataset_metadata.py"], +) + +py_test( + name = "dataset_metadata_test", + srcs = ["dataset_metadata_test.py"], + deps = [ + ":dataset", + "//snowflake/ml/feature_store:feature_store_lib", + ], +) + py_library( name = "dataset", srcs = [ + "__init__.py", "dataset.py", + "dataset_factory.py", ], deps = [ - "//snowflake/ml/_internal/utils:query_result_checker", - "//snowflake/ml/registry:artifact_manager", + ":dataset_metadata", + ":dataset_reader", + "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/exceptions", + "//snowflake/ml/_internal/exceptions:dataset_error_messages", + "//snowflake/ml/_internal/exceptions:dataset_errors", + "//snowflake/ml/_internal/utils:import_utils", + "//snowflake/ml/_internal/utils:snowpark_dataframe_utils", ], ) diff --git a/snowflake/ml/dataset/__init__.py b/snowflake/ml/dataset/__init__.py new file mode 100644 index 00000000..a0550d09 --- /dev/null +++ b/snowflake/ml/dataset/__init__.py @@ -0,0 +1,10 @@ +from .dataset import Dataset +from .dataset_factory import create_from_dataframe, load_dataset +from .dataset_reader import DatasetReader + +__all__ = [ + "Dataset", + "DatasetReader", + "create_from_dataframe", + "load_dataset", +] diff --git a/snowflake/ml/dataset/dataset.py b/snowflake/ml/dataset/dataset.py index b51e6e84..05ab7dc6 100644 --- a/snowflake/ml/dataset/dataset.py +++ b/snowflake/ml/dataset/dataset.py @@ -1,161 +1,486 @@ import json -import time -from dataclasses import dataclass -from typing import Any, Dict, List, Optional +import warnings +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple, Union -from snowflake.ml.registry.artifact import Artifact, ArtifactType -from snowflake.snowpark import DataFrame, Session +from snowflake import snowpark +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.exceptions import ( + dataset_error_messages, + dataset_errors, + error_codes, + exceptions as snowml_exceptions, +) +from snowflake.ml._internal.lineage import data_source +from snowflake.ml._internal.utils import ( + formatting, + identifier, + query_result_checker, + snowpark_dataframe_utils, +) +from snowflake.ml.dataset import dataset_metadata, dataset_reader +from snowflake.snowpark import exceptions as snowpark_exceptions, functions +_PROJECT = "Dataset" +_TELEMETRY_STATEMENT_PARAMS = telemetry.get_function_usage_statement_params(_PROJECT) +_METADATA_MAX_QUERY_LENGTH = 10000 +_DATASET_VERSION_NAME_COL = "version" -def _get_val_or_null(val: Any) -> Any: - return val if val is not None else "null" +class DatasetVersion: + """Represents a version of a Snowflake Dataset""" -def _wrap_embedded_str(s: str) -> str: - s = s.replace("\\", "\\\\") - s = s.replace('"', '\\"') - return s + @telemetry.send_api_usage_telemetry(project=_PROJECT) + def __init__( + self, + dataset: "Dataset", + version: str, + ) -> None: + """Initialize a DatasetVersion object. + Args: + dataset: The parent Snowflake Dataset. + version: Dataset version name. + """ + self._parent = dataset + self._version = version + self._session: snowpark.Session = self._parent._session -DATASET_SCHEMA_VERSION = "1" + self._properties: Optional[Dict[str, Any]] = None + self._raw_metadata: Optional[Dict[str, Any]] = None + self._metadata: Optional[dataset_metadata.DatasetMetadata] = None + @property + def name(self) -> str: + return self._version -@dataclass(frozen=True) -class FeatureStoreMetadata: - """ - Feature store metadata. + @property + def created_on(self) -> datetime: + timestamp = self._get_property("created_on") + assert isinstance(timestamp, datetime) + return timestamp - Properties: - spine_query: The input query on source table which will be joined with features. - connection_params: a config contains feature store metadata. - features: A list of feature serialized object in the feature store. + @property + def comment(self) -> Optional[str]: + comment: Optional[str] = self._get_property("comment") + return comment - """ + def _get_property(self, property_name: str, default: Any = None) -> Any: + if self._properties is None: + sql_result = ( + query_result_checker.SqlResultValidator( + self._session, + f"SHOW VERSIONS LIKE '{self._version}' IN DATASET {self._parent.fully_qualified_name}", + statement_params=_TELEMETRY_STATEMENT_PARAMS, + ) + .has_dimensions(expected_rows=1) + .validate() + ) + self._properties = sql_result[0].as_dict(True) + return self._properties.get(property_name, default) + + def _get_metadata(self) -> Optional[dataset_metadata.DatasetMetadata]: + if self._raw_metadata is None: + self._raw_metadata = json.loads(self._get_property("metadata", "{}")) + try: + self._metadata = ( + dataset_metadata.DatasetMetadata.from_json(self._raw_metadata) if self._raw_metadata else None + ) + except ValueError as e: + warnings.warn(f"Metadata parsing failed with error: {e}", UserWarning, stacklevel=2) + return self._metadata - spine_query: str - connection_params: Dict[str, str] - features: List[str] + def _get_exclude_cols(self) -> List[str]: + metadata = self._get_metadata() + if metadata is None: + return [] + cols = [] + if metadata.exclude_cols: + cols.extend(metadata.exclude_cols) + if metadata.label_cols: + cols.extend(metadata.label_cols) + return cols - def to_json(self) -> str: - state_dict = { - # TODO(zhe): Additional wrap is needed because ml_.artifact.ad_artifact takes a dict - # but we retrieve it as an object. Snowpark serialization is inconsistent with - # our deserialization. A fix is let artifact table stores string and callers - # handles both serialization and deserialization. - "spine_query": self.spine_query, - "connection_params": json.dumps(self.connection_params), - "features": json.dumps(self.features), - } - return json.dumps(state_dict) + def url(self) -> str: + """Returns the URL of the DatasetVersion contents in Snowflake. + + Returns: + Snowflake URL string. + """ + path = f"snow://dataset/{self._parent.fully_qualified_name}/versions/{self._version}/" + return path - @classmethod - def from_json(cls, json_str: str) -> "FeatureStoreMetadata": - json_dict = json.loads(json_str) - return cls( - spine_query=json_dict["spine_query"], - connection_params=json.loads(json_dict["connection_params"]), - features=json.loads(json_dict["features"]), + @telemetry.send_api_usage_telemetry(project=_PROJECT) + def list_files(self, subdir: Optional[str] = None) -> List[snowpark.Row]: + """Get the list of remote file paths for the current DatasetVersion.""" + return self._session.sql(f"LIST {self.url()}{subdir or ''}").collect( + statement_params=_TELEMETRY_STATEMENT_PARAMS ) + def __repr__(self) -> str: + return f"{self.__class__.__name__}(dataset='{self._parent.fully_qualified_name}', version='{self.name}')" -class Dataset(Artifact): - """Metadata of dataset.""" +class Dataset: + """Represents a Snowflake Dataset which is organized into versions.""" + + @telemetry.send_api_usage_telemetry(project=_PROJECT) def __init__( self, - session: Session, - df: DataFrame, - generation_timestamp: Optional[float] = None, - materialized_table: Optional[str] = None, - snapshot_table: Optional[str] = None, - timestamp_col: Optional[str] = None, - label_cols: Optional[List[str]] = None, - feature_store_metadata: Optional[FeatureStoreMetadata] = None, - desc: str = "", + session: snowpark.Session, + database: str, + schema: str, + name: str, + selected_version: Optional[str] = None, ) -> None: - """Initialize dataset object. + """Initialize a lazily evaluated Dataset object""" + self._session = session + self._db = database + self._schema = schema + self._name = name + self._fully_qualified_name = identifier.get_schema_level_object_identifier(database, schema, name) + + self._version = DatasetVersion(self, selected_version) if selected_version else None + self._reader: Optional[dataset_reader.DatasetReader] = None + + @property + def fully_qualified_name(self) -> str: + return self._fully_qualified_name + + @property + def selected_version(self) -> Optional[DatasetVersion]: + return self._version + + @property + def read(self) -> dataset_reader.DatasetReader: + if not self.selected_version: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_ATTRIBUTE, + original_exception=RuntimeError("No Dataset version selected."), + ) + if self._reader is None: + v = self.selected_version + self._reader = dataset_reader.DatasetReader( + self._session, + [ + data_source.DataSource( + fully_qualified_name=self._fully_qualified_name, + version=v.name, + url=v.url(), + exclude_cols=v._get_exclude_cols(), + ) + ], + ) + return self._reader + + @staticmethod + @telemetry.send_api_usage_telemetry(project=_PROJECT) + def load(session: snowpark.Session, name: str) -> "Dataset": + """ + Load an existing Snowflake Dataset. DatasetVersions can be created from the Dataset object + using `Dataset.create_version()` and loaded with `Dataset.version()`. Args: - session: An active snowpark session. - df: A dataframe object representing the dataset generation. - generation_timestamp: The timestamp when this dataset is generated. It will use current time if - not provided. - materialized_table: The destination table name which data will writes into. - snapshot_table: A snapshot table name on the materialized table. - timestamp_col: Timestamp column which was used for point-in-time correct feature lookup. - label_cols: Name of column(s) in materialized_table that contains labels. - feature_store_metadata: A feature store metadata object. - desc: A description about this dataset. + session: Snowpark Session to interact with Snowflake backend. + name: Name of dataset to load. May optionally be a schema-level identifier. + + Returns: + Dataset object representing loaded dataset + + Raises: + ValueError: name is not a valid Snowflake identifier + DatasetNotExistError: Specified Dataset does not exist + + # noqa: DAR402 """ - self.df = df - self.generation_timestamp = generation_timestamp if generation_timestamp is not None else time.time() - self.materialized_table = materialized_table - self.snapshot_table = snapshot_table - self.timestamp_col = timestamp_col - self.label_cols = label_cols - self.feature_store_metadata = feature_store_metadata - self.desc = desc - self.owner = session.sql("SELECT CURRENT_USER()").collect()[0]["CURRENT_USER()"] - self.schema_version = DATASET_SCHEMA_VERSION - - super().__init__(type=ArtifactType.DATASET, spec=self.to_json()) - - def load_features(self) -> Optional[List[str]]: - if self.feature_store_metadata is not None: - return self.feature_store_metadata.features - else: - return None - - def features_df(self) -> DataFrame: - result = self.df - if self.timestamp_col is not None: - result = result.drop(self.timestamp_col) - if self.label_cols is not None: - result = result.drop(self.label_cols) - return result - - def to_json(self) -> str: - if len(self.df.queries["queries"]) != 1: - raise ValueError( - f"""df dataframe must contain only 1 query. -Got {len(self.df.queries['queries'])}: {self.df.queries['queries']} -""" + db, schema, ds_name = _get_schema_level_identifier(session, name) + _validate_dataset_exists(session, db, schema, ds_name) + return Dataset(session, db, schema, ds_name) + + @staticmethod + @telemetry.send_api_usage_telemetry(project=_PROJECT) + def create(session: snowpark.Session, name: str, exist_ok: bool = False) -> "Dataset": + """ + Create a new Snowflake Dataset. DatasetVersions can created from the Dataset object + using `Dataset.create_version()` and loaded with `Dataset.version()`. + + Args: + session: Snowpark Session to interact with Snowflake backend. + name: Name of dataset to create. May optionally be a schema-level identifier. + exist_ok: If False, raises an exception if specified Dataset already exists + + Returns: + Dataset object representing created dataset + + Raises: + ValueError: name is not a valid Snowflake identifier + DatasetExistError: Specified Dataset already exists + DatasetError: Dataset creation failed + + # noqa: DAR401 + # noqa: DAR402 + """ + db, schema, ds_name = _get_schema_level_identifier(session, name) + ds_fqn = identifier.get_schema_level_object_identifier(db, schema, ds_name) + query = f"CREATE DATASET{' IF NOT EXISTS' if exist_ok else ''} {ds_fqn}" + try: + session.sql(query).collect(statement_params=_TELEMETRY_STATEMENT_PARAMS) + return Dataset(session, db, schema, ds_name) + except snowpark_exceptions.SnowparkClientException as e: + # Snowpark wraps the Python Connector error code in the head of the error message. + if e.message.startswith(dataset_errors.ERRNO_OBJECT_ALREADY_EXISTS): + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.OBJECT_ALREADY_EXISTS, + original_exception=dataset_errors.DatasetExistError( + dataset_error_messages.DATASET_ALREADY_EXISTS.format(name) + ), + ) from e + else: + raise + + @telemetry.send_api_usage_telemetry(project=_PROJECT) + def list_versions(self, detailed: bool = False) -> Union[List[str], List[snowpark.Row]]: + """Return list of versions""" + versions = self._list_versions() + versions.sort(key=lambda r: r[_DATASET_VERSION_NAME_COL]) + if not detailed: + return [r[_DATASET_VERSION_NAME_COL] for r in versions] + return versions + + @telemetry.send_api_usage_telemetry(project=_PROJECT) + def select_version(self, version: str) -> "Dataset": + """Return a new Dataset instance with the specified version selected. + + Args: + version: Dataset version name. + + Returns: + Dataset object. + """ + self._validate_version_exists(version) + return Dataset(self._session, self._db, self._schema, self._name, version) + + @telemetry.send_api_usage_telemetry(project=_PROJECT) + def create_version( + self, + version: str, + input_dataframe: snowpark.DataFrame, + shuffle: bool = False, + exclude_cols: Optional[List[str]] = None, + label_cols: Optional[List[str]] = None, + properties: Optional[dataset_metadata.DatasetPropertiesType] = None, + partition_by: Optional[str] = None, + comment: Optional[str] = None, + ) -> "Dataset": + """Create a new version of the current Dataset. + + The result Dataset object captures the query result deterministically as stage files. + + Args: + version: Dataset version name. Data contents are materialized to the Dataset entity. + input_dataframe: A Snowpark DataFrame which yields the Dataset contents. + shuffle: A boolean represents whether the data should be shuffled globally. Default to be false. + exclude_cols: Name of column(s) in dataset to be excluded during training/testing (e.g. timestamp). + label_cols: Name of column(s) in dataset that contains labels. + properties: Custom metadata properties, saved under `DatasetMetadata.properties` + partition_by: Optional partitioning scheme within the new Dataset version. + comment: A descriptive comment about this dataset. + + Returns: + A Dataset object with the newly created version selected. + + Raises: + SnowflakeMLException: The Dataset no longer exists. + SnowflakeMLException: The specified Dataset version already exists. + snowpark_exceptions.SnowparkClientException: An error occurred during Dataset creation. + + Note: During the generation of stage files, data casting will occur. The casting rules are as follows:: + - Data casting: + - DecimalType(NUMBER): + - If its scale is zero, cast to BIGINT + - If its scale is non-zero, cast to FLOAT + - DoubleType(DOUBLE): Cast to FLOAT. + - ByteType(TINYINT): Cast to SMALLINT. + - ShortType(SMALLINT):Cast to SMALLINT. + - IntegerType(INT): Cast to INT. + - LongType(BIGINT): Cast to BIGINT. + - No action: + - FloatType(FLOAT): No action. + - StringType(String): No action. + - BinaryType(BINARY): No action. + - BooleanType(BOOLEAN): No action. + - Not supported: + - ArrayType(ARRAY): Not supported. A warning will be logged. + - MapType(OBJECT): Not supported. A warning will be logged. + - TimestampType(TIMESTAMP): Not supported. A warning will be logged. + - TimeType(TIME): Not supported. A warning will be logged. + - DateType(DATE): Not supported. A warning will be logged. + - VariantType(VARIANT): Not supported. A warning will be logged. + """ + casted_df = snowpark_dataframe_utils.cast_snowpark_dataframe(input_dataframe) + + if shuffle: + casted_df = casted_df.order_by(functions.random()) + + source_query = json.dumps(input_dataframe.queries) + if len(source_query) > _METADATA_MAX_QUERY_LENGTH: + warnings.warn( + "Source query exceeded max query length, dropping from metadata (limit=%d, actual=%d)" + % (_METADATA_MAX_QUERY_LENGTH, len(source_query)), + stacklevel=2, ) + source_query = "" - state_dict = { - "df_query": _wrap_embedded_str(self.df.queries["queries"][0]), - "generation_timestamp": self.generation_timestamp, - "owner": self.owner, - "materialized_table": _wrap_embedded_str(_get_val_or_null(self.materialized_table)), - "snapshot_table": _wrap_embedded_str(_get_val_or_null(self.snapshot_table)), - "timestamp_col": _wrap_embedded_str(_get_val_or_null(self.timestamp_col)), - "label_cols": _get_val_or_null(self.label_cols), - "feature_store_metadata": _wrap_embedded_str(self.feature_store_metadata.to_json()) - if self.feature_store_metadata is not None - else "null", - "schema_version": self.schema_version, - "desc": self.desc, - } - return json.dumps(state_dict) - - @classmethod - def from_json(cls, json_str: str, session: Session) -> "Dataset": - json_dict = json.loads(json_str, strict=False) - json_dict["df"] = session.sql(json_dict.pop("df_query")) - - fs_meta_json = json_dict["feature_store_metadata"] - json_dict["feature_store_metadata"] = ( - FeatureStoreMetadata.from_json(fs_meta_json) if fs_meta_json != "null" else None + metadata = dataset_metadata.DatasetMetadata( + source_query=source_query, + owner=self._session.sql("SELECT CURRENT_USER()").collect(statement_params=_TELEMETRY_STATEMENT_PARAMS)[0][ + "CURRENT_USER()" + ], + exclude_cols=exclude_cols, + label_cols=label_cols, + properties=properties, ) - schema_version = json_dict.pop("schema_version") - owner = json_dict.pop("owner") + post_actions = casted_df._plan.post_actions + try: + # Execute all but the last query, final query gets passed to ALTER DATASET ADD VERSION + query = casted_df._plan.queries[-1].sql.strip() + if len(casted_df._plan.queries) > 1: + casted_df._plan.queries = casted_df._plan.queries[:-1] + casted_df._plan.post_actions = [] + casted_df.collect(statement_params=_TELEMETRY_STATEMENT_PARAMS) + sql_command = "ALTER DATASET {} ADD VERSION '{}' FROM ({})".format( + self.fully_qualified_name, + version, + query, + ) + if partition_by: + sql_command += f" PARTITION BY {partition_by}" + if comment: + sql_command += f" COMMENT={formatting.format_value_for_select(comment)}" + sql_command += f" METADATA=$${metadata.to_json()}$$" + self._session.sql(sql_command).collect(statement_params=_TELEMETRY_STATEMENT_PARAMS) + + return Dataset(self._session, self._db, self._schema, self._name, version) - result = cls(session, **json_dict) - result.schema_version = schema_version - result.owner = owner + except snowpark_exceptions.SnowparkClientException as e: + if e.message.startswith(dataset_errors.ERRNO_DATASET_NOT_EXIST): + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.NOT_FOUND, + original_exception=dataset_errors.DatasetNotExistError( + dataset_error_messages.DATASET_NOT_EXIST.format(self.fully_qualified_name) + ), + ) from e + elif ( + e.message.startswith(dataset_errors.ERRNO_DATASET_VERSION_ALREADY_EXISTS) + or e.message.startswith(dataset_errors.ERRNO_VERSION_ALREADY_EXISTS) + or e.message.startswith(dataset_errors.ERRNO_FILES_ALREADY_EXISTING) + ): + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.OBJECT_ALREADY_EXISTS, + original_exception=dataset_errors.DatasetExistError( + dataset_error_messages.DATASET_VERSION_ALREADY_EXISTS.format(self.fully_qualified_name, version) + ), + ) from e + else: + raise + finally: + for action in post_actions: + self._session.sql(action.sql.strip()).collect(statement_params=_TELEMETRY_STATEMENT_PARAMS) - return result + @telemetry.send_api_usage_telemetry(project=_PROJECT) + def delete_version(self, version_name: str) -> None: + """Delete the Dataset version - def __eq__(self, other: object) -> bool: - return isinstance(other, Dataset) and self.to_json() == other.to_json() + Args: + version_name: Name of version to delete from Dataset + + Raises: + SnowflakeMLException: An error occurred when the DatasetVersion cannot get deleted. + """ + delete_sql = f"ALTER DATASET {self.fully_qualified_name} DROP VERSION '{version_name}'" + try: + self._session.sql(delete_sql).collect( + statement_params=_TELEMETRY_STATEMENT_PARAMS, + ) + except snowpark_exceptions.SnowparkClientException as e: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.SNOWML_DELETE_FAILED, + original_exception=dataset_errors.DatasetCannotDeleteError(str(e)), + ) from e + return + + @telemetry.send_api_usage_telemetry(project=_PROJECT) + def delete(self) -> None: + """Delete Dataset and all contained versions""" + # TODO: Check and warn if any versions exist + self._session.sql(f"DROP DATASET {self.fully_qualified_name}").collect( + statement_params=_TELEMETRY_STATEMENT_PARAMS + ) + + def _list_versions(self, pattern: Optional[str] = None) -> List[snowpark.Row]: + """Return list of versions""" + try: + pattern_clause = f" LIKE '{pattern}'" if pattern else "" + return ( + query_result_checker.SqlResultValidator( + self._session, + f"SHOW VERSIONS{pattern_clause} IN DATASET {self.fully_qualified_name}", + statement_params=_TELEMETRY_STATEMENT_PARAMS, + ) + .has_column(_DATASET_VERSION_NAME_COL, allow_empty=True) + .validate() + ) + except snowpark_exceptions.SnowparkClientException as e: + # Snowpark wraps the Python Connector error code in the head of the error message. + if e.message.startswith(dataset_errors.ERRNO_OBJECT_NOT_EXIST): + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.NOT_FOUND, + original_exception=dataset_errors.DatasetNotExistError( + dataset_error_messages.DATASET_NOT_EXIST.format(self.fully_qualified_name) + ), + ) from e + else: + raise + + def _validate_version_exists(self, version: str) -> None: + """Verify that the requested version exists. Raises DatasetNotExist if version not found""" + matches = self._list_versions(version) + matches = [m for m in matches if m[_DATASET_VERSION_NAME_COL] == version] # Case sensitive match + if len(matches) == 0: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.NOT_FOUND, + original_exception=dataset_errors.DatasetNotExistError( + dataset_error_messages.DATASET_VERSION_NOT_EXIST.format(self.fully_qualified_name, version) + ), + ) + + +# Utility methods + + +def _get_schema_level_identifier(session: snowpark.Session, dataset_name: str) -> Tuple[str, str, str]: + """Resolve a dataset name into a validated schema-level location identifier""" + db, schema, object_name, others = identifier.parse_schema_level_object_identifier(dataset_name) + if others: + raise ValueError(f"Invalid identifier: unexpected '{others}'") + db = db or session.get_current_database() + schema = schema or session.get_current_schema() + return str(db), str(schema), str(object_name) + + +def _validate_dataset_exists(session: snowpark.Session, db: str, schema: str, dataset_name: str) -> None: + # FIXME: Once we switch version to SQL Identifiers we can just use version check with version='' + dataset_name = identifier.resolve_identifier(dataset_name) + if len(dataset_name) > 0 and dataset_name[0] == '"' and dataset_name[-1] == '"': + dataset_name = identifier.get_unescaped_names(dataset_name) + # Case sensitive match + query = f"show datasets like '{dataset_name}' in schema {db}.{schema} starts with '{dataset_name}'" + ds_matches = session.sql(query).count() + if ds_matches == 0: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.NOT_FOUND, + original_exception=dataset_errors.DatasetNotExistError( + dataset_error_messages.DATASET_NOT_EXIST.format(dataset_name) + ), + ) diff --git a/snowflake/ml/dataset/dataset_factory.py b/snowflake/ml/dataset/dataset_factory.py new file mode 100644 index 00000000..d7b0ecc7 --- /dev/null +++ b/snowflake/ml/dataset/dataset_factory.py @@ -0,0 +1,53 @@ +from typing import Any + +from snowflake import snowpark +from snowflake.ml._internal import telemetry +from snowflake.ml.dataset import dataset + +_PROJECT = "Dataset" + + +@telemetry.send_api_usage_telemetry(project=_PROJECT) +def create_from_dataframe( + session: snowpark.Session, + name: str, + version: str, + input_dataframe: snowpark.DataFrame, + **version_kwargs: Any, +) -> dataset.Dataset: + """ + Create a new versioned Dataset from a DataFrame and returns + a DatasetReader for the newly created Dataset version. + + Args: + session: The Snowpark Session instance to use. + name: The dataset name + version: The dataset version name + input_dataframe: DataFrame containing data to be saved to the created Dataset. + version_kwargs: Keyword arguments passed to dataset version creation. + See `Dataset.create_version()` documentation for supported arguments. + + Returns: + A Dataset object. + """ + ds: dataset.Dataset = dataset.Dataset.create(session, name, exist_ok=True) + ds.create_version(version, input_dataframe=input_dataframe, **version_kwargs) + ds = ds.select_version(version) # select_version returns a new copy + return ds + + +@telemetry.send_api_usage_telemetry(project=_PROJECT) +def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.Dataset: + """ + Load a versioned Dataset into a DatasetReader. + + Args: + session: The Snowpark Session instance to use. + name: The dataset name. + version: The dataset version name. + + Returns: + A DatasetReader object. + """ + ds: dataset.Dataset = dataset.Dataset.load(session, name).select_version(version) + return ds diff --git a/snowflake/ml/dataset/dataset_metadata.py b/snowflake/ml/dataset/dataset_metadata.py new file mode 100644 index 00000000..bcf1eea9 --- /dev/null +++ b/snowflake/ml/dataset/dataset_metadata.py @@ -0,0 +1,103 @@ +import dataclasses +import json +import typing +from typing import Any, Dict, List, Optional, Union + +_PROPERTY_TYPE_KEY = "$proptype$" +DATASET_SCHEMA_VERSION = "1" + + +@dataclasses.dataclass(frozen=True) +class FeatureStoreMetadata: + """ + Feature store metadata. + + Properties: + spine_query: The input query on source table which will be joined with features. + serialized_feature_views: A list of serialized feature objects in the feature store. + spine_timestamp_col: Timestamp column which was used for point-in-time correct feature lookup. + """ + + spine_query: str + serialized_feature_views: List[str] + spine_timestamp_col: Optional[str] = None + + def to_json(self) -> str: + return json.dumps(dataclasses.asdict(self)) + + @classmethod + def from_json(cls, input_json: Union[Dict[str, Any], str, bytes]) -> "FeatureStoreMetadata": + if isinstance(input_json, dict): + return cls(**input_json) + return cls(**json.loads(input_json)) + + +DatasetPropertiesType = Union[ + FeatureStoreMetadata, +] + +# Union[T] gets automatically squashed to T, so default to [T] if get_args() returns empty +_DatasetPropTypes = typing.get_args(DatasetPropertiesType) or [DatasetPropertiesType] +_DatasetPropTypeDict = {t.__name__: t for t in _DatasetPropTypes} + + +@dataclasses.dataclass(frozen=True) +class DatasetMetadata: + """ + Dataset metadata. + + Properties: + source_query: The query string used to produce the Dataset. + owner: The owner of the Dataset. + generation_timestamp: The timestamp when this dataset was generated. + exclude_cols: Name of column(s) in dataset to be excluded during training/testing. + These are typically columns for human inspection such as timestamp or other meta-information. + Columns included in `label_cols` do not need to be included here. + label_cols: Name of column(s) in dataset that contains labels. + properties: Additional metadata properties. + """ + + source_query: str + owner: str + exclude_cols: Optional[List[str]] = None + label_cols: Optional[List[str]] = None + properties: Optional[DatasetPropertiesType] = None + schema_version: str = dataclasses.field(default=DATASET_SCHEMA_VERSION, init=False) + + def to_json(self) -> str: + state_dict = dataclasses.asdict(self) + if self.properties: + prop_type = type(self.properties).__name__ + if prop_type not in _DatasetPropTypeDict: + raise ValueError( + f"Unsupported `properties` type={prop_type} (supported={','.join(_DatasetPropTypeDict.keys())})" + ) + state_dict[_PROPERTY_TYPE_KEY] = prop_type + return json.dumps(state_dict) + + @classmethod + def from_json(cls, input_json: Union[Dict[str, Any], str, bytes]) -> "DatasetMetadata": + if not input_json: + raise ValueError("json_str was empty or None") + try: + state_dict: Dict[str, Any] = ( + input_json if isinstance(input_json, dict) else json.loads(input_json, strict=False) + ) + + # TODO: Validate schema version + _ = state_dict.pop("schema_version", DATASET_SCHEMA_VERSION) + + prop_type = state_dict.pop(_PROPERTY_TYPE_KEY, None) + prop_values = state_dict.get("properties", {}) + if prop_type: + prop_cls = _DatasetPropTypeDict.get(prop_type, None) + if prop_cls is None: + raise TypeError( + f"Unsupported `properties` type={prop_type} (supported={','.join(_DatasetPropTypeDict.keys())})" + ) + state_dict["properties"] = prop_cls(**prop_values) + elif prop_values: + raise TypeError(f"`properties` provided but missing `{_PROPERTY_TYPE_KEY}`") + return cls(**state_dict) + except TypeError as e: + raise ValueError("Invalid input schema") from e diff --git a/snowflake/ml/dataset/dataset_metadata_test.py b/snowflake/ml/dataset/dataset_metadata_test.py new file mode 100644 index 00000000..85e19e30 --- /dev/null +++ b/snowflake/ml/dataset/dataset_metadata_test.py @@ -0,0 +1,169 @@ +import json +from typing import Any, Dict, List, Optional + +import dataset_metadata +from absl.testing import absltest, parameterized + +from snowflake.ml.feature_store import entity, feature_view +from snowflake.snowpark import dataframe + + +class MockDataFrame(dataframe.DataFrame): + def __init__(self, query: str, columns: List[str]) -> None: + self._query = query + self._columns = [c.upper() for c in columns] + + @property + def queries(self) -> Dict[str, List[str]]: + return {"queries": [self._query]} + + @property + def columns(self) -> List[str]: + return self._columns + + +def _create_feature_view(name: str, columns: List[str]) -> feature_view.FeatureView: + df = MockDataFrame("test query", columns) + return feature_view.FeatureView( + name, + [entity.Entity("e", columns[:1])], + feature_df=df, + ) + + +def _create_metadata(props: Any) -> dataset_metadata.DatasetMetadata: + return dataset_metadata.DatasetMetadata( + source_query="test", + owner="test", + properties=props, + ) + + +class DatasetMetadataTest(parameterized.TestCase): + @parameterized.parameters( # type: ignore[misc] + {"input_props": None}, + {"input_props": dataset_metadata.FeatureStoreMetadata("query", {"conn": "test"}, ["feat1", "feat2"])}, + {"input_props": dataset_metadata.FeatureStoreMetadata("query", ["feat1", "feat2"], spine_timestamp_col="ts")}, + { + "input_props": dataset_metadata.FeatureStoreMetadata( + "query", + [ + _create_feature_view("fv1", ["col1", "col2"]).to_json(), + _create_feature_view("fv2", ["col1", "col3", "col4"]).slice(["col4"]).to_json(), + ], + spine_timestamp_col="ts", + ) + }, + ) + def test_json_convert(self, input_props: Any) -> None: + expected = input_props + + source = _create_metadata(input_props) + serialized = source.to_json() + actual = dataset_metadata.DatasetMetadata.from_json(serialized).properties + + self.assertEqual(expected, actual) + + @parameterized.parameters( # type: ignore[misc] + {"input_props": {"custom": "value"}}, + ) + def test_json_convert_negative(self, input_props: Any) -> None: + source = _create_metadata(input_props) + with self.assertRaises(ValueError): + source.to_json() + + @parameterized.parameters( # type: ignore[misc] + '{"source_query": "test_source", "owner": "test"}', + '{"source_query": "test_source", "owner": "test", "exclude_cols": ["col1"]}', + '{"source_query": "test_source", "owner": "test", "label_cols": ["col1"]}', + '{"source_query": "test_source", "owner": "test", "properties": null}', + '{"source_query": "test_source", "owner": "test", "properties": { } }', + '{"source_query": "test_source", "owner": "test", "properties": null}', + """ + { + "source_query": "test_source", + "owner": "test", + "$proptype$": "FeatureStoreMetadata", + "properties": { + "spine_query": "test query", + "serialized_feature_views": [] + } + } + """, + """ + { + "source_query": "test_source", + "owner": "test", + "$proptype$": "FeatureStoreMetadata", + "properties": { + "spine_query": "test query", + "serialized_feature_views": [ + "{\\\"_name\\\": \\\"fv1\\\"}" + ], + "spine_timestamp_col": "ts_col" + } + } + """, + ) + def test_deserialize(self, json_str: str) -> None: + actual = dataset_metadata.DatasetMetadata.from_json(json_str) + self.assertIsNotNone(actual) + + json_dict = json.loads(json_str) + actual2 = dataset_metadata.DatasetMetadata.from_json(json_dict) + self.assertIsNotNone(actual2) + + @parameterized.parameters( # type: ignore[misc] + None, + "", + "{}", + '{"unrelated": "value"}', + '{"source_query": "test_source", "owner": "test", "properties": {"prop1": "val1"} }', + """ + { + "source_query": "test_source", + "owner": "test", + "$proptype$": "unrecognized", + "properties": {"prop1": "val1"} + } + """, + """ + { + "source_query": "test_source", + "owner": "test", + "$proptype$": "dict", + "properties": {"prop1": "val1"} + } + """, + """ + { + "source_query": "test_source", + "owner": "test", + "$proptype$": "FeatureStoreMetadata", + "properties": {"prop1": "val1"} + } + """, + """ + { + "source_query": "test_source", + "owner": "test", + "$proptype$": "FeatureStoreMetadata" + } + """, + # FIXME: These test cases currently fail due to lack of type enforcement + # '{"source_query": "test_source", "owner": "test"}', + # '{"source_query": "test_source", "owner": "test", "exclude_cols": "col1"}', + # '{"source_query": "test_source", "owner": "test", "label_cols": "col1"}', + # '{"source_query": "test_source", "owner": "test", "properties": "value"}', + ) + def test_deserialize_negative(self, json_str: Optional[str]) -> None: + with self.assertRaises(ValueError): + dataset_metadata.DatasetMetadata.from_json(json_str) + + json_dict = json.loads(json_str) if json_str else None + with self.assertRaises(ValueError): + dataset_metadata.DatasetMetadata.from_json(json_dict) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/dataset/dataset_reader.py b/snowflake/ml/dataset/dataset_reader.py new file mode 100644 index 00000000..eb69c57f --- /dev/null +++ b/snowflake/ml/dataset/dataset_reader.py @@ -0,0 +1,202 @@ +from typing import Any, List + +import pandas as pd + +from snowflake import snowpark +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.lineage import data_source, dataset_dataframe +from snowflake.ml._internal.utils import import_utils +from snowflake.ml.fileset import snowfs + +_PROJECT = "Dataset" +_SUBPROJECT = "DatasetReader" +TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading. + + +class DatasetReader: + """Snowflake Dataset abstraction which provides application integration connectors""" + + @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT) + def __init__( + self, + session: snowpark.Session, + sources: List[data_source.DataSource], + ) -> None: + """Initialize a DatasetVersion object. + + Args: + session: Snowpark Session to interact with Snowflake backend. + sources: Data sources to read from. + + Raises: + ValueError: `sources` arg was empty or null + """ + if not sources: + raise ValueError("Invalid input: empty `sources` list not allowed") + self._session = session + self._sources = sources + self._fs: snowfs.SnowFileSystem = snowfs.SnowFileSystem( + snowpark_session=self._session, + cache_type="bytes", + block_size=2 * TARGET_FILE_SIZE, + ) + + self._files: List[str] = [] + + def _list_files(self) -> List[str]: + """Private helper function that lists all files in this DatasetVersion and caches the results.""" + if self._files: + return self._files + + files: List[str] = [] + for source in self._sources: + # Sort within each source for consistent ordering + files.extend(sorted(self._fs.ls(source.url))) # type: ignore[arg-type] + files.sort() + + self._files = files + return self._files + + @property + def data_sources(self) -> List[data_source.DataSource]: + return self._sources + + @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT) + def files(self) -> List[str]: + """Get the list of remote file paths for the current DatasetVersion. + + The file paths follows the snow protocol. + + Returns: + A list of remote file paths + + Example: + >>> dsv.files() + ---- + ["snow://dataset/mydb.myschema.mydataset/versions/test/data_0_0_0.snappy.parquet", + "snow://dataset/mydb.myschema.mydataset/versions/test/data_0_0_1.snappy.parquet"] + """ + files = self._list_files() + return [self._fs.unstrip_protocol(f) for f in files] + + @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT) + def filesystem(self) -> snowfs.SnowFileSystem: + """Return an fsspec FileSystem which can be used to load the DatasetVersion's `files()`""" + return self._fs + + @telemetry.send_api_usage_telemetry( + project=_PROJECT, + subproject=_SUBPROJECT, + func_params_to_log=["batch_size", "shuffle", "drop_last_batch"], + ) + def to_torch_datapipe(self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True) -> Any: + """Transform the Snowflake data into a ready-to-use Pytorch datapipe. + + Return a Pytorch datapipe which iterates on rows of data. + + Args: + batch_size: It specifies the size of each data batch which will be + yield in the result datapipe + shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and + rows in each file will also be shuffled. + drop_last_batch: Whether the last batch of data should be dropped. If set to be true, + then the last batch will get dropped if its size is smaller than the given batch_size. + + Returns: + A Pytorch iterable datapipe that yield data. + + Examples: + >>> dp = dataset.to_torch_datapipe(batch_size=1) + >>> for data in dp: + >>> print(data) + ---- + {'_COL_1':[10]} + """ + IterableWrapper, _ = import_utils.import_or_get_dummy("torchdata.datapipes.iter.IterableWrapper") + torch_datapipe_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.torch_datapipe") + + self._fs.optimize_read(self._list_files()) + + input_dp = IterableWrapper(self._list_files()) + return torch_datapipe_module.ReadAndParseParquet(input_dp, self._fs, batch_size, shuffle, drop_last_batch) + + @telemetry.send_api_usage_telemetry( + project=_PROJECT, + subproject=_SUBPROJECT, + func_params_to_log=["batch_size", "shuffle", "drop_last_batch"], + ) + def to_tf_dataset(self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True) -> Any: + """Transform the Snowflake data into a ready-to-use TensorFlow tf.data.Dataset. + + Args: + batch_size: It specifies the size of each data batch which will be + yield in the result datapipe + shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and + rows in each file will also be shuffled. + drop_last_batch: Whether the last batch of data should be dropped. If set to be true, + then the last batch will get dropped if its size is smaller than the given batch_size. + + Returns: + A tf.data.Dataset that yields batched tf.Tensors. + + Examples: + >>> dp = dataset.to_tf_dataset(batch_size=1) + >>> for data in dp: + >>> print(data) + ---- + {'_COL_1': } + """ + tf_dataset_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.tf_dataset") + + self._fs.optimize_read(self._list_files()) + + return tf_dataset_module.read_and_parse_parquet( + self._list_files(), self._fs, batch_size, shuffle, drop_last_batch + ) + + @telemetry.send_api_usage_telemetry( + project=_PROJECT, + subproject=_SUBPROJECT, + func_params_to_log=["only_feature_cols"], + ) + def to_snowpark_dataframe(self, only_feature_cols: bool = False) -> snowpark.DataFrame: + """Convert the DatasetVersion to a Snowpark DataFrame. + + Args: + only_feature_cols: If True, drops exclude_cols and label_cols from returned DataFrame. + The original DatasetVersion is unaffected. + + Returns: + A Snowpark dataframe that contains the data of this DatasetVersion. + + Note: The dataframe generated by this method might not have the same schema as the original one. Specifically, + - NUMBER type with scale != 0 will become float. + - Unsupported types (see comments of :func:`Dataset.create_version`) will not have any guarantee. + For example, an OBJECT column may be scanned back as a STRING column. + """ + file_path_pattern = ".*data_.*[.]parquet" + dfs: List[snowpark.DataFrame] = [] + for source in self._sources: + df = self._session.read.option("pattern", file_path_pattern).parquet(source.url) + if only_feature_cols and source.exclude_cols: + df = df.drop(source.exclude_cols) + dfs.append(df) + + combined_df = dfs[0] + for df in dfs[1:]: + combined_df = combined_df.union_all_by_name(df) + return dataset_dataframe.DatasetDataFrame.from_dataframe(combined_df, data_sources=self._sources, inplace=True) + + @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT) + def to_pandas(self) -> pd.DataFrame: + """Retrieve the DatasetVersion contents as a Pandas Dataframe""" + files = self._list_files() + if not files: + return pd.DataFrame() # Return empty DataFrame + self._fs.optimize_read(files) + pd_dfs = [] + for file in files: + with self._fs.open(file) as fp: + pd_dfs.append(pd.read_parquet(fp)) + pd_df = pd_dfs[0] if len(pd_dfs) == 1 else pd.concat(pd_dfs, ignore_index=True, copy=False) + return pd_df diff --git a/snowflake/ml/dataset/notebooks/internal_demo/Dataset_Basic_Demo.ipynb b/snowflake/ml/dataset/notebooks/internal_demo/Dataset_Basic_Demo.ipynb new file mode 100644 index 00000000..89e60e5f --- /dev/null +++ b/snowflake/ml/dataset/notebooks/internal_demo/Dataset_Basic_Demo.ipynb @@ -0,0 +1,444 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e645315e-9a73-4cb0-b72e-a1ecb32abf1d", + "metadata": {}, + "source": [ + "# Setup Environment" + ] + }, + { + "cell_type": "markdown", + "id": "f5652801-1259-439e-8b70-df7d1995916b", + "metadata": {}, + "source": [ + "## Import Dependencies and Create Session" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79bb6a13-5b93-4eff-87c6-7e65cc8398ed", + "metadata": {}, + "outputs": [], + "source": [ + "from snowflake.snowpark import Session, functions as F\n", + "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", + "from snowflake.ml import dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5c797d0-f2cd-4b17-a3ac-8445e0f83ddf", + "metadata": {}, + "outputs": [], + "source": [ + "session = Session.builder.configs(SnowflakeLoginOptions()).create()\n", + "print(session)\n", + "\n", + "TEST_DATASET_DB = \"DATASET_DEMO_DB\"\n", + "TEST_DATASET_SCHEMA = \"DATASET_DEMO_SCHEMA\"\n", + "session.sql(f\"CREATE DATABASE IF NOT EXISTS {TEST_DATASET_DB}\").collect()\n", + "session.sql(f\"\"\"\n", + " CREATE SCHEMA IF NOT EXISTS \n", + " {TEST_DATASET_DB}.{TEST_DATASET_SCHEMA}\"\"\").collect()\n", + "session.use_database(TEST_DATASET_DB)\n", + "session.use_schema(TEST_DATASET_SCHEMA)" + ] + }, + { + "cell_type": "markdown", + "id": "dc7cdc84-5f2f-491d-97c6-9a0d22f294bc", + "metadata": {}, + "source": [ + "# Prepare test data\n", + "\n", + "We will use the [diamond price dataset](https://ggplot2.tidyverse.org/reference/diamonds.html) for this demo. The data can be downloaded from https://raw.githubusercontent.com/tidyverse/ggplot2/main/data-raw/diamonds.csv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "144abd6b-b56e-481f-aa07-0806c1ec32ab", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from sklearn.preprocessing import StandardScaler, LabelEncoder\n", + "\n", + "data_url = \"https://raw.githubusercontent.com/tidyverse/ggplot2/main/data-raw/diamonds.csv\"\n", + "data_pd = pd.read_csv(data_url)\n", + "\n", + "# Encode categorical variables: cut, color, clarity\n", + "label_encoder = LabelEncoder()\n", + "data_pd['cut'] = label_encoder.fit_transform(data_pd['cut'])\n", + "data_pd['color'] = label_encoder.fit_transform(data_pd['color'])\n", + "data_pd['clarity'] = label_encoder.fit_transform(data_pd['clarity'])\n", + "\n", + "# Scale numerical features: carat, x, y, z, depth, table\n", + "numerical_features = ['carat', 'x', 'y', 'z', 'depth', 'table']\n", + "scaler = StandardScaler()\n", + "data_pd[numerical_features] = scaler.fit_transform(data_pd[numerical_features])\n", + "\n", + "df = session.create_dataframe(data_pd)\n", + "df.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b931e01e-9483-44a2-b4bb-9b80719bae3a", + "metadata": {}, + "source": [ + "Let's create a Snowflake Dataset from the raw dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcc2d826-2381-40d3-9b27-b01fc6c5ec66", + "metadata": {}, + "outputs": [], + "source": [ + "ds_name = f\"{TEST_DATASET_DB}.{TEST_DATASET_SCHEMA}.wine_data\"\n", + "ds_version = \"v1\"\n", + "\n", + "session.sql(f\"DROP DATASET IF EXISTS {ds_name}\").collect()\n", + "ds = dataset.create_from_dataframe(\n", + " session,\n", + " name=ds_name,\n", + " version=ds_version,\n", + " input_dataframe=df,\n", + " label_cols=[\"price\"],\n", + ")\n", + "\n", + "print(f\"Dataset: {ds.fully_qualified_name}\")\n", + "print(f\"Selected version: {ds.selected_version.name} ({ds.selected_version})\")\n", + "print(f\"Available versions: {ds.list_versions()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "dcd07ee9-765b-440a-9ec4-2e9a4e6adde2", + "metadata": {}, + "source": [ + "The Dataset object includes various connectors under the `read` property which we can use to inspect or consume the Dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca80c3e6-c888-4071-b3da-f4c5393e1889", + "metadata": {}, + "outputs": [], + "source": [ + "print([f for f in dir(ds.read) if not f.startswith('_') and callable(getattr(ds.read, f))])\n", + "\n", + "print(ds.read.files())\n", + "print(ds.read.to_pandas().shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "919f9e7b-6428-41cc-88d3-70fd08b0d1f6", + "metadata": {}, + "outputs": [], + "source": [ + "type(ds.read)" + ] + }, + { + "cell_type": "markdown", + "id": "2a11261b-c7ea-4bb9-b0b0-7a711dd63cec", + "metadata": {}, + "source": [ + "We could use this dataset as-is and do any train/test split at runtime if needed. However, we might want to guarantee consistent splitting by saving the pre-split dataset as versions of our Snowflake Dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f25fe63-c2a4-4832-b9ac-6f4c301624e8", + "metadata": {}, + "outputs": [], + "source": [ + "test_ratio = 0.2\n", + "uniform_min, uniform_max = 1, 10\n", + "pivot = (uniform_max - uniform_min + 1) * test_ratio\n", + "df_aug = df.with_column(\"_UNIFORM\", F.uniform(uniform_min, uniform_max, F.random()))\n", + "ds.create_version(\n", + " version=\"train\",\n", + " input_dataframe=df_aug.where(df_aug.col(\"_UNIFORM\") > pivot).drop(df_aug.col(\"_UNIFORM\")),\n", + " label_cols=[\"price\"],\n", + ")\n", + "ds.create_version(\n", + " version=\"test\",\n", + " input_dataframe=df_aug.where(df_aug.col(\"_UNIFORM\") <= pivot).drop(df_aug.col(\"_UNIFORM\")),\n", + " label_cols=[\"price\"],\n", + ")\n", + "\n", + "print(ds.list_versions())\n", + "\n", + "train_ds = ds.select_version(\"train\")\n", + "test_ds = ds.select_version(\"test\")\n", + "\n", + "print(\"train rows:\", train_ds.read.to_snowpark_dataframe().count())\n", + "print(\"test rows:\", test_ds.read.to_snowpark_dataframe().count())" + ] + }, + { + "cell_type": "markdown", + "id": "6ca2ea4d-d8c1-4ef9-a7c5-78c8b3aa8779", + "metadata": {}, + "source": [ + "# Model Training\n", + "\n", + "Let's train and evaluate a basic PyTorch model using our newly created Snowflake Datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e294e5ef-3d5c-467f-a044-f7710fb2b566", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.ensemble import RandomForestRegressor\n", + "from sklearn.metrics import mean_squared_error\n", + "\n", + "train_pd = train_ds.read.to_pandas()\n", + "X_train = train_pd.drop(columns=[\"price\"])\n", + "y_train = train_pd[\"price\"]\n", + "rf_regressor = RandomForestRegressor(n_estimators=100, random_state=42)\n", + "rf_regressor.fit(X_train, y_train)\n", + "\n", + "# Evaluate the Model\n", + "test_pd = test_ds.read.to_pandas()\n", + "X_test = test_pd.drop(columns=[\"price\"])\n", + "y_test = test_pd[\"price\"]\n", + "y_pred = rf_regressor.predict(X_test)\n", + "\n", + "# Calculate the Mean Squared Error\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "print(\"Mean Squared Error:\", mse)" + ] + }, + { + "cell_type": "markdown", + "id": "0847e2be-fbb8-49aa-8124-433c99a42149", + "metadata": {}, + "source": [ + "We can run this same model in a stored procedure" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fee64ad7-af71-495f-8746-f2fbb3b41d12", + "metadata": {}, + "outputs": [], + "source": [ + "local_code_imports = [\n", + " (os.path.join(snowml_path, 'snowflake', 'ml', '_internal'), 'snowflake.ml._internal'),\n", + " (os.path.join(snowml_path, 'snowflake', 'ml', 'fileset'), 'snowflake.ml.fileset'),\n", + " (os.path.join(snowml_path, 'snowflake', 'ml', 'dataset'), 'snowflake.ml.dataset'),\n", + "]\n", + "for t in local_code_imports:\n", + " session.add_import(*t, whole_file_hash=True)\n", + " \n", + "deps = [\n", + " \"snowflake-snowpark-python\",\n", + " \"snowflake-ml-python\",\n", + " \"cryptography\",\n", + "]\n", + "\n", + "@F.sproc(session=session, packages=deps)\n", + "def ds_sproc(session: Session) -> float:\n", + " train_ds = dataset.load_dataset(session, ds_name, \"train\")\n", + " test_ds = dataset.load_dataset(session, ds_name, \"test\")\n", + "\n", + " train_pd = train_ds.read.to_pandas()\n", + " X_train = train_pd.drop(columns=[\"price\"])\n", + " y_train = train_pd[\"price\"]\n", + " rf_regressor = RandomForestRegressor(n_estimators=100, random_state=42)\n", + " rf_regressor.fit(X_train, y_train)\n", + "\n", + " # Evaluate the Model\n", + " test_pd = test_ds.read.to_pandas()\n", + " X_test = test_pd.drop(columns=[\"price\"])\n", + " y_test = test_pd[\"price\"]\n", + " y_pred = rf_regressor.predict(X_test)\n", + "\n", + " # Calculate the Mean Squared Error\n", + " return mean_squared_error(y_test, y_pred)\n", + "\n", + "print(\"Mean Squared Error:\", ds_sproc(session))\n", + "session.clear_imports()" + ] + }, + { + "cell_type": "markdown", + "id": "39bb0905-d544-4032-a8da-12673c202f6d", + "metadata": {}, + "source": [ + "We can also use Dataset's connector APIs to integrate with ML frameworks like PyTorch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42996aad-95c5-44b9-b320-218e5cdd66ee", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "from torch import nn, optim\n", + "\n", + "class DiamondPricePredictor(nn.Module):\n", + " def __init__(self):\n", + " super(DiamondPricePredictor, self).__init__()\n", + " self.fc1 = nn.Linear(9, 64)\n", + " self.fc2 = nn.Linear(64, 32)\n", + " self.fc3 = nn.Linear(32, 1)\n", + " self.relu = nn.ReLU()\n", + " \n", + " def forward(self, carat, cut, color, clarity, depth, table, x, y, z):\n", + " X = torch.cat((carat, cut, color, clarity, depth, table, x, y, z), axis=1)\n", + " X = self.relu(self.fc1(X))\n", + " X = self.relu(self.fc2(X))\n", + " X = self.fc3(X)\n", + " return X\n", + "\n", + "\n", + "def train_model(model: nn.Module, ds: dataset.Dataset, batch_size: int = 32, num_epochs: int = 10, learning_rate: float = 1e-3):\n", + " model.train()\n", + "\n", + " # Define loss function and optimizer\n", + " criterion = nn.MSELoss()\n", + " optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n", + "\n", + " # Training loop\n", + " for epoch in range(num_epochs):\n", + " for batch in ds.read.to_torch_datapipe(batch_size=batch_size):\n", + " targets = torch.from_numpy(batch.pop(\"price\")).unsqueeze(1).to(torch.float32)\n", + " inputs = {k:torch.from_numpy(v).unsqueeze(1) for k,v in batch.items()}\n", + " \n", + " # Forward pass\n", + " outputs = model(**inputs)\n", + " loss = criterion(outputs, targets)\n", + " \n", + " # Backward and optimize\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " return model\n", + "\n", + "def eval_model(model: nn.Module, ds: dataset.Dataset, batch_size: int = 32) -> float:\n", + " model.eval()\n", + " mse = 0.0\n", + " with torch.no_grad():\n", + " for batch in ds.read.to_torch_datapipe(batch_size=batch_size):\n", + " targets = torch.from_numpy(batch.pop(\"price\")).unsqueeze(1).to(torch.float32)\n", + " inputs = {k:torch.from_numpy(v).unsqueeze(1) for k,v in batch.items()}\n", + "\n", + " outputs = model(**inputs)\n", + " mse += nn.functional.mse_loss(outputs, targets).item()\n", + " return mse\n", + "\n", + "model = DiamondPricePredictor()\n", + "train_model(model, train_ds)\n", + "eval_model(model, test_ds)" + ] + }, + { + "cell_type": "markdown", + "id": "d1009685-2c71-4db8-97a0-947f12d693d7", + "metadata": {}, + "source": [ + "(WIP) We can pass the Datasets into SnowML modeling APIs using either Snowpark DataFrame or Pandas DataFrame" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ff1a3bb-2747-43e4-80b1-4d847ebff347", + "metadata": {}, + "outputs": [], + "source": [ + "from snowflake.ml.modeling.xgboost import XGBRegressor\n", + "\n", + "FEATURE_COLS = [\"carat\", \"cut\", \"color\", \"clarity\", \"depth\", \"table\", \"x\", \"y\", \"z\"]\n", + "LABEL_COLS = [\"price\"]\n", + "\n", + "# Train an XGBoost model on snowflake.\n", + "xgboost_model = XGBRegressor(\n", + " input_cols=FEATURE_COLS,\n", + " label_cols=LABEL_COLS,\n", + ")\n", + "\n", + "xgboost_model.fit(train_ds.read.to_snowpark_dataframe())\n", + "\n", + "# Use the model to make predictions.\n", + "predictions = xgboost_model.predict(test_ds.read.to_snowpark_dataframe())" + ] + }, + { + "cell_type": "markdown", + "id": "bcfad573-9e8b-429e-af38-c4d3316cbb5a", + "metadata": {}, + "source": [ + "# Future Work\n", + "\n", + "There are several features which are still on the horizon for the Dataset client API, such as:\n", + "1. Adding multi-version Dataset support\n", + "2. Adding exclude_cols handling to all connectors (`to_pandas()`, `to_torch_datapipe()`, etc)\n", + "3. Consolidating FileSet functionality (reading from internal stage) into dataset.DataReader" + ] + }, + { + "cell_type": "markdown", + "id": "74e6727b-db3c-4e3e-9201-001ad5e0e98e", + "metadata": {}, + "source": [ + "# Clean Up Resources" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "040fffb3-544d-4110-b941-b6230ec14604", + "metadata": {}, + "outputs": [], + "source": [ + "session.sql(f\"DROP SCHEMA IF EXISTS {TEST_DATASET_SCHEMA}\").collect()\n", + "session.sql(f\"DROP DATABASE IF EXISTS {TEST_DATASET_DB}\").collect()\n", + "session.close()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/snowflake/ml/feature_store/BUILD.bazel b/snowflake/ml/feature_store/BUILD.bazel index 0366ead5..b926f752 100644 --- a/snowflake/ml/feature_store/BUILD.bazel +++ b/snowflake/ml/feature_store/BUILD.bazel @@ -29,6 +29,7 @@ py_library( deps = [ ":init", "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/lineage:dataset_dataframe", "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/_internal/utils:query_result_checker", "//snowflake/ml/_internal/utils:sql_identifier", diff --git a/snowflake/ml/feature_store/feature_store.py b/snowflake/ml/feature_store/feature_store.py index a1b91b02..3b6fcc6a 100644 --- a/snowflake/ml/feature_store/feature_store.py +++ b/snowflake/ml/feature_store/feature_store.py @@ -8,13 +8,17 @@ import warnings from dataclasses import dataclass from enum import Enum -from typing import Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast +import packaging.version as pkg_version +import snowflake.ml.version as snowml_version from pytimeparse.timeparse import timeparse from typing_extensions import Concatenate, ParamSpec +from snowflake.ml import dataset from snowflake.ml._internal import telemetry from snowflake.ml._internal.exceptions import ( + dataset_errors, error_codes, exceptions as snowml_exceptions, ) @@ -23,12 +27,8 @@ SqlIdentifier, to_sql_identifiers, ) -from snowflake.ml.dataset.dataset import Dataset, FeatureStoreMetadata -from snowflake.ml.feature_store.entity import ( - _ENTITY_NAME_LENGTH_LIMIT, - _FEATURE_VIEW_ENTITY_TAG_DELIMITER, - Entity, -) +from snowflake.ml.dataset.dataset_metadata import FeatureStoreMetadata +from snowflake.ml.feature_store.entity import _ENTITY_NAME_LENGTH_LIMIT, Entity from snowflake.ml.feature_store.feature_view import ( _FEATURE_OBJ_TYPE, _FEATURE_VIEW_NAME_DELIMITER, @@ -37,11 +37,17 @@ FeatureViewSlice, FeatureViewStatus, FeatureViewVersion, + _FeatureViewMetadata, ) from snowflake.snowpark import DataFrame, Row, Session, functions as F -from snowflake.snowpark._internal import type_utils, utils as snowpark_utils from snowflake.snowpark.exceptions import SnowparkSQLException -from snowflake.snowpark.types import StructField +from snowflake.snowpark.types import ( + ArrayType, + StringType, + StructField, + StructType, + TimestampType, +) _Args = ParamSpec("_Args") _RT = TypeVar("_RT") @@ -49,38 +55,80 @@ logger = logging.getLogger(__name__) _ENTITY_TAG_PREFIX = "SNOWML_FEATURE_STORE_ENTITY_" -_FEATURE_VIEW_ENTITY_TAG = "SNOWML_FEATURE_STORE_FV_ENTITIES" -_FEATURE_VIEW_TS_COL_TAG = "SNOWML_FEATURE_STORE_FV_TS_COL" _FEATURE_STORE_OBJECT_TAG = "SNOWML_FEATURE_STORE_OBJECT" +_FEATURE_VIEW_METADATA_TAG = "SNOWML_FEATURE_VIEW_METADATA" + + +@dataclass(frozen=True) +class _FeatureStoreObjInfo: + type: _FeatureStoreObjTypes + pkg_version: str + + def to_json(self) -> str: + state_dict = self.__dict__.copy() + state_dict["type"] = state_dict["type"].value + return json.dumps(state_dict) + + @classmethod + def from_json(cls, json_str: str) -> _FeatureStoreObjInfo: + json_dict = json.loads(json_str) + # since we may introduce new fields in the json blob in the future, + # in order to guarantee compatibility, we need to select ones that can be + # decoded in the current version + state_dict = {} + state_dict["type"] = _FeatureStoreObjTypes.parse(json_dict["type"]) + state_dict["pkg_version"] = json_dict["pkg_version"] + return cls(**state_dict) # type: ignore[arg-type] # TODO: remove "" after dataset is updated class _FeatureStoreObjTypes(Enum): - FEATURE_VIEW = "FEATURE_VIEW" + UNKNOWN = "UNKNOWN" # for forward compatibility + MANAGED_FEATURE_VIEW = "MANAGED_FEATURE_VIEW" + EXTERNAL_FEATURE_VIEW = "EXTERNAL_FEATURE_VIEW" FEATURE_VIEW_REFRESH_TASK = "FEATURE_VIEW_REFRESH_TASK" TRAINING_DATA = "" + @classmethod + def parse(cls, val: str) -> _FeatureStoreObjTypes: + try: + return cls(val) + except ValueError: + return cls.UNKNOWN + _PROJECT = "FeatureStore" _DT_OR_VIEW_QUERY_PATTERN = re.compile( r"""CREATE\ (OR\ REPLACE\ )?(?P(DYNAMIC\ TABLE|VIEW))\ .* COMMENT\ =\ '(?P.*)'\s* - TAG.*?{entity_tag}\ =\ '(?P.*?)',\n - .*?{ts_col_tag}\ =\ '(?P.*?)',?.*? + TAG.*?{fv_metadata_tag}\ =\ '(?P.*?)',?.*? AS\ (?P.*) """.format( - entity_tag=_FEATURE_VIEW_ENTITY_TAG, ts_col_tag=_FEATURE_VIEW_TS_COL_TAG + fv_metadata_tag=_FEATURE_VIEW_METADATA_TAG, ), flags=re.DOTALL | re.IGNORECASE | re.X, ) +_LIST_FEATURE_VIEW_SCHEMA = StructType( + [ + StructField("name", StringType()), + StructField("version", StringType()), + StructField("database_name", StringType()), + StructField("schema_name", StringType()), + StructField("created_on", TimestampType()), + StructField("owner", StringType()), + StructField("desc", StringType()), + StructField("entities", ArrayType(StringType())), + ] +) + class CreationMode(Enum): FAIL_IF_NOT_EXIST = 1 CREATE_IF_NOT_EXIST = 2 -@dataclass +@dataclass(frozen=True) class _FeatureStoreConfig: database: SqlIdentifier schema: SqlIdentifier @@ -111,14 +159,14 @@ def wrapper(self: FeatureStore, /, *args: _Args.args, **kargs: _Args.kwargs) -> return wrapper -def dispatch_decorator( - prpr_version: str, -) -> Callable[[Callable[Concatenate[FeatureStore, _Args], _RT]], Callable[Concatenate[FeatureStore, _Args], _RT],]: +def dispatch_decorator() -> Callable[ + [Callable[Concatenate[FeatureStore, _Args], _RT]], + Callable[Concatenate[FeatureStore, _Args], _RT], +]: def decorator( f: Callable[Concatenate[FeatureStore, _Args], _RT] ) -> Callable[Concatenate[FeatureStore, _Args], _RT]: @telemetry.send_api_usage_telemetry(project=_PROJECT) - @snowpark_utils.private_preview(version=prpr_version) @switch_warehouse @functools.wraps(f) def wrap(self: FeatureStore, /, *args: _Args.args, **kargs: _Args.kwargs) -> _RT: @@ -135,7 +183,6 @@ class FeatureStore: """ @telemetry.send_api_usage_telemetry(project=_PROJECT) - @snowpark_utils.private_preview(version="1.0.8") def __init__( self, session: Session, @@ -178,7 +225,7 @@ def __init__( # search space used in query "SHOW LIKE IN " # object domain used in query "TAG_REFERENCE(, )" self._obj_search_spaces = { - "TABLES": (self._config.full_schema_path, "TABLE"), + "DATASETS": (self._config.full_schema_path, "DATASET"), "DYNAMIC TABLES": (self._config.full_schema_path, "TABLE"), "VIEWS": (self._config.full_schema_path, "TABLE"), "SCHEMAS": (f"DATABASE {self._config.database}", "SCHEMA"), @@ -200,8 +247,7 @@ def __init__( ) for tag in to_sql_identifiers( [ - _FEATURE_VIEW_ENTITY_TAG, - _FEATURE_VIEW_TS_COL_TAG, + _FEATURE_VIEW_METADATA_TAG, ] ): self._session.sql(f"CREATE TAG IF NOT EXISTS {self._get_fully_qualified_name(tag)}").collect( @@ -209,8 +255,7 @@ def __init__( ) self._session.sql( - f"""CREATE TAG IF NOT EXISTS {self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)} - ALLOWED_VALUES {','.join([f"'{v.value}'" for v in _FeatureStoreObjTypes])}""" + f"CREATE TAG IF NOT EXISTS {self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)}" ).collect(statement_params=self._telemetry_stmp) except Exception as e: self.clear() @@ -219,10 +264,12 @@ def __init__( original_exception=RuntimeError(f"Failed to create feature store {name}: {e}."), ) + # TODO: remove this after tag_ref_internal rollout + self._use_optimized_tag_ref = self._tag_ref_internal_enabled() + self._check_feature_store_object_versions() logger.info(f"Successfully connected to feature store: {self._config.full_schema_path}.") @telemetry.send_api_usage_telemetry(project=_PROJECT) - @snowpark_utils.private_preview(version="1.0.12") def update_default_warehouse(self, warehouse_name: str) -> None: """Update default warehouse for feature store. @@ -242,7 +289,7 @@ def update_default_warehouse(self, warehouse_name: str) -> None: self._default_warehouse = warehouse - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def register_entity(self, entity: Entity) -> Entity: """ Register Entity in the FeatureStore. @@ -268,13 +315,13 @@ def register_entity(self, entity: Entity) -> Entity: return entity # allowed_values will add double-quotes around each value, thus use resolved str here. - join_keys = [f"'{key.resolved()}'" for key in entity.join_keys] + join_keys = [f"{key.resolved()}" for key in entity.join_keys] join_keys_str = ",".join(join_keys) full_tag_name = self._get_fully_qualified_name(tag_name) try: self._session.sql( f"""CREATE TAG IF NOT EXISTS {full_tag_name} - ALLOWED_VALUES {join_keys_str} + ALLOWED_VALUES '{join_keys_str}' COMMENT = '{entity.desc}' """ ).collect(statement_params=self._telemetry_stmp) @@ -289,7 +336,7 @@ def register_entity(self, entity: Entity) -> Entity: return self.get_entity(entity.name) # TODO: add support to update column desc once SNOW-894249 is fixed - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def register_feature_view( self, feature_view: FeatureView, @@ -342,7 +389,6 @@ def register_feature_view( ), ) - # TODO: ideally we should move this to FeatureView creation time for e in feature_view.entities: if not self._validate_entity_exists(e.name): raise snowml_exceptions.SnowflakeMLException( @@ -358,12 +404,23 @@ def register_feature_view( pass fully_qualified_name = self._get_fully_qualified_name(feature_view_name) - entities = _FEATURE_VIEW_ENTITY_TAG_DELIMITER.join([e.name for e in feature_view.entities]) - timestamp_col = ( - feature_view.timestamp_col - if feature_view.timestamp_col is not None - else SqlIdentifier(_TIMESTAMP_COL_PLACEHOLDER) - ) + refresh_freq = feature_view.refresh_freq + + if refresh_freq is not None: + obj_info = _FeatureStoreObjInfo(_FeatureStoreObjTypes.MANAGED_FEATURE_VIEW, snowml_version.VERSION) + else: + obj_info = _FeatureStoreObjInfo(_FeatureStoreObjTypes.EXTERNAL_FEATURE_VIEW, snowml_version.VERSION) + + tagging_clause = [ + f"{self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)} = '{obj_info.to_json()}'", + f"{self._get_fully_qualified_name(_FEATURE_VIEW_METADATA_TAG)} = '{feature_view._metadata().to_json()}'", + ] + for e in feature_view.entities: + join_keys = [f"{key.resolved()}" for key in e.join_keys] + tagging_clause.append( + f"{self._get_fully_qualified_name(self._get_entity_name(e.name))} = '{','.join(join_keys)}'" + ) + tagging_clause_str = ",\n".join(tagging_clause) def create_col_desc(col: StructField) -> str: desc = feature_view.feature_descs.get(SqlIdentifier(col.name), None) @@ -371,7 +428,6 @@ def create_col_desc(col: StructField) -> str: return f"{col.name} {desc}" column_descs = ", ".join([f"{create_col_desc(col)}" for col in feature_view.output_schema.fields]) - refresh_freq = feature_view.refresh_freq if refresh_freq is not None: schedule_task = refresh_freq != "DOWNSTREAM" and timeparse(refresh_freq) is None @@ -380,10 +436,9 @@ def create_col_desc(col: StructField) -> str: feature_view, fully_qualified_name, column_descs, - entities, + tagging_clause_str, schedule_task, self._default_warehouse, - timestamp_col, block, overwrite, ) @@ -393,9 +448,7 @@ def create_col_desc(col: StructField) -> str: query = f"""CREATE{overwrite_clause} VIEW {fully_qualified_name} ({column_descs}) COMMENT = '{feature_view.desc}' TAG ( - {_FEATURE_VIEW_ENTITY_TAG} = '{entities}', - {_FEATURE_VIEW_TS_COL_TAG} = '{timestamp_col}', - {_FEATURE_STORE_OBJECT_TAG} = '{_FeatureStoreObjTypes.FEATURE_VIEW.value}' + {tagging_clause_str} ) AS {feature_view.query} """ @@ -406,10 +459,10 @@ def create_col_desc(col: StructField) -> str: original_exception=RuntimeError(f"Create view {fully_qualified_name} [\n{query}\n] failed: {e}"), ) from e - logger.info(f"Registered FeatureView {feature_view.name}/{version}.") + logger.info(f"Registered FeatureView {feature_view.name}/{version} successfully.") return self.get_feature_view(feature_view.name, str(version)) - @dispatch_decorator(prpr_version="1.1.0") + @dispatch_decorator() def update_feature_view( self, name: str, version: str, refresh_freq: Optional[str] = None, warehouse: Optional[str] = None ) -> FeatureView: @@ -456,7 +509,7 @@ def update_feature_view( ) from e return self.get_feature_view(name=name, version=version) - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def read_feature_view(self, feature_view: FeatureView) -> DataFrame: """ Read FeatureView data. @@ -478,13 +531,12 @@ def read_feature_view(self, feature_view: FeatureView) -> DataFrame: return self._session.sql(f"SELECT * FROM {feature_view.fully_qualified_name()}") - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def list_feature_views( self, entity_name: Optional[str] = None, feature_view_name: Optional[str] = None, - as_dataframe: bool = True, - ) -> Union[Optional[DataFrame], List[FeatureView]]: + ) -> DataFrame: """ List FeatureViews in the FeatureStore. If entity_name is specified, FeatureViews associated with that Entity will be listed. @@ -493,34 +545,26 @@ def list_feature_views( Args: entity_name: Entity name. feature_view_name: FeatureView name. - as_dataframe: whether the return type should be a DataFrame. Returns: - List of FeatureViews or in a DataFrame representation. + FeatureViews information as a Snowpark DataFrame. """ - if entity_name is not None: - entity_name = SqlIdentifier(entity_name) if feature_view_name is not None: feature_view_name = SqlIdentifier(feature_view_name) if entity_name is not None: - fvs = self._find_feature_views(entity_name, feature_view_name) + entity_name = SqlIdentifier(entity_name) + if self._use_optimized_tag_ref: + return self._optimized_find_feature_views(entity_name, feature_view_name) + else: + return self._find_feature_views(entity_name, feature_view_name) else: - fvs = [] - entities = self.list_entities().collect() + output_values: List[List[Any]] = [] for row in self._get_fv_backend_representations(feature_view_name, prefix_match=True): - fvs.append(self._compose_feature_view(row, entities)) - - if as_dataframe: - result = None - for fv in fvs: - fv_df = fv.to_df(self._session) - result = fv_df if result is None else result.union(fv_df) # type: ignore[attr-defined] - return result - else: - return fvs + self._extract_feature_view_info(row, output_values) + return self._session.create_dataframe(output_values, schema=_LIST_FEATURE_VIEW_SCHEMA) - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def get_feature_view(self, name: str, version: str) -> FeatureView: """ Retrieve previously registered FeatureView. @@ -549,7 +593,7 @@ def get_feature_view(self, name: str, version: str) -> FeatureView: return self._compose_feature_view(results[0], self.list_entities().collect()) - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def resume_feature_view(self, feature_view: FeatureView) -> FeatureView: """ Resume a previously suspended FeatureView. @@ -562,7 +606,7 @@ def resume_feature_view(self, feature_view: FeatureView) -> FeatureView: """ return self._update_feature_view_status(feature_view, "RESUME") - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def suspend_feature_view(self, feature_view: FeatureView) -> FeatureView: """ Suspend an active FeatureView. @@ -575,7 +619,7 @@ def suspend_feature_view(self, feature_view: FeatureView) -> FeatureView: """ return self._update_feature_view_status(feature_view, "SUSPEND") - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def delete_feature_view(self, feature_view: FeatureView) -> None: """ Delete a FeatureView. @@ -586,6 +630,8 @@ def delete_feature_view(self, feature_view: FeatureView) -> None: Raises: SnowflakeMLException: [ValueError] FeatureView is not registered. """ + # TODO: we should leverage lineage graph to check downstream deps, and block the deletion + # if there're other FVs depending on this if feature_view.status == FeatureViewStatus.DRAFT or feature_view.version is None: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.NOT_FOUND, @@ -608,7 +654,7 @@ def delete_feature_view(self, feature_view: FeatureView) -> None: logger.info(f"Deleted FeatureView {feature_view.name}/{feature_view.version}.") - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def list_entities(self) -> DataFrame: """ List all Entities in the FeatureStore. @@ -629,7 +675,7 @@ def list_entities(self) -> DataFrame: ), ) - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def get_entity(self, name: str) -> Entity: """ Retrieve previously registered Entity object. @@ -659,8 +705,7 @@ def get_entity(self, name: str) -> Entity: original_exception=ValueError(f"Cannot find Entity with name: {name}."), ) - raw_join_keys = result[0]["JOIN_KEYS"] - join_keys = raw_join_keys.strip("[]").split(",") + join_keys = self._recompose_join_keys(result[0]["JOIN_KEYS"]) return Entity._construct_entity( name=SqlIdentifier(result[0]["NAME"], case_sensitive=True).identifier(), @@ -669,7 +714,7 @@ def get_entity(self, name: str) -> Entity: owner=result[0]["OWNER"], ) - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def delete_entity(self, name: str) -> None: """ Delete a previously registered Entity. @@ -690,13 +735,13 @@ def delete_entity(self, name: str) -> None: original_exception=ValueError(f"Entity {name} does not exist."), ) - active_feature_views = cast(List[FeatureView], self.list_feature_views(entity_name=name, as_dataframe=False)) + active_feature_views = self.list_feature_views(entity_name=name).collect(statement_params=self._telemetry_stmp) + if len(active_feature_views) > 0: + active_fvs = [r["NAME"] for r in active_feature_views] raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.SNOWML_DELETE_FAILED, - original_exception=ValueError( - f"Cannot delete Entity {name} due to active FeatureViews: {[f.name for f in active_feature_views]}." - ), + original_exception=ValueError(f"Cannot delete Entity {name} due to active FeatureViews: {active_fvs}."), ) tag_name = self._get_fully_qualified_name(self._get_entity_name(name)) @@ -709,7 +754,7 @@ def delete_entity(self, name: str) -> None: ) from e logger.info(f"Deleted Entity {name}.") - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def retrieve_feature_values( self, spine_df: DataFrame, @@ -757,39 +802,35 @@ def retrieve_feature_values( return df - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def generate_dataset( self, + name: str, spine_df: DataFrame, features: List[Union[FeatureView, FeatureViewSlice]], - materialized_table: Optional[str] = None, + version: Optional[str] = None, spine_timestamp_col: Optional[str] = None, spine_label_cols: Optional[List[str]] = None, exclude_columns: Optional[List[str]] = None, - save_mode: str = "errorifexists", include_feature_view_timestamp_col: bool = False, desc: str = "", - ) -> Dataset: + ) -> dataset.Dataset: """ Generate dataset by given source table and feature views. Args: + name: The name of the Dataset to be generated. Datasets are uniquely identified within a schema + by their name and version. spine_df: The fact table contains the raw dataset. features: A list of FeatureView or FeatureViewSlice which contains features to be joined. - materialized_table: The destination table where produced result will be stored. If it's none, then result - won't be registered. If materialized_table is provided, then produced result will be written into - the provided table. Note result dataset will be a snowflake clone of registered table. - New data can append on same registered table and previously generated dataset won't be affected. - Default result table name will be a concatenation of materialized_table name and current timestamp. + version: The version of the Dataset to be generated. If none specified, the current timestamp + will be used instead. spine_timestamp_col: Name of timestamp column in spine_df that will be used to join time-series features. If spine_timestamp_col is not none, the input features also must have timestamp_col. spine_label_cols: Name of column(s) in spine_df that contains labels. exclude_columns: Column names to exclude from the result dataframe. The underlying storage will still contain the columns. - save_mode: How new data is saved. currently support: - errorifexists: Raise error if registered table already exists. - merge: Merge new data if registered table already exists. include_feature_view_timestamp_col: Generated dataset will include timestamp column of feature view (if feature view has timestamp column) if set true. Default to false. desc: A description about this dataset. @@ -798,10 +839,8 @@ def generate_dataset( A Dataset object. Raises: - SnowflakeMLException: [ValueError] save_mode is invalid. SnowflakeMLException: [ValueError] spine_df contains more than one query. - SnowflakeMLException: [ValueError] Materialized_table contains invalid char `.`. - SnowflakeMLException: [ValueError] Materialized_table already exists with save_mode `errorifexists`. + SnowflakeMLException: [ValueError] Dataset name/version already exists SnowflakeMLException: [ValueError] Snapshot creation failed. SnowflakeMLException: [RuntimeError] Failed to create clone from table. SnowflakeMLException: [RuntimeError] Failed to find resources. @@ -811,15 +850,6 @@ def generate_dataset( if spine_label_cols is not None: spine_label_cols = to_sql_identifiers(spine_label_cols) # type: ignore[assignment] - allowed_save_mode = {"errorifexists", "merge"} - if save_mode.lower() not in allowed_save_mode: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - f"'{save_mode}' is not supported. Current supported save modes: {','.join(allowed_save_mode)}" - ), - ) - if len(spine_df.queries["queries"]) != 1: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_ARGUMENT, @@ -832,70 +862,55 @@ def generate_dataset( spine_df, features, spine_timestamp_col, include_feature_view_timestamp_col ) - snapshot_table = None - if materialized_table is not None: - if "." in materialized_table: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError(f"materialized_table {materialized_table} contains invalid char `.`"), - ) - - # TODO (wezhou) change materialized_table to SqlIdentifier - found_rows = self._find_object("TABLES", SqlIdentifier(materialized_table)) - if save_mode.lower() == "errorifexists" and len(found_rows) > 0: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.OBJECT_ALREADY_EXISTS, - original_exception=ValueError(f"Dataset table {materialized_table} already exists."), - ) - - self._dump_dataset(result_df, materialized_table, join_keys, spine_timestamp_col) - - snapshot_table = f"{materialized_table}_{datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}" - snapshot_table = self._get_fully_qualified_name(snapshot_table) - materialized_table = self._get_fully_qualified_name(materialized_table) - - try: - self._session.sql(f"CREATE TABLE {snapshot_table} CLONE {materialized_table}").collect( - statement_params=self._telemetry_stmp - ) - except Exception as e: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_SNOWPARK_ERROR, - original_exception=RuntimeError( - f"Failed to create clone {materialized_table} from table {snapshot_table}: {e}." - ), - ) from e - - result_df = self._session.sql(f"SELECT * FROM {snapshot_table}") + # Convert name to fully qualified name if not already fully qualified + db_name, schema_name, object_name, _ = identifier.parse_schema_level_object_identifier(name) + name = "{}.{}.{}".format( + db_name or self._config.database, + schema_name or self._config.schema, + object_name, + ) + version = version or datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") if exclude_columns is not None: result_df = self._exclude_columns(result_df, exclude_columns) fs_meta = FeatureStoreMetadata( spine_query=spine_df.queries["queries"][0], - connection_params=vars(self._config), - features=[fv.to_json() for fv in features], + serialized_feature_views=[fv.to_json() for fv in features], + spine_timestamp_col=spine_timestamp_col, ) - dataset = Dataset( - self._session, - df=result_df, - materialized_table=materialized_table, - snapshot_table=snapshot_table, - timestamp_col=spine_timestamp_col, - label_cols=spine_label_cols, - feature_store_metadata=fs_meta, - desc=desc, - ) - return dataset + try: + ds: dataset.Dataset = dataset.create_from_dataframe( + self._session, + name, + version, + input_dataframe=result_df, + exclude_cols=[spine_timestamp_col], + label_cols=spine_label_cols, + properties=fs_meta, + comment=desc, + ) + return ds - @dispatch_decorator(prpr_version="1.0.8") - def load_feature_views_from_dataset(self, dataset: Dataset) -> List[Union[FeatureView, FeatureViewSlice]]: + except dataset_errors.DatasetExistError as e: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.OBJECT_ALREADY_EXISTS, + original_exception=ValueError(str(e)), + ) from e + except SnowparkSQLException as e: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.INTERNAL_SNOWPARK_ERROR, + original_exception=RuntimeError(f"An error occurred during Dataset generation: {e}."), + ) from e + + @dispatch_decorator() + def load_feature_views_from_dataset(self, ds: dataset.Dataset) -> List[Union[FeatureView, FeatureViewSlice]]: """ Retrieve FeatureViews used during Dataset construction. Args: - dataset: Dataset object created from feature store. + ds: Dataset object created from feature store. Returns: List of FeatureViews used during Dataset construction. @@ -903,13 +918,18 @@ def load_feature_views_from_dataset(self, dataset: Dataset) -> List[Union[Featur Raises: ValueError: if dataset object is not generated from feature store. """ - serialized_objs = dataset.load_features() - if serialized_objs is None: - raise ValueError(f"Dataset {dataset} does not contain valid feature view information.") + assert ds.selected_version is not None + source_meta = ds.selected_version._get_metadata() + if ( + source_meta is None + or not isinstance(source_meta.properties, FeatureStoreMetadata) + or source_meta.properties.serialized_feature_views is None + ): + raise ValueError(f"Dataset {ds} does not contain valid feature view information.") - return self._load_serialized_feature_objects(serialized_objs) + return self._load_serialized_feature_objects(source_meta.properties.serialized_feature_views) - @dispatch_decorator(prpr_version="1.0.8") + @dispatch_decorator() def clear(self) -> None: """ Clear all feature store internal objects including feature views, entities etc. Note feature store @@ -929,7 +949,11 @@ def clear(self) -> None: if len(result) == 0: return - object_types = ["DYNAMIC TABLES", "TABLES", "VIEWS", "TASKS"] + fs_obj_tag = self._find_object("TAGS", SqlIdentifier(_FEATURE_STORE_OBJECT_TAG)) + if len(fs_obj_tag) == 0: + return + + object_types = ["DYNAMIC TABLES", "DATASETS", "VIEWS", "TASKS"] for obj_type in object_types: all_object_rows = self._find_object(obj_type, None) for row in all_object_rows: @@ -939,9 +963,8 @@ def clear(self) -> None: entity_tags = self._find_object("TAGS", SqlIdentifier(_ENTITY_TAG_PREFIX), prefix_match=True) all_tags = [ - _FEATURE_VIEW_ENTITY_TAG, - _FEATURE_VIEW_TS_COL_TAG, _FEATURE_STORE_OBJECT_TAG, + _FEATURE_VIEW_METADATA_TAG, ] + [SqlIdentifier(row["name"], case_sensitive=True) for row in entity_tags] for tag_name in all_tags: obj_name = self._get_fully_qualified_name(tag_name) @@ -965,37 +988,47 @@ def _get_feature_view_if_exists(self, name: str, version: str) -> FeatureView: ) return existing_fv + def _recompose_join_keys(self, join_key: str) -> List[str]: + # ALLOWED_VALUES in TAG will follow format ["key_1,key2,..."] + # since keys are already resolved following the SQL identifier rule on the write path, + # we simply parse the keys back and wrap them with quotes to preserve cases + # Example join_key repr from TAG value: "[key1,key2,key3]" + join_keys = join_key[2:-2].split(",") + res = [] + for k in join_keys: + res.append(f'"{k}"') + return res + def _create_dynamic_table( self, feature_view_name: SqlIdentifier, feature_view: FeatureView, fully_qualified_name: str, column_descs: str, - entities: str, + tagging_clause: str, schedule_task: bool, warehouse: SqlIdentifier, - timestamp_col: SqlIdentifier, block: bool, override: bool, ) -> None: # TODO: cluster by join keys once DT supports that - override_clause = " OR REPLACE" if override else "" - query = f"""CREATE{override_clause} DYNAMIC TABLE {fully_qualified_name} ({column_descs}) - TARGET_LAG = '{'DOWNSTREAM' if schedule_task else feature_view.refresh_freq}' - COMMENT = '{feature_view.desc}' - TAG ( - {self._get_fully_qualified_name(_FEATURE_VIEW_ENTITY_TAG)} = '{entities}', - {self._get_fully_qualified_name(_FEATURE_VIEW_TS_COL_TAG)} = '{timestamp_col}', - {self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)} = - '{_FeatureStoreObjTypes.FEATURE_VIEW.value}' - ) - WAREHOUSE = {warehouse} - AS {feature_view.query} - """ try: + override_clause = " OR REPLACE" if override else "" + query = f"""CREATE{override_clause} DYNAMIC TABLE {fully_qualified_name} ({column_descs}) + TARGET_LAG = '{'DOWNSTREAM' if schedule_task else feature_view.refresh_freq}' + COMMENT = '{feature_view.desc}' + TAG ( + {tagging_clause} + ) + WAREHOUSE = {warehouse} + AS {feature_view.query} + """ self._session.sql(query).collect(block=block, statement_params=self._telemetry_stmp) if schedule_task: + task_obj_info = _FeatureStoreObjInfo( + _FeatureStoreObjTypes.FEATURE_VIEW_REFRESH_TASK, snowml_version.VERSION + ) try: self._session.sql( f"""CREATE{override_clause} TASK {fully_qualified_name} @@ -1007,8 +1040,7 @@ def _create_dynamic_table( self._session.sql( f""" ALTER TASK {fully_qualified_name} - SET TAG {self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)} - ='{_FeatureStoreObjTypes.FEATURE_VIEW_REFRESH_TASK.value}' + SET TAG {self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)}='{task_obj_info.to_json()}' """ ).collect(statement_params=self._telemetry_stmp) self._session.sql(f"ALTER TASK {fully_qualified_name} RESUME").collect( @@ -1049,57 +1081,6 @@ def _check_dynamic_table_refresh_mode(self, feature_view_name: SqlIdentifier) -> category=UserWarning, ) - def _dump_dataset( - self, - df: DataFrame, - table_name: str, - join_keys: List[SqlIdentifier], - spine_timestamp_col: Optional[SqlIdentifier] = None, - ) -> None: - if len(df.queries["queries"]) != 1: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError(f"Dataset df must contain only one query. Got: {df.queries['queries']}"), - ) - schema = ", ".join([f"{c.name} {type_utils.convert_sp_to_sf_type(c.datatype)}" for c in df.schema.fields]) - fully_qualified_name = self._get_fully_qualified_name(table_name) - - try: - self._session.sql( - f"""CREATE TABLE IF NOT EXISTS {fully_qualified_name} ({schema}) - CLUSTER BY ({', '.join(join_keys)}) - TAG ({self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)} = '') - """ - ).collect(block=True, statement_params=self._telemetry_stmp) - except Exception as e: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_SNOWPARK_ERROR, - original_exception=RuntimeError(f"Failed to create table {fully_qualified_name}: {e}."), - ) from e - - source_query = df.queries["queries"][0] - - if spine_timestamp_col is not None: - join_keys.append(spine_timestamp_col) - - _, _, dest_alias, _ = identifier.parse_schema_level_object_identifier(fully_qualified_name) - source_alias = f"{dest_alias}_source" - join_cond = " AND ".join([f"{dest_alias}.{k} = {source_alias}.{k}" for k in join_keys]) - update_clause = ", ".join([f"{dest_alias}.{c} = {source_alias}.{c}" for c in df.columns]) - insert_clause = ", ".join([f"{source_alias}.{c}" for c in df.columns]) - query = f""" - MERGE INTO {fully_qualified_name} USING ({source_query}) {source_alias} ON {join_cond} - WHEN MATCHED THEN UPDATE SET {update_clause} - WHEN NOT MATCHED THEN INSERT ({', '.join(df.columns)}) VALUES ({insert_clause}) - """ - try: - self._session.sql(query).collect(block=True, statement_params=self._telemetry_stmp) - except Exception as e: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INTERNAL_SNOWPARK_ERROR, - original_exception=RuntimeError(f"Failed to create dataset {fully_qualified_name} with merge: {e}."), - ) from e - def _validate_entity_exists(self, name: SqlIdentifier) -> bool: full_entity_tag_name = self._get_entity_name(name) found_rows = self._find_object("TAGS", full_entity_tag_name) @@ -1150,7 +1131,7 @@ def _join_features( else: cols = f.feature_names - join_keys = [k for e in f.entities for k in e.join_keys] + join_keys = list({k for e in f.entities for k in e.join_keys}) join_keys_str = ", ".join(join_keys) assert f.version is not None join_table_name = f.fully_qualified_name() @@ -1227,8 +1208,7 @@ def _check_internal_objects_exist_or_throw(self) -> None: for tag_name in to_sql_identifiers( [ _FEATURE_STORE_OBJECT_TAG, - _FEATURE_VIEW_ENTITY_TAG, - _FEATURE_VIEW_TS_COL_TAG, + _FEATURE_VIEW_METADATA_TAG, ] ): tag_result = self._find_object("TAGS", tag_name) @@ -1340,7 +1320,8 @@ def join_cols(cols: List[SqlIdentifier], end_comma: bool, rename: bool, prefix: # Part 4: join original spine table with window table prefix_f_only_cols = to_sql_identifiers( - [f"{temp_prefix}{name.resolved()}" for name in f_only_cols], case_sensitive=True + [f"{temp_prefix}{name.resolved()}" for name in f_only_cols], + case_sensitive=True, ) last_select = f""" SELECT @@ -1373,7 +1354,10 @@ def _get_fv_backend_representations( return dynamic_table_results + view_results def _update_feature_view_status(self, feature_view: FeatureView, operation: str) -> FeatureView: - assert operation in ["RESUME", "SUSPEND"], f"Operation: {operation} not supported" + assert operation in [ + "RESUME", + "SUSPEND", + ], f"Operation: {operation} not supported" if feature_view.status == FeatureViewStatus.DRAFT or feature_view.version is None: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.NOT_FOUND, @@ -1397,17 +1381,76 @@ def _update_feature_view_status(self, feature_view: FeatureView, operation: str) logger.info(f"Successfully {operation} FeatureView {feature_view.name}/{feature_view.version}.") return self.get_feature_view(feature_view.name, feature_view.version) - def _find_feature_views( + def _optimized_find_feature_views( self, entity_name: SqlIdentifier, feature_view_name: Optional[SqlIdentifier] - ) -> List[FeatureView]: + ) -> DataFrame: if not self._validate_entity_exists(entity_name): - return [] + return self._session.create_dataframe([], schema=_LIST_FEATURE_VIEW_SCHEMA) + # TODO: this can be optimized further by directly getting all possible FVs and filter by tag + # it's easier to rewrite the code once we can remove the tag_reference path all_fvs = self._get_fv_backend_representations(object_name=None) fv_maps = {SqlIdentifier(r["name"], case_sensitive=True): r for r in all_fvs} if len(fv_maps.keys()) == 0: - return [] + return self._session.create_dataframe([], schema=_LIST_FEATURE_VIEW_SCHEMA) + + filter_clause = f"WHERE OBJECT_NAME LIKE '{feature_view_name.resolved()}%'" if feature_view_name else "" + try: + res = self._session.sql( + f""" + SELECT + OBJECT_NAME + FROM TABLE( + {self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES_INTERNAL( + TAG_NAME => '{self._get_fully_qualified_name(self._get_entity_name(entity_name))}' + ) + ) {filter_clause}""" + ).collect(statement_params=self._telemetry_stmp) + except Exception as e: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.INTERNAL_SNOWPARK_ERROR, + original_exception=RuntimeError(f"Failed to find feature views' by entity {entity_name}: {e}"), + ) from e + + output_values: List[List[Any]] = [] + for r in res: + row = fv_maps[SqlIdentifier(r["OBJECT_NAME"], case_sensitive=True)] + self._extract_feature_view_info(row, output_values) + + return self._session.create_dataframe(output_values, schema=_LIST_FEATURE_VIEW_SCHEMA) + + def _extract_feature_view_info(self, row: Row, output_values: List[List[Any]]) -> None: + name, version = row["name"].split(_FEATURE_VIEW_NAME_DELIMITER) + m = re.match(_DT_OR_VIEW_QUERY_PATTERN, row["text"]) + if m is None: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.INTERNAL_SNOWML_ERROR, + original_exception=RuntimeError(f"Failed to parse query text for FeatureView {name}/{version}: {row}."), + ) + + fv_metadata = _FeatureViewMetadata.from_json(m.group("fv_metadata")) + + values: List[Any] = [] + values.append(name) + values.append(version) + values.append(row["database_name"]) + values.append(row["schema_name"]) + values.append(row["created_on"]) + values.append(row["owner"]) + values.append(row["comment"]) + values.append(fv_metadata.entities) + output_values.append(values) + + def _find_feature_views(self, entity_name: SqlIdentifier, feature_view_name: Optional[SqlIdentifier]) -> DataFrame: + if not self._validate_entity_exists(entity_name): + return self._session.create_dataframe([], schema=_LIST_FEATURE_VIEW_SCHEMA) + + all_fvs = self._get_fv_backend_representations(object_name=None) + fv_maps = {SqlIdentifier(r["name"], case_sensitive=True): r for r in all_fvs} + + if len(fv_maps.keys()) == 0: + return self._session.create_dataframe([], schema=_LIST_FEATURE_VIEW_SCHEMA) # NOTE: querying INFORMATION_SCHEMA for Entity lineage can be expensive depending on how many active # FeatureViews there are. If this ever become an issue, consider exploring improvements. @@ -1424,7 +1467,7 @@ def _find_feature_views( ) ) WHERE LEVEL = 'TABLE' - AND TAG_NAME = '{_FEATURE_VIEW_ENTITY_TAG}' + AND TAG_NAME = '{_FEATURE_VIEW_METADATA_TAG}' """ for fv_name in fv_maps.keys() ] @@ -1436,21 +1479,22 @@ def _find_feature_views( original_exception=RuntimeError(f"Failed to retrieve feature views' information: {e}"), ) from e - entities = self.list_entities().collect() - outputs = [] + output_values: List[List[Any]] = [] for r in results: - if entity_name == SqlIdentifier(r["TAG_VALUE"], case_sensitive=True): - fv_name, _ = r["OBJECT_NAME"].split(_FEATURE_VIEW_NAME_DELIMITER) - fv_name = SqlIdentifier(fv_name, case_sensitive=True) - obj_name = SqlIdentifier(r["OBJECT_NAME"], case_sensitive=True) - if feature_view_name is not None: - if fv_name == feature_view_name: - outputs.append(self._compose_feature_view(fv_maps[obj_name], entities)) + fv_metadata = _FeatureViewMetadata.from_json(r["TAG_VALUE"]) + for retrieved_entity in fv_metadata.entities: + if entity_name == SqlIdentifier(retrieved_entity, case_sensitive=True): + fv_name, _ = r["OBJECT_NAME"].split(_FEATURE_VIEW_NAME_DELIMITER) + fv_name = SqlIdentifier(fv_name, case_sensitive=True) + obj_name = SqlIdentifier(r["OBJECT_NAME"], case_sensitive=True) + if feature_view_name is not None: + if fv_name == feature_view_name: + self._extract_feature_view_info(fv_maps[obj_name], output_values) + else: + continue else: - continue - else: - outputs.append(self._compose_feature_view(fv_maps[obj_name], entities)) - return outputs + self._extract_feature_view_info(fv_maps[obj_name], output_values) + return self._session.create_dataframe(output_values, schema=_LIST_FEATURE_VIEW_SCHEMA) def _compose_feature_view(self, row: Row, entity_list: List[Row]) -> FeatureView: def find_and_compose_entity(name: str) -> Entity: @@ -1459,7 +1503,7 @@ def find_and_compose_entity(name: str) -> Entity: if e["NAME"] == name: return Entity( name=SqlIdentifier(e["NAME"], case_sensitive=True).identifier(), - join_keys=e["JOIN_KEYS"].strip("[]").split(","), + join_keys=self._recompose_join_keys(e["JOIN_KEYS"]), desc=e["DESC"], ) raise RuntimeError(f"Cannot find entity {name} from retrieved entity list: {entity_list}") @@ -1477,9 +1521,9 @@ def find_and_compose_entity(name: str) -> Entity: query = m.group("query") df = self._session.sql(query) desc = m.group("comment") - entity_names = m.group("entities") - entities = [find_and_compose_entity(n) for n in entity_names.split(_FEATURE_VIEW_ENTITY_TAG_DELIMITER)] - ts_col = m.group("ts_col") + fv_metadata = _FeatureViewMetadata.from_json(m.group("fv_metadata")) + entities = [find_and_compose_entity(n) for n in fv_metadata.entities] + ts_col = fv_metadata.timestamp_col timestamp_col = ts_col if ts_col != _TIMESTAMP_COL_PLACEHOLDER else None fv = FeatureView._construct_feature_view( @@ -1506,9 +1550,9 @@ def find_and_compose_entity(name: str) -> Entity: query = m.group("query") df = self._session.sql(query) desc = m.group("comment") - entity_names = m.group("entities") - entities = [find_and_compose_entity(n) for n in entity_names.split(_FEATURE_VIEW_ENTITY_TAG_DELIMITER)] - ts_col = m.group("ts_col") + fv_metadata = _FeatureViewMetadata.from_json(m.group("fv_metadata")) + entities = [find_and_compose_entity(n) for n in fv_metadata.entities] + ts_col = fv_metadata.timestamp_col timestamp_col = ts_col if ts_col != _TIMESTAMP_COL_PLACEHOLDER else None fv = FeatureView._construct_feature_view( @@ -1542,7 +1586,10 @@ def _fetch_column_descs(self, obj_type: str, obj_name: SqlIdentifier) -> Dict[st return descs def _find_object( - self, object_type: str, object_name: Optional[SqlIdentifier], prefix_match: bool = False + self, + object_type: str, + object_name: Optional[SqlIdentifier], + prefix_match: bool = False, ) -> List[Row]: """Try to find an object by given type and name pattern. @@ -1569,7 +1616,7 @@ def _find_object( search_space, obj_domain = self._obj_search_spaces[object_type] all_rows = [] fs_tag_objects = [] - tag_free_object_types = ["TAGS", "SCHEMAS", "WAREHOUSES"] + tag_free_object_types = ["TAGS", "SCHEMAS", "WAREHOUSES", "DATASETS"] try: search_scope = f"IN {search_space}" if search_space is not None else "" all_rows = self._session.sql(f"SHOW {object_type} LIKE '{match_name}' {search_scope}").collect( @@ -1577,25 +1624,41 @@ def _find_object( ) # There could be none-FS objects under FS schema, thus filter on objects with FS special tag. if object_type not in tag_free_object_types and len(all_rows) > 0: - # Note: in TAG_REFERENCES() is case insensitive, - # use double quotes to make it case-sensitive. - queries = [ - f""" - SELECT OBJECT_NAME - FROM TABLE( - {self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES( - '{self._get_fully_qualified_name(SqlIdentifier(row['name'], case_sensitive=True))}', - '{obj_domain}' + if self._use_optimized_tag_ref: + fs_obj_rows = self._session.sql( + f""" + SELECT + OBJECT_NAME + FROM TABLE( + {self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES_INTERNAL( + TAG_NAME => '{self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)}' + ) ) - ) - WHERE TAG_NAME = '{_FEATURE_STORE_OBJECT_TAG}' - AND TAG_SCHEMA = '{self._config.schema.resolved()}' - """ - for row in all_rows - ] - fs_obj_rows = self._session.sql("\nUNION\n".join(queries)).collect( - statement_params=self._telemetry_stmp - ) + WHERE DOMAIN='{obj_domain}' + """ + ).collect(statement_params=self._telemetry_stmp) + else: + # TODO: remove this after tag_ref_internal rollout + # Note: in TAG_REFERENCES() is case insensitive, + # use double quotes to make it case-sensitive. + queries = [ + f""" + SELECT OBJECT_NAME + FROM TABLE( + {self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES( + '{self._get_fully_qualified_name(SqlIdentifier(row['name'], case_sensitive=True))}', + '{obj_domain}' + ) + ) + WHERE TAG_NAME = '{_FEATURE_STORE_OBJECT_TAG}' + AND TAG_SCHEMA = '{self._config.schema.resolved()}' + """ + for row in all_rows + ] + fs_obj_rows = self._session.sql("\nUNION\n".join(queries)).collect( + statement_params=self._telemetry_stmp + ) + fs_tag_objects = [row["OBJECT_NAME"] for row in fs_obj_rows] except Exception as e: raise snowml_exceptions.SnowflakeMLException( @@ -1641,3 +1704,66 @@ def _exclude_columns(self, df: DataFrame, exclude_columns: List[str]) -> DataFra ), ) return cast(DataFrame, df.drop(exclude_columns)) + + def _tag_ref_internal_enabled(self) -> bool: + try: + self._session.sql( + f""" + SELECT * FROM TABLE( + INFORMATION_SCHEMA.TAG_REFERENCES_INTERNAL( + TAG_NAME => '{_FEATURE_STORE_OBJECT_TAG}' + ) + ) LIMIT 1; + """ + ).collect() + return True + except Exception: + return False + + def _check_feature_store_object_versions(self) -> None: + versions = self._collapse_object_versions() + if len(versions) > 0 and pkg_version.parse(snowml_version.VERSION) < versions[0]: + warnings.warn( + "The current snowflake-ml-python version out of date, package upgrade recommended " + + f"(current={snowml_version.VERSION}, recommended>={str(versions[0])})", + stacklevel=2, + category=UserWarning, + ) + + def _collapse_object_versions(self) -> List[pkg_version.Version]: + if not self._use_optimized_tag_ref: + return [] + + query = f""" + SELECT + TAG_VALUE + FROM TABLE( + {self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES_INTERNAL( + TAG_NAME => '{self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)}' + ) + ) + """ + try: + res = self._session.sql(query).collect(statement_params=self._telemetry_stmp) + except Exception: + # since this is a best effort user warning to upgrade pkg versions + # we are treating failures as benign error + return [] + versions = set() + compatibility_breakage_detected = False + for r in res: + info = _FeatureStoreObjInfo.from_json(r["TAG_VALUE"]) + if info.type == _FeatureStoreObjTypes.UNKNOWN: + compatibility_breakage_detected = True + versions.add(pkg_version.parse(info.pkg_version)) + + sorted_versions = sorted(versions, reverse=True) + if compatibility_breakage_detected: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.SNOWML_PACKAGE_OUTDATED, + original_exception=RuntimeError( + f"The current snowflake-ml-python version {snowml_version.VERSION} is out of date, " + + f"please upgrade to at least {sorted_versions[0]}." + ), + ) + return sorted_versions diff --git a/snowflake/ml/feature_store/feature_view.py b/snowflake/ml/feature_store/feature_view.py index 44f2618a..79635a64 100644 --- a/snowflake/ml/feature_store/feature_view.py +++ b/snowflake/ml/feature_store/feature_view.py @@ -1,8 +1,9 @@ from __future__ import annotations import json +import re from collections import OrderedDict -from dataclasses import dataclass +from dataclasses import asdict, dataclass from enum import Enum from typing import Dict, List, Optional @@ -28,19 +29,42 @@ _FEATURE_VIEW_NAME_DELIMITER = "$" _TIMESTAMP_COL_PLACEHOLDER = "FS_TIMESTAMP_COL_PLACEHOLDER_VAL" _FEATURE_OBJ_TYPE = "FEATURE_OBJ_TYPE" +# Feature view version rule is aligned with dataset version rule in SQL. +_FEATURE_VIEW_VERSION_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_.\-]*$") +_FEATURE_VIEW_VERSION_MAX_LENGTH = 128 -class FeatureViewVersion(SqlIdentifier): +@dataclass(frozen=True) +class _FeatureViewMetadata: + """Represent metadata tracked on top of FV backend object""" + + entities: List[str] + timestamp_col: str + + def to_json(self) -> str: + return json.dumps(asdict(self)) + + @classmethod + def from_json(cls, json_str: str) -> _FeatureViewMetadata: + state_dict = json.loads(json_str) + return cls(**state_dict) + + +class FeatureViewVersion(str): def __new__(cls, version: str) -> FeatureViewVersion: - if _FEATURE_VIEW_NAME_DELIMITER in version: + if not _FEATURE_VIEW_VERSION_RE.match(version) or len(version) > _FEATURE_VIEW_VERSION_MAX_LENGTH: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError(f"{_FEATURE_VIEW_NAME_DELIMITER} is not allowed in version: {version}."), + original_exception=ValueError( + f"`{version}` is not a valid feature view version. " + "It must start with letter or digit, and followed by letter, digit, '_', '-' or '.'. " + f"The length limit is {_FEATURE_VIEW_VERSION_MAX_LENGTH}." + ), ) - return super().__new__(cls, version) # type: ignore[return-value] + return super().__new__(cls, version) def __init__(self, version: str) -> None: - super().__init__(version) + super().__init__() class FeatureViewStatus(Enum): @@ -285,6 +309,11 @@ def refresh_mode_reason(self) -> Optional[str]: def owner(self) -> Optional[str]: return self._owner + def _metadata(self) -> _FeatureViewMetadata: + entity_names = [e.name.identifier() for e in self.entities] + ts_col = self.timestamp_col.identifier() if self.timestamp_col is not None else _TIMESTAMP_COL_PLACEHOLDER + return _FeatureViewMetadata(entity_names, ts_col) + def _get_query(self) -> str: if len(self._feature_df.queries["queries"]) != 1: raise ValueError( @@ -436,8 +465,8 @@ def _construct_feature_view( status: FeatureViewStatus, feature_descs: Dict[str, str], refresh_freq: Optional[str], - database: Optional[str], - schema: Optional[str], + database: str, + schema: str, warehouse: Optional[str], refresh_mode: Optional[str], refresh_mode_reason: Optional[str], diff --git a/snowflake/ml/fileset/BUILD.bazel b/snowflake/ml/fileset/BUILD.bazel index f9e8c70f..f8ddbb5e 100644 --- a/snowflake/ml/fileset/BUILD.bazel +++ b/snowflake/ml/fileset/BUILD.bazel @@ -67,6 +67,7 @@ py_library( deps = [ ":embedded_stage_fs", ":sfcfs", + "//snowflake/ml/_internal/utils:snowflake_env", # FIXME(dhung) temporary workaround for SnowURL bug in GS 8.17 ], ) diff --git a/snowflake/ml/fileset/embedded_stage_fs.py b/snowflake/ml/fileset/embedded_stage_fs.py index 3fe52370..0c03fb89 100644 --- a/snowflake/ml/fileset/embedded_stage_fs.py +++ b/snowflake/ml/fileset/embedded_stage_fs.py @@ -1,11 +1,22 @@ -from typing import Any, Optional +import re +from collections import defaultdict +from typing import Any, List, Optional, Tuple from snowflake import snowpark from snowflake.connector import connection +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.exceptions import ( + error_codes, + exceptions as snowml_exceptions, + fileset_errors, +) from snowflake.ml._internal.utils import identifier +from snowflake.snowpark import exceptions as snowpark_exceptions from . import stage_fs +_SNOWURL_PATH_RE = re.compile(r"versions/(?P[^/]+)(?:/+(?P.*))?") + class SFEmbeddedStageFileSystem(stage_fs.SFStageFileSystem): def __init__( @@ -56,3 +67,80 @@ def _stage_path_to_relative_path(self, stage_path: str) -> str: A string of the relative stage path. """ return stage_path + + def _fetch_presigned_urls( + self, files: List[str], url_lifetime: float = stage_fs._PRESIGNED_URL_LIFETIME_SEC + ) -> List[Tuple[str, str]]: + """Fetch presigned urls for the given files.""" + # SnowURL requires full snow:////versions/ as the stage path arg to get_presigned_Url + versions_dict = defaultdict(list) + for file in files: + match = _SNOWURL_PATH_RE.fullmatch(file) + assert match is not None and match.group("filepath") is not None + versions_dict[match.group("version")].append(match.group("filepath")) + presigned_urls: List[Tuple[str, str]] = [] + try: + for version, version_files in versions_dict.items(): + for file in version_files: + stage_loc = f"{self.stage_name}/versions/{version}" + presigned_urls.extend( + self._session.sql( + f"select '{version}/{file}' as name," + f" get_presigned_url('{stage_loc}', '{file}', {url_lifetime}) as url" + ).collect( + statement_params=telemetry.get_function_usage_statement_params( + project=stage_fs._PROJECT, + api_calls=[snowpark.DataFrame.collect], + ), + ) + ) + except snowpark_exceptions.SnowparkClientException as e: + if e.message.startswith(fileset_errors.ERRNO_DOMAIN_NOT_EXIST) or e.message.startswith( + fileset_errors.ERRNO_STAGE_NOT_EXIST + ): + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.SNOWML_NOT_FOUND, + original_exception=fileset_errors.StageNotFoundError( + f"Stage {self.stage_name} does not exist or is not authorized." + ), + ) + else: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.INTERNAL_SNOWML_ERROR, + original_exception=fileset_errors.FileSetError(str(e)), + ) + return presigned_urls + + @classmethod + def _parent(cls, path: str) -> str: + """Get parent of specified path up to minimally valid root path. + + For SnowURL, the minimum valid path is snow:////versions/ + + Args: + path: File or directory path + + Returns: + Parent path + + Examples: + ---- + >>> fs._parent("snow://dataset/my_ds/versions/my_version/file.ext") + "snow://dataset/my_ds/versions/my_version/" + >>> fs._parent("snow://dataset/my_ds/versions/my_version/subdir/file.ext") + "snow://dataset/my_ds/versions/my_version/subdir/" + >>> fs._parent("snow://dataset/my_ds/versions/my_version/") + "snow://dataset/my_ds/versions/my_version/" + >>> fs._parent("snow://dataset/my_ds/versions/my_version") + "snow://dataset/my_ds/versions/my_version" + """ + path_match = _SNOWURL_PATH_RE.fullmatch(path) + if not path_match: + return super()._parent(path) # type: ignore[no-any-return] + filepath: str = path_match.group("filepath") or "" + root: str = path[: path_match.start("filepath")] if filepath else path + if "/" in filepath: + parent = filepath.rsplit("/", 1)[0] + return root + parent + else: + return root diff --git a/snowflake/ml/fileset/embedded_stage_fs_test.py b/snowflake/ml/fileset/embedded_stage_fs_test.py index faa82e51..cdc45916 100644 --- a/snowflake/ml/fileset/embedded_stage_fs_test.py +++ b/snowflake/ml/fileset/embedded_stage_fs_test.py @@ -90,7 +90,7 @@ def _mock_collect_res(self, prefix: str) -> mock_data_frame.MockDataFrame: def _add_mock_test_case(self, prefix: str) -> None: self.session.add_mock_sql( - query=f"LIST snow://{self.domain}/{self.name}/{prefix}", + query=f"LIST 'snow://{self.domain}/{self.name}/{prefix}'", result=self._mock_collect_res(prefix), ) diff --git a/snowflake/ml/fileset/sfcfs.py b/snowflake/ml/fileset/sfcfs.py index c3c592ec..d242a6d3 100644 --- a/snowflake/ml/fileset/sfcfs.py +++ b/snowflake/ml/fileset/sfcfs.py @@ -185,7 +185,6 @@ def _get_stage_fs(self, sf_file_path: _SFFilePath) -> stage_fs.SFStageFileSystem func_params_to_log=["detail"], conn_attr_name="_conn", ) - @snowpark._internal.utils.private_preview(version="0.2.0") def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[List[str], List[Dict[str, Any]]]: """Override fsspec `ls` method. List single "directory" with or without details. @@ -216,7 +215,6 @@ def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[List[str], project=_PROJECT, conn_attr_name="_conn", ) - @snowpark._internal.utils.private_preview(version="0.2.0") def optimize_read(self, files: Optional[List[str]] = None) -> None: """Prefetch and cache the presigned urls for all the given files to speed up the file opening. @@ -242,7 +240,6 @@ def optimize_read(self, files: Optional[List[str]] = None) -> None: project=_PROJECT, conn_attr_name="_conn", ) - @snowpark._internal.utils.private_preview(version="0.2.0") def _open(self, path: str, **kwargs: Any) -> fsspec.spec.AbstractBufferedFile: """Override fsspec `_open` method. Open a file for reading in 'rb' mode. @@ -268,7 +265,6 @@ def _open(self, path: str, **kwargs: Any) -> fsspec.spec.AbstractBufferedFile: project=_PROJECT, conn_attr_name="_conn", ) - @snowpark._internal.utils.private_preview(version="0.2.0") def info(self, path: str, **kwargs: Any) -> Dict[str, Any]: """Override fsspec `info` method. Give details of entry at path.""" file_path = self._parse_file_path(path) diff --git a/snowflake/ml/fileset/sfcfs_test.py b/snowflake/ml/fileset/sfcfs_test.py index 66157db6..2e0b218e 100644 --- a/snowflake/ml/fileset/sfcfs_test.py +++ b/snowflake/ml/fileset/sfcfs_test.py @@ -1,5 +1,4 @@ import pickle -from typing import List import fsspec from absl.testing import absltest, parameterized @@ -43,7 +42,7 @@ def test_init_sf_file_system(self) -> None: "nytrain/", ), ) - def test_parse_sfc_file_path(self, *test_case: List[str]) -> None: + def test_parse_sfc_file_path(self, *test_case: str) -> None: """Test if the FS could parse the input stage location correctly""" with absltest.mock.patch( "snowflake.ml.fileset.stage_fs.SFStageFileSystem", autospec=True diff --git a/snowflake/ml/fileset/snowfs.py b/snowflake/ml/fileset/snowfs.py index 8158aaef..dacab563 100644 --- a/snowflake/ml/fileset/snowfs.py +++ b/snowflake/ml/fileset/snowfs.py @@ -1,9 +1,10 @@ import collections import logging import re -from typing import Any, Optional +from typing import Any, Dict, Optional import fsspec +import packaging.version as pkg_version from snowflake import snowpark from snowflake.connector import connection @@ -11,7 +12,7 @@ error_codes, exceptions as snowml_exceptions, ) -from snowflake.ml._internal.utils import identifier +from snowflake.ml._internal.utils import identifier, snowflake_env from snowflake.ml.fileset import embedded_stage_fs, sfcfs PROTOCOL_NAME = "snow" @@ -24,9 +25,13 @@ f"({PROTOCOL_NAME}://)?" r"(?\w+)/" rf"(?P(?:{identifier._SF_IDENTIFIER}\.){{,2}}{identifier._SF_IDENTIFIER})/" - r"(?Pversions(?:/(?:(?P[^/]+)(?:/(?P.*))?)?)?)" + r"(?Pversions/(?:(?P[^/]+)(?:/(?P.*))?)?)" ) +# FIXME(dhung): Temporary fix for bug in GS version 8.17 +_BUG_VERSION_MIN = pkg_version.Version("8.17") # Inclusive minimum version with bugged behavior +_BUG_VERSION_MAX = pkg_version.Version("8.18") # Exclusive maximum version with bugged behavior + class SnowFileSystem(sfcfs.SFFileSystem): """A filesystem that allows user to access Snowflake embedded stage files with valid Snowflake locations. @@ -39,8 +44,8 @@ class SnowFileSystem(sfcfs.SFFileSystem): """ protocol = PROTOCOL_NAME + _IS_BUGGED_VERSION = None - @snowpark._internal.utils.private_preview(version="1.5.0") def __init__( self, sf_connection: Optional[connection.SnowflakeConnection] = None, @@ -49,6 +54,21 @@ def __init__( ) -> None: super().__init__(sf_connection=sf_connection, snowpark_session=snowpark_session, **kwargs) + # FIXME(dhung): Temporary fix for bug in GS version 8.17 + if SnowFileSystem._IS_BUGGED_VERSION is None: + try: + sf_version = snowflake_env.get_current_snowflake_version(self._session) + SnowFileSystem._IS_BUGGED_VERSION = _BUG_VERSION_MIN <= sf_version < _BUG_VERSION_MAX + except Exception: + SnowFileSystem._IS_BUGGED_VERSION = False + + def info(self, path: str, **kwargs: Any) -> Dict[str, Any]: + # FIXME(dhung): Temporary fix for bug in GS version 8.17 + res: Dict[str, Any] = super().info(path, **kwargs) + if res.get("type") == "directory" and not res["name"].endswith("/"): + res["name"] += "/" + return res + def _get_stage_fs( self, sf_file_path: _SFFileEntityPath # type: ignore[override] ) -> embedded_stage_fs.SFEmbeddedStageFileSystem: @@ -79,7 +99,13 @@ def _stage_path_to_absolute_path(self, stage_fs: embedded_stage_fs.SFEmbeddedSta protocol = f"{PROTOCOL_NAME}://" if stage_name.startswith(protocol): stage_name = stage_name[len(protocol) :] - return stage_name + "/" + path + abs_path = stage_name + "/" + path + # FIXME(dhung): Temporary fix for bug in GS version 8.17 + if self._IS_BUGGED_VERSION: + match = _SNOWURL_PATTERN.fullmatch(abs_path) + assert match is not None + abs_path = abs_path.replace(match.group("relpath"), match.group("relpath").lstrip("/")) + return abs_path @classmethod def _parse_file_path(cls, path: str) -> _SFFileEntityPath: # type: ignore[override] @@ -117,6 +143,9 @@ def _parse_file_path(cls, path: str) -> _SFFileEntityPath: # type: ignore[overr version = snowurl_match.group("version") relative_path = snowurl_match.group("relpath") or "" logging.debug(f"Parsed snow URL: {snowurl_match.groups()}") + # FIXME(dhung): Temporary fix for bug in GS version 8.17 + if cls._IS_BUGGED_VERSION: + filepath = filepath.replace(f"{version}/", f"{version}//") return _SFFileEntityPath( domain=domain, name=name, version=version, relative_path=relative_path, filepath=filepath ) diff --git a/snowflake/ml/fileset/stage_fs.py b/snowflake/ml/fileset/stage_fs.py index 1eb23d4d..f71a7903 100644 --- a/snowflake/ml/fileset/stage_fs.py +++ b/snowflake/ml/fileset/stage_fs.py @@ -144,7 +144,6 @@ def stage_name(self) -> str: project=_PROJECT, func_params_to_log=["detail"], ) - @snowpark._internal.utils.private_preview(version="0.2.0") def ls(self, path: str, detail: bool = False) -> Union[List[str], List[Dict[str, Any]]]: """Override fsspec `ls` method. List single "directory" with or without details. @@ -168,7 +167,7 @@ def ls(self, path: str, detail: bool = False) -> Union[List[str], List[Dict[str, try: loc = self.stage_name path = path.lstrip("/") - objects = self._session.sql(f"LIST {loc}/{path}").collect() + objects = self._session.sql(f"LIST '{loc}/{path}'").collect() except snowpark_exceptions.SnowparkClientException as e: if e.message.startswith(fileset_errors.ERRNO_DOMAIN_NOT_EXIST): raise snowml_exceptions.SnowflakeMLException( @@ -191,7 +190,6 @@ def ls(self, path: str, detail: bool = False) -> Union[List[str], List[Dict[str, @telemetry.send_api_usage_telemetry( project=_PROJECT, ) - @snowpark._internal.utils.private_preview(version="0.2.0") def optimize_read(self, files: Optional[List[str]] = None) -> None: """Prefetch and cache the presigned urls for all the given files to speed up the read performance. @@ -218,7 +216,6 @@ def optimize_read(self, files: Optional[List[str]] = None) -> None: @telemetry.send_api_usage_telemetry( project=_PROJECT, ) - @snowpark._internal.utils.private_preview(version="0.2.0") def _open(self, path: str, mode: str = "rb", **kwargs: Any) -> fsspec.spec.AbstractBufferedFile: """Override fsspec `_open` method. Open a file for reading. diff --git a/snowflake/ml/fileset/stage_fs_test.py b/snowflake/ml/fileset/stage_fs_test.py index 4f9a4655..72c4ce03 100644 --- a/snowflake/ml/fileset/stage_fs_test.py +++ b/snowflake/ml/fileset/stage_fs_test.py @@ -96,7 +96,7 @@ def _mock_collect_res(self, prefix: str) -> mock_data_frame.MockDataFrame: def _add_mock_test_case(self, prefix: str) -> None: self.session.add_mock_sql( - query=f"LIST @{self.db}.{self.schema}.{self.stage}/{prefix}", + query=f"LIST '@{self.db}.{self.schema}.{self.stage}/{prefix}'", result=self._mock_collect_res(prefix), ) diff --git a/snowflake/ml/model/__init__.py b/snowflake/ml/model/__init__.py index bcebb67d..d14485f5 100644 --- a/snowflake/ml/model/__init__.py +++ b/snowflake/ml/model/__init__.py @@ -1,6 +1,6 @@ from snowflake.ml.model._client.model.model_impl import Model -from snowflake.ml.model._client.model.model_version_impl import ModelVersion +from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel from snowflake.ml.model.models.llm import LLM, LLMOptions -__all__ = ["Model", "ModelVersion", "HuggingFacePipelineModel", "LLM", "LLMOptions"] +__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "LLM", "LLMOptions"] diff --git a/snowflake/ml/model/_api.py b/snowflake/ml/model/_api.py index 1ff0edc4..f9f316b4 100644 --- a/snowflake/ml/model/_api.py +++ b/snowflake/ml/model/_api.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast, overload import pandas as pd +from typing_extensions import deprecated from snowflake.ml._internal.exceptions import ( error_codes, @@ -23,6 +24,7 @@ from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session, functions as F +@deprecated("Only used by PrPr model registry.") @overload def save_model( *, @@ -61,6 +63,7 @@ def save_model( ... +@deprecated("Only used by PrPr model registry.") @overload def save_model( *, @@ -101,6 +104,7 @@ def save_model( ... +@deprecated("Only used by PrPr model registry.") @overload def save_model( *, @@ -142,6 +146,7 @@ def save_model( ... +@deprecated("Only used by PrPr model registry.") def save_model( *, name: str, @@ -208,6 +213,7 @@ def save_model( return m +@deprecated("Only used by PrPr model registry.") @overload def load_model(*, session: Session, stage_path: str) -> model_composer.ModelComposer: """Load the model into memory from a zip file in the stage. @@ -219,6 +225,7 @@ def load_model(*, session: Session, stage_path: str) -> model_composer.ModelComp ... +@deprecated("Only used by PrPr model registry.") @overload def load_model(*, session: Session, stage_path: str, meta_only: Literal[False]) -> model_composer.ModelComposer: """Load the model into memory from a zip file in the stage. @@ -231,6 +238,7 @@ def load_model(*, session: Session, stage_path: str, meta_only: Literal[False]) ... +@deprecated("Only used by PrPr model registry.") @overload def load_model(*, session: Session, stage_path: str, meta_only: Literal[True]) -> model_composer.ModelComposer: """Load the model into memory from a zip file in the stage with metadata only. @@ -243,6 +251,7 @@ def load_model(*, session: Session, stage_path: str, meta_only: Literal[True]) - ... +@deprecated("Only used by PrPr model registry.") def load_model( *, session: Session, @@ -261,10 +270,11 @@ def load_model( Loaded model. """ m = model_composer.ModelComposer(session=session, stage_path=stage_path) - m.load(meta_only=meta_only) + m.legacy_load(meta_only=meta_only) return m +@deprecated("Only used by PrPr model registry.") @overload def deploy( session: Session, @@ -290,6 +300,7 @@ def deploy( ... +@deprecated("Only used by PrPr model registry.") @overload def deploy( session: Session, @@ -319,6 +330,7 @@ def deploy( ... +@deprecated("Only used by PrPr model registry.") def deploy( session: Session, *, @@ -423,6 +435,7 @@ def deploy( return info +@deprecated("Only used by PrPr model registry.") @overload def predict( session: Session, @@ -443,6 +456,7 @@ def predict( ... +@deprecated("Only used by PrPr model registry.") @overload def predict( session: Session, @@ -462,6 +476,7 @@ def predict( ... +@deprecated("Only used by PrPr model registry.") def predict( session: Session, *, diff --git a/snowflake/ml/model/_client/model/BUILD.bazel b/snowflake/ml/model/_client/model/BUILD.bazel index ae70596e..db4b23f8 100644 --- a/snowflake/ml/model/_client/model/BUILD.bazel +++ b/snowflake/ml/model/_client/model/BUILD.bazel @@ -31,6 +31,7 @@ py_library( deps = [ "//snowflake/ml/_internal:telemetry", "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_client/ops:model_ops", "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", ], @@ -43,6 +44,7 @@ py_test( ":model_version_impl", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model:model_signature", + "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_client/ops:metadata_ops", "//snowflake/ml/model/_client/ops:model_ops", "//snowflake/ml/test_utils:mock_data_frame", diff --git a/snowflake/ml/model/_client/model/model_impl.py b/snowflake/ml/model/_client/model/model_impl.py index d70629f9..b34c2772 100644 --- a/snowflake/ml/model/_client/model/model_impl.py +++ b/snowflake/ml/model/_client/model/model_impl.py @@ -350,3 +350,30 @@ def unset_tag(self, tag_name: str) -> None: tag_name=tag_name_id, statement_params=statement_params, ) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def rename(self, model_name: str) -> None: + """Rename a model. Can be used to move a model when a fully qualified name is provided. + + Args: + model_name: The new model name. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + db, schema, model, _ = identifier.parse_schema_level_object_identifier(model_name) + new_model_db = sql_identifier.SqlIdentifier(db) if db else None + new_model_schema = sql_identifier.SqlIdentifier(schema) if schema else None + new_model_id = sql_identifier.SqlIdentifier(model) + self._model_ops.rename( + model_name=self._model_name, + new_model_db=new_model_db, + new_model_schema=new_model_schema, + new_model_name=new_model_id, + statement_params=statement_params, + ) + self._model_name = new_model_id diff --git a/snowflake/ml/model/_client/model/model_impl_test.py b/snowflake/ml/model/_client/model/model_impl_test.py index 1df9658f..0ca13129 100644 --- a/snowflake/ml/model/_client/model/model_impl_test.py +++ b/snowflake/ml/model/_client/model/model_impl_test.py @@ -318,6 +318,28 @@ def test_unset_tag_3(self) -> None: statement_params=mock.ANY, ) + def test_rename(self) -> None: + with mock.patch.object(self.m_model._model_ops, "rename") as mock_rename: + self.m_model.rename(model_name="MODEL2") + mock_rename.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + new_model_db=None, + new_model_schema=None, + new_model_name=sql_identifier.SqlIdentifier("MODEL2"), + statement_params=mock.ANY, + ) + + def test_rename_fully_qualified_name(self) -> None: + with mock.patch.object(self.m_model._model_ops, "rename") as mock_rename: + self.m_model.rename(model_name='TEMP."test".MODEL2') + mock_rename.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + new_model_db=sql_identifier.SqlIdentifier("TEMP"), + new_model_schema=sql_identifier.SqlIdentifier("test", case_sensitive=True), + new_model_name=sql_identifier.SqlIdentifier("MODEL2"), + statement_params=mock.ANY, + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_client/model/model_version_impl.py b/snowflake/ml/model/_client/model/model_version_impl.py index e5e5277f..78b94f12 100644 --- a/snowflake/ml/model/_client/model/model_version_impl.py +++ b/snowflake/ml/model/_client/model/model_version_impl.py @@ -1,17 +1,29 @@ +import enum +import pathlib +import tempfile +import warnings from typing import Any, Callable, Dict, List, Optional, Union import pandas as pd from snowflake.ml._internal import telemetry from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import type_hints as model_types from snowflake.ml.model._client.ops import metadata_ops, model_ops +from snowflake.ml.model._model_composer import model_composer from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema +from snowflake.ml.model._packager.model_handlers import snowmlmodel from snowflake.snowpark import dataframe _TELEMETRY_PROJECT = "MLOps" _TELEMETRY_SUBPROJECT = "ModelManagement" +class ExportMode(enum.Enum): + MODEL = "model" + FULL = "full" + + class ModelVersion: """Model Version Object representing a specific version of the model that could be run.""" @@ -240,6 +252,7 @@ def run( X: Union[pd.DataFrame, dataframe.DataFrame], *, function_name: Optional[str] = None, + partition_column: Optional[str] = None, strict_input_validation: bool = False, ) -> Union[pd.DataFrame, dataframe.DataFrame]: """Invoke a method in a model version object. @@ -248,12 +261,14 @@ def run( X: The input data, which could be a pandas DataFrame or Snowpark DataFrame. function_name: The function name to run. It is the name used to call a function in SQL. Defaults to None. It can only be None if there is only 1 method. + partition_column: The partition column name to partition by. strict_input_validation: Enable stricter validation for the input data. This will result value range based type validation to make sure your input data won't overflow when providing to the model. Raises: ValueError: When no method with the corresponding name is available. ValueError: When there are more than 1 target methods available in the model but no function name specified. + ValueError: When the partition column is not a valid Snowflake identifier. Returns: The prediction data. It would be the same type dataframe as your input. @@ -263,6 +278,10 @@ def run( subproject=_TELEMETRY_SUBPROJECT, ) + if partition_column is not None: + # Partition column must be a valid identifier + partition_column = sql_identifier.SqlIdentifier(partition_column) + functions: List[model_manifest_schema.ModelFunctionInfo] = self._functions if function_name: req_method_name = sql_identifier.SqlIdentifier(function_name).identifier() @@ -287,10 +306,126 @@ def run( target_function_info = functions[0] return self._model_ops.invoke_method( method_name=sql_identifier.SqlIdentifier(target_function_info["name"]), + method_function_type=target_function_info["target_method_function_type"], signature=target_function_info["signature"], X=X, model_name=self._model_name, version_name=self._version_name, strict_input_validation=strict_input_validation, + partition_column=partition_column, + statement_params=statement_params, + ) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"] + ) + def export(self, target_path: str, *, export_mode: ExportMode = ExportMode.MODEL) -> None: + """Export model files to a local directory. + + Args: + target_path: Path to a local directory to export files to. A directory will be created if does not exist. + export_mode: The mode to export the model. Defaults to ExportMode.MODEL. + ExportMode.MODEL: All model files including environment to load the model and model weights. + ExportMode.FULL: Additional files to run the model in Warehouse, besides all files in MODEL mode, + + Raises: + ValueError: Raised when the target path is a file or an non-empty folder. + """ + target_local_path = pathlib.Path(target_path) + if target_local_path.is_file() or any(target_local_path.iterdir()): + raise ValueError(f"Target path {target_local_path} is a file or an non-empty folder.") + + target_local_path.mkdir(parents=False, exist_ok=True) + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + self._model_ops.download_files( + model_name=self._model_name, + version_name=self._version_name, + target_path=target_local_path, + mode=export_mode.value, + statement_params=statement_params, + ) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["force", "options"] + ) + def load( + self, + *, + force: bool = False, + options: Optional[model_types.ModelLoadOption] = None, + ) -> model_types.SupportedModelType: + """Load the underlying original Python object back from a model. + This operation requires to have the exact the same environment as the one when logging the model, otherwise, + the model might be not functional or some other problems might occur. + + Args: + force: Bypass the best-effort environment validation. Defaults to False. + options: Options to specify when loading the model, check `snowflake.ml.model.type_hints` for available + options. Defaults to None. + + Raises: + ValueError: Raised when the best-effort environment validation fails. + + Returns: + The original Python object loaded from the model object. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + if not force: + with tempfile.TemporaryDirectory() as tmp_workspace_for_validation: + ws_path_for_validation = pathlib.Path(tmp_workspace_for_validation) + self._model_ops.download_files( + model_name=self._model_name, + version_name=self._version_name, + target_path=ws_path_for_validation, + mode="minimal", + statement_params=statement_params, + ) + pk_for_validation = model_composer.ModelComposer.load( + ws_path_for_validation, meta_only=True, options=options + ) + assert pk_for_validation.meta, ( + "Unable to load model metadata for validation. " + f"model_name={self._model_name}, version_name={self._version_name}" + ) + + validation_errors = pk_for_validation.meta.env.validate_with_local_env( + check_snowpark_ml_version=( + pk_for_validation.meta.model_type == snowmlmodel.SnowMLModelHandler.HANDLER_TYPE + ) + ) + if validation_errors: + raise ValueError( + f"Unable to load this model due to following validation errors: {validation_errors}. " + "Make sure your local environment is the same as that when you logged the model, " + "or if you believe it should work, specify `force=True` to bypass this check." + ) + + warnings.warn( + "Loading model requires to have the exact the same environment as the one when " + "logging the model, otherwise, the model might be not functional or " + "some other problems might occur.", + category=RuntimeWarning, + stacklevel=2, + ) + + # We need the folder to be existed. + workspace = pathlib.Path(tempfile.mkdtemp()) + self._model_ops.download_files( + model_name=self._model_name, + version_name=self._version_name, + target_path=workspace, + mode="model", statement_params=statement_params, ) + pk = model_composer.ModelComposer.load(workspace, meta_only=False, options=options) + assert pk.model, ( + "Unable to load model. " + f"model_name={self._model_name}, version_name={self._version_name}, metadata={pk.meta}" + ) + return pk.model diff --git a/snowflake/ml/model/_client/model/model_version_impl_test.py b/snowflake/ml/model/_client/model/model_version_impl_test.py index 94208228..a278a47a 100644 --- a/snowflake/ml/model/_client/model/model_version_impl_test.py +++ b/snowflake/ml/model/_client/model/model_version_impl_test.py @@ -1,12 +1,16 @@ +import os +import pathlib +import tempfile from typing import cast from unittest import mock from absl.testing import absltest from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.model import model_signature +from snowflake.ml.model import model_signature, type_hints as model_types from snowflake.ml.model._client.model import model_version_impl from snowflake.ml.model._client.ops import metadata_ops, model_ops +from snowflake.ml.model._model_composer import model_composer from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.test_utils import mock_data_frame, mock_session from snowflake.snowpark import Session @@ -17,7 +21,13 @@ model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="input"), ], outputs=[model_signature.FeatureSpec(name="output", dtype=model_signature.DataType.FLOAT)], - ) + ), + "predict_table": model_signature.ModelSignature( + inputs=[ + model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="input"), + ], + outputs=[model_signature.FeatureSpec(name="output", dtype=model_signature.DataType.FLOAT)], + ), } @@ -202,11 +212,13 @@ def test_run(self) -> None: self.m_mv.run(m_df, function_name='"predict"') mock_invoke_method.assert_called_once_with( method_name='"predict"', + method_function_type="FUNCTION", signature=_DUMMY_SIG["predict"], X=m_df, model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), strict_input_validation=False, + partition_column=None, statement_params=mock.ANY, ) @@ -214,11 +226,13 @@ def test_run(self) -> None: self.m_mv.run(m_df, function_name="__call__") mock_invoke_method.assert_called_once_with( method_name="__CALL__", + method_function_type="FUNCTION", signature=_DUMMY_SIG["predict"], X=m_df, model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), strict_input_validation=False, + partition_column=None, statement_params=mock.ANY, ) @@ -241,11 +255,13 @@ def test_run_without_method_name(self) -> None: self.m_mv.run(m_df) mock_invoke_method.assert_called_once_with( method_name='"predict"', + method_function_type="FUNCTION", signature=_DUMMY_SIG["predict"], X=m_df, model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), strict_input_validation=False, + partition_column=None, statement_params=mock.ANY, ) @@ -268,11 +284,63 @@ def test_run_strict(self) -> None: self.m_mv.run(m_df, strict_input_validation=True) mock_invoke_method.assert_called_once_with( method_name='"predict"', + method_function_type="FUNCTION", signature=_DUMMY_SIG["predict"], X=m_df, model_name=sql_identifier.SqlIdentifier("MODEL"), strict_input_validation=True, version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + partition_column=None, + statement_params=mock.ANY, + ) + + def test_run_table_function_method(self) -> None: + m_df = mock_data_frame.MockDataFrame() + m_methods = [ + model_manifest_schema.ModelFunctionInfo( + { + "name": '"predict_table"', + "target_method": "predict_table", + "target_method_function_type": "TABLE_FUNCTION", + "signature": _DUMMY_SIG["predict_table"], + } + ), + model_manifest_schema.ModelFunctionInfo( + { + "name": "__CALL__", + "target_method": "__call__", + "target_method_function_type": "TABLE_FUNCTION", + "signature": _DUMMY_SIG["predict_table"], + } + ), + ] + self.m_mv._functions = m_methods + + with mock.patch.object(self.m_mv._model_ops, "invoke_method", return_value=m_df) as mock_invoke_method: + self.m_mv.run(m_df, function_name='"predict_table"') + mock_invoke_method.assert_called_once_with( + method_name='"predict_table"', + method_function_type="TABLE_FUNCTION", + signature=_DUMMY_SIG["predict_table"], + X=m_df, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + strict_input_validation=False, + partition_column=None, + statement_params=mock.ANY, + ) + + with mock.patch.object(self.m_mv._model_ops, "invoke_method", return_value=m_df) as mock_invoke_method: + self.m_mv.run(m_df, function_name='"predict_table"', partition_column="PARTITION_COLUMN") + mock_invoke_method.assert_called_once_with( + method_name='"predict_table"', + method_function_type="TABLE_FUNCTION", + signature=_DUMMY_SIG["predict_table"], + X=m_df, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + strict_input_validation=False, + partition_column="PARTITION_COLUMN", statement_params=mock.ANY, ) @@ -318,6 +386,148 @@ def test_comment_setter(self) -> None: statement_params=mock.ANY, ) + def test_export_invalid_path(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "dummy"), mode="w") as f: + f.write("hello") + with self.assertRaisesRegex(ValueError, "is a file or an non-empty folder"): + self.m_mv.export(tmpdir) + + def test_export_model(self) -> None: + with mock.patch.object( + self.m_mv._model_ops, "download_files" + ) as mock_download_files, tempfile.TemporaryDirectory() as tmpdir: + self.m_mv.export(tmpdir) + mock_download_files.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + target_path=pathlib.Path(tmpdir), + mode="model", + statement_params=mock.ANY, + ) + + def test_export_full(self) -> None: + with mock.patch.object( + self.m_mv._model_ops, "download_files" + ) as mock_download_files, tempfile.TemporaryDirectory() as tmpdir: + self.m_mv.export(tmpdir, export_mode=model_version_impl.ExportMode.FULL) + mock_download_files.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + target_path=pathlib.Path(tmpdir), + mode="full", + statement_params=mock.ANY, + ) + + def test_load(self) -> None: + m_pk_for_validation = mock.MagicMock() + m_pk_for_validation.meta = mock.MagicMock() + m_pk_for_validation.meta.model_type = "foo" + m_pk_for_validation.meta.env = mock.MagicMock() + + m_model = mock.MagicMock() + m_pk = mock.MagicMock() + m_pk.meta = mock.MagicMock() + m_pk.model = m_model + + m_options = model_types.ModelLoadOption(use_gpu=False) + with mock.patch.object(self.m_mv._model_ops, "download_files") as mock_download_files, mock.patch.object( + model_composer.ModelComposer, "load", side_effect=[m_pk_for_validation, m_pk] + ) as mock_load, mock.patch.object( + m_pk_for_validation.meta.env, "validate_with_local_env", return_value=[] + ) as mock_validate_with_local_env: + self.assertEqual(self.m_mv.load(options=m_options), m_model) + mock_download_files.assert_has_calls( + [ + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + target_path=mock.ANY, + mode="minimal", + statement_params=mock.ANY, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + target_path=mock.ANY, + mode="model", + statement_params=mock.ANY, + ), + ] + ) + mock_load.assert_has_calls( + [ + mock.call(mock.ANY, meta_only=True, options=m_options), + mock.call(mock.ANY, meta_only=False, options=m_options), + ] + ) + mock_validate_with_local_env.assert_called_once_with(check_snowpark_ml_version=False) + + def test_load_error(self) -> None: + m_pk_for_validation = mock.MagicMock() + m_pk_for_validation.meta = mock.MagicMock() + m_pk_for_validation.meta.model_type = "snowml" + m_pk_for_validation.meta.env = mock.MagicMock() + + m_model = mock.MagicMock() + m_pk = mock.MagicMock() + m_pk.meta = mock.MagicMock() + m_pk.model = m_model + + m_options = model_types.ModelLoadOption(use_gpu=False) + with mock.patch.object(self.m_mv._model_ops, "download_files") as mock_download_files, mock.patch.object( + model_composer.ModelComposer, "load", side_effect=[m_pk_for_validation, m_pk] + ) as mock_load, mock.patch.object( + m_pk_for_validation.meta.env, "validate_with_local_env", return_value=["error"] + ) as mock_validate_with_local_env: + with self.assertRaisesRegex(ValueError, "Unable to load this model due to following validation errors"): + self.assertEqual(self.m_mv.load(options=m_options), m_model) + mock_download_files.assert_has_calls( + [ + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + target_path=mock.ANY, + mode="minimal", + statement_params=mock.ANY, + ), + ] + ) + mock_load.assert_has_calls( + [ + mock.call(mock.ANY, meta_only=True, options=m_options), + ] + ) + mock_validate_with_local_env.assert_called_once_with(check_snowpark_ml_version=True) + + def test_load_force(self) -> None: + m_model = mock.MagicMock() + m_pk = mock.MagicMock() + m_pk.meta = mock.MagicMock() + m_pk.model = m_model + + m_options = model_types.ModelLoadOption(use_gpu=False) + with mock.patch.object(self.m_mv._model_ops, "download_files") as mock_download_files, mock.patch.object( + model_composer.ModelComposer, "load", side_effect=[m_pk] + ) as mock_load: + self.assertEqual(self.m_mv.load(force=True, options=m_options), m_model) + mock_download_files.assert_has_calls( + [ + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + target_path=mock.ANY, + mode="model", + statement_params=mock.ANY, + ), + ] + ) + mock_load.assert_has_calls( + [ + mock.call(mock.ANY, meta_only=False, options=m_options), + ] + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_client/ops/BUILD.bazel b/snowflake/ml/model/_client/ops/BUILD.bazel index 74b915ac..b007d67a 100644 --- a/snowflake/ml/model/_client/ops/BUILD.bazel +++ b/snowflake/ml/model/_client/ops/BUILD.bazel @@ -22,6 +22,7 @@ py_library( "//snowflake/ml/model/_model_composer:model_composer", "//snowflake/ml/model/_model_composer/model_manifest", "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", + "//snowflake/ml/model/_packager/model_env", "//snowflake/ml/model/_packager/model_meta", "//snowflake/ml/model/_packager/model_meta:model_meta_schema", "//snowflake/ml/model/_signatures:snowpark_handler", diff --git a/snowflake/ml/model/_client/ops/model_ops.py b/snowflake/ml/model/_client/ops/model_ops.py index 38fea24f..e2019c27 100644 --- a/snowflake/ml/model/_client/ops/model_ops.py +++ b/snowflake/ml/model/_client/ops/model_ops.py @@ -1,7 +1,7 @@ +import os import pathlib import tempfile -from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Union, cast +from typing import Any, Dict, List, Literal, Optional, Union, cast import yaml @@ -19,7 +19,9 @@ model_manifest, model_manifest_schema, ) +from snowflake.ml.model._packager.model_env import model_env from snowflake.ml.model._packager.model_meta import model_meta +from snowflake.ml.model._packager.model_runtime import model_runtime from snowflake.ml.model._signatures import snowpark_handler from snowflake.snowpark import dataframe, row, session from snowflake.snowpark._internal import utils as snowpark_utils @@ -337,16 +339,6 @@ def get_model_version_manifest( mm = model_manifest.ModelManifest(pathlib.Path(tmpdir)) return mm.load() - @contextmanager - def _enable_model_details( - self, - *, - statement_params: Optional[Dict[str, Any]] = None, - ) -> Generator[None, None, None]: - self._model_client.config_model_details(enable=True, statement_params=statement_params) - yield - self._model_client.config_model_details(enable=False, statement_params=statement_params) - @staticmethod def _match_model_spec_with_sql_functions( sql_functions_names: List[sql_identifier.SqlIdentifier], target_methods: List[str] @@ -374,64 +366,63 @@ def get_functions( version_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, ) -> List[model_manifest_schema.ModelFunctionInfo]: - with self._enable_model_details(statement_params=statement_params): - raw_model_spec_res = self._model_client.show_versions( - model_name=model_name, - version_name=version_name, - check_model_details=True, - statement_params=statement_params, - )[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME] - model_spec_dict = yaml.safe_load(raw_model_spec_res) - model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict) - show_functions_res = self._model_version_client.show_functions( - model_name=model_name, - version_name=version_name, - statement_params=statement_params, + raw_model_spec_res = self._model_client.show_versions( + model_name=model_name, + version_name=version_name, + check_model_details=True, + statement_params={**(statement_params or {}), "SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL": True}, + )[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME] + model_spec_dict = yaml.safe_load(raw_model_spec_res) + model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict) + show_functions_res = self._model_version_client.show_functions( + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + function_names_and_types = [] + for r in show_functions_res: + function_name = sql_identifier.SqlIdentifier( + r[self._model_version_client.FUNCTION_NAME_COL_NAME], case_sensitive=True ) - function_names_and_types = [] - for r in show_functions_res: - function_name = sql_identifier.SqlIdentifier( - r[self._model_version_client.FUNCTION_NAME_COL_NAME], case_sensitive=True - ) - function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value - try: - return_type = r[self._model_version_client.FUNCTION_RETURN_TYPE_COL_NAME] - except KeyError: - pass - else: - if "TABLE" in return_type: - function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value - - function_names_and_types.append((function_name, function_type)) - - signatures = model_spec["signatures"] - function_names = [name for name, _ in function_names_and_types] - function_name_mapping = ModelOperator._match_model_spec_with_sql_functions( - function_names, list(signatures.keys()) - ) + function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value + try: + return_type = r[self._model_version_client.FUNCTION_RETURN_TYPE_COL_NAME] + except KeyError: + pass + else: + if "TABLE" in return_type: + function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value - return [ - model_manifest_schema.ModelFunctionInfo( - name=function_name.identifier(), - target_method=function_name_mapping[function_name], - target_method_function_type=function_type, - signature=model_signature.ModelSignature.from_dict( - signatures[function_name_mapping[function_name]] - ), - ) - for function_name, function_type in function_names_and_types - ] + function_names_and_types.append((function_name, function_type)) + + signatures = model_spec["signatures"] + function_names = [name for name, _ in function_names_and_types] + function_name_mapping = ModelOperator._match_model_spec_with_sql_functions( + function_names, list(signatures.keys()) + ) + + return [ + model_manifest_schema.ModelFunctionInfo( + name=function_name.identifier(), + target_method=function_name_mapping[function_name], + target_method_function_type=function_type, + signature=model_signature.ModelSignature.from_dict(signatures[function_name_mapping[function_name]]), + ) + for function_name, function_type in function_names_and_types + ] def invoke_method( self, *, method_name: sql_identifier.SqlIdentifier, + method_function_type: str, signature: model_signature.ModelSignature, X: Union[type_hints.SupportedDataType, dataframe.DataFrame], model_name: sql_identifier.SqlIdentifier, version_name: sql_identifier.SqlIdentifier, strict_input_validation: bool = False, + partition_column: Optional[sql_identifier.SqlIdentifier] = None, statement_params: Optional[Dict[str, str]] = None, ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]: identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED @@ -469,15 +460,27 @@ def invoke_method( if output_name in original_cols: original_cols.remove(output_name) - df_res = self._model_version_client.invoke_method( - method_name=method_name, - input_df=s_df, - input_args=input_args, - returns=returns, - model_name=model_name, - version_name=version_name, - statement_params=statement_params, - ) + if method_function_type == model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value: + df_res = self._model_version_client.invoke_function_method( + method_name=method_name, + input_df=s_df, + input_args=input_args, + returns=returns, + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value: + df_res = self._model_version_client.invoke_table_function_method( + method_name=method_name, + input_df=s_df, + input_args=input_args, + partition_column=partition_column, + returns=returns, + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) if keep_order: df_res = df_res.sort( @@ -486,7 +489,11 @@ def invoke_method( ) if not output_with_input_features: - df_res = df_res.drop(*original_cols) + cols_to_drop = original_cols + if partition_column is not None: + # don't drop partition column + cols_to_drop.remove(partition_column.identifier()) + df_res = df_res.drop(*cols_to_drop) # Get final result if not isinstance(X, dataframe.DataFrame): @@ -512,3 +519,66 @@ def delete_model_or_version( model_name=model_name, statement_params=statement_params, ) + + def rename( + self, + *, + model_name: sql_identifier.SqlIdentifier, + new_model_db: Optional[sql_identifier.SqlIdentifier], + new_model_schema: Optional[sql_identifier.SqlIdentifier], + new_model_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + self._model_client.rename( + model_name=model_name, + new_model_db=new_model_db, + new_model_schema=new_model_schema, + new_model_name=new_model_name, + statement_params=statement_params, + ) + + # Map indicating in different modes, the path to list and download. + # The boolean value indicates if it is a directory, + MODEL_FILE_DOWNLOAD_PATTERN = { + "minimal": { + pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH) + / model_meta.MODEL_METADATA_FILE: False, + pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH) / model_env._DEFAULT_ENV_DIR: True, + pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH) + / model_runtime.ModelRuntime.RUNTIME_DIR_REL_PATH: True, + }, + "model": {pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH): True}, + "full": {pathlib.PurePosixPath(os.curdir): True}, + } + + def download_files( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + target_path: pathlib.Path, + mode: Literal["full", "model", "minimal"] = "model", + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items(): + list_file_res = self._model_version_client.list_file( + model_name=model_name, + version_name=version_name, + file_path=remote_rel_path, + is_dir=is_dir, + statement_params=statement_params, + ) + file_list = [ + pathlib.PurePosixPath(*pathlib.PurePosixPath(row.name).parts[2:]) # versions//... + for row in list_file_res + ] + for stage_file_path in file_list: + local_file_dir = target_path / stage_file_path.parent + local_file_dir.mkdir(parents=True, exist_ok=True) + self._model_version_client.get_file( + model_name=model_name, + version_name=version_name, + file_path=stage_file_path, + target_path=local_file_dir, + statement_params=statement_params, + ) diff --git a/snowflake/ml/model/_client/ops/model_ops_test.py b/snowflake/ml/model/_client/ops/model_ops_test.py index 1bd8cb88..694683f3 100644 --- a/snowflake/ml/model/_client/ops/model_ops_test.py +++ b/snowflake/ml/model/_client/ops/model_ops_test.py @@ -1,3 +1,4 @@ +import pathlib from typing import List, cast from unittest import mock @@ -9,6 +10,7 @@ from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.model import model_signature from snowflake.ml.model._client.ops import model_ops +from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema from snowflake.ml.model._signatures import snowpark_handler from snowflake.ml.test_utils import mock_data_frame, mock_session @@ -503,12 +505,13 @@ def test_invoke_method_1(self) -> None: with mock.patch.object( snowpark_handler.SnowparkDataFrameHandler, "convert_from_df", return_value=m_df ) as mock_convert_from_df, mock.patch.object( - self.m_ops._model_version_client, "invoke_method", return_value=m_df + self.m_ops._model_version_client, "invoke_function_method", return_value=m_df ) as mock_invoke_method, mock.patch.object( snowpark_handler.SnowparkDataFrameHandler, "convert_to_df", return_value=pd_df ) as mock_convert_to_df: self.m_ops.invoke_method( method_name=sql_identifier.SqlIdentifier("PREDICT"), + method_function_type=model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value, signature=m_sig, X=pd_df, model_name=sql_identifier.SqlIdentifier("MODEL"), @@ -539,12 +542,13 @@ def test_invoke_method_1_no_drop(self) -> None: with mock.patch.object( snowpark_handler.SnowparkDataFrameHandler, "convert_from_df", return_value=m_df ) as mock_convert_from_df, mock.patch.object( - self.m_ops._model_version_client, "invoke_method", return_value=m_df + self.m_ops._model_version_client, "invoke_function_method", return_value=m_df ) as mock_invoke_method, mock.patch.object( snowpark_handler.SnowparkDataFrameHandler, "convert_to_df", return_value=pd_df ) as mock_convert_to_df: self.m_ops.invoke_method( method_name=sql_identifier.SqlIdentifier("PREDICT"), + method_function_type=model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value, signature=m_sig, X=pd_df, model_name=sql_identifier.SqlIdentifier("MODEL"), @@ -574,12 +578,13 @@ def test_invoke_method_2(self) -> None: ) as mock_convert_from_df, mock.patch.object( model_signature, "_validate_snowpark_data", return_value=model_signature.SnowparkIdentifierRule.NORMALIZED ) as mock_validate_snowpark_data, mock.patch.object( - self.m_ops._model_version_client, "invoke_method", return_value=m_df + self.m_ops._model_version_client, "invoke_function_method", return_value=m_df ) as mock_invoke_method, mock.patch.object( snowpark_handler.SnowparkDataFrameHandler, "convert_to_df" ) as mock_convert_to_df: self.m_ops.invoke_method( method_name=sql_identifier.SqlIdentifier("PREDICT"), + method_function_type=model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value, signature=m_sig, X=cast(DataFrame, m_df), model_name=sql_identifier.SqlIdentifier("MODEL"), @@ -609,12 +614,13 @@ def test_invoke_method_3(self) -> None: ) as mock_convert_from_df, mock.patch.object( model_signature, "_validate_snowpark_data", return_value=model_signature.SnowparkIdentifierRule.NORMALIZED ) as mock_validate_snowpark_data, mock.patch.object( - self.m_ops._model_version_client, "invoke_method", return_value=m_df + self.m_ops._model_version_client, "invoke_function_method", return_value=m_df ) as mock_invoke_method, mock.patch.object( snowpark_handler.SnowparkDataFrameHandler, "convert_to_df" ) as mock_convert_to_df: self.m_ops.invoke_method( method_name=sql_identifier.SqlIdentifier("PREDICT"), + method_function_type=model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value, signature=m_sig, X=cast(DataFrame, m_df), model_name=sql_identifier.SqlIdentifier("MODEL"), @@ -636,6 +642,84 @@ def test_invoke_method_3(self) -> None: ) mock_convert_to_df.assert_not_called() + def test_invoke_method_table_function(self) -> None: + pd_df = pd.DataFrame([["1.0"]], columns=["input"], dtype=np.float32) + m_sig = _DUMMY_SIG["predict_table"] + m_df = mock_data_frame.MockDataFrame() + m_df.__setattr__("_statement_params", None) + m_df.__setattr__("columns", ["COL1", "COL2"]) + m_df.add_mock_sort("_ID", ascending=True).add_mock_drop("COL1", "COL2") + with mock.patch.object( + snowpark_handler.SnowparkDataFrameHandler, "convert_from_df", return_value=m_df + ) as mock_convert_from_df, mock.patch.object( + self.m_ops._model_version_client, "invoke_table_function_method", return_value=m_df + ) as mock_invoke_method, mock.patch.object( + snowpark_handler.SnowparkDataFrameHandler, "convert_to_df", return_value=pd_df + ) as mock_convert_to_df: + self.m_ops.invoke_method( + method_name=sql_identifier.SqlIdentifier("PREDICT_TABLE"), + method_function_type=model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value, + signature=m_sig, + X=pd_df, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_convert_from_df.assert_called_once_with( + self.c_session, mock.ANY, keep_order=True, features=m_sig.inputs + ) + mock_invoke_method.assert_called_once_with( + method_name=sql_identifier.SqlIdentifier("PREDICT_TABLE"), + input_df=m_df, + input_args=['"input"'], + partition_column=None, + returns=[("output", spt.FloatType(), '"output"')], + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_convert_to_df.assert_called_once_with(m_df, features=m_sig.outputs) + + def test_invoke_method_table_function_partition_column(self) -> None: + pd_df = pd.DataFrame([["1.0"]], columns=["input"], dtype=np.float32) + m_sig = _DUMMY_SIG["predict_table"] + m_df = mock_data_frame.MockDataFrame() + m_df.__setattr__("_statement_params", None) + m_df.__setattr__("columns", ["COL1", "COL2", "PARTITION_COLUMN"]) + m_df.add_mock_sort("_ID", ascending=True).add_mock_drop("COL1", "COL2") + partition_column = sql_identifier.SqlIdentifier("PARTITION_COLUMN") + with mock.patch.object( + snowpark_handler.SnowparkDataFrameHandler, "convert_from_df", return_value=m_df + ) as mock_convert_from_df, mock.patch.object( + self.m_ops._model_version_client, "invoke_table_function_method", return_value=m_df + ) as mock_invoke_method, mock.patch.object( + snowpark_handler.SnowparkDataFrameHandler, "convert_to_df", return_value=pd_df + ) as mock_convert_to_df: + self.m_ops.invoke_method( + method_name=sql_identifier.SqlIdentifier("PREDICT_TABLE"), + method_function_type=model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value, + signature=m_sig, + X=pd_df, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + partition_column=partition_column, + statement_params=self.m_statement_params, + ) + mock_convert_from_df.assert_called_once_with( + self.c_session, mock.ANY, keep_order=True, features=m_sig.inputs + ) + mock_invoke_method.assert_called_once_with( + method_name=sql_identifier.SqlIdentifier("PREDICT_TABLE"), + input_df=m_df, + input_args=['"input"'], + partition_column=partition_column, + returns=[("output", spt.FloatType(), '"output"')], + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_convert_to_df.assert_called_once_with(m_df, features=m_sig.outputs) + def test_get_comment_1(self) -> None: m_list_res = [ Row( @@ -819,18 +903,43 @@ def test_delete_model_or_version_2(self) -> None: statement_params=self.m_statement_params, ) - def test_enable_model_details(self) -> None: + def test_rename(self) -> None: with mock.patch.object( self.m_ops._model_client, - "config_model_details", - ) as mock_config_model_details: - with self.m_ops._enable_model_details(statement_params=self.m_statement_params): - mock_config_model_details.assert_called_with( - enable=True, - statement_params=self.m_statement_params, - ) - mock_config_model_details.assert_called_with( - enable=False, + "rename", + ) as mock_rename: + self.m_ops.rename( + model_name=sql_identifier.SqlIdentifier("MODEL"), + new_model_db=None, + new_model_schema=None, + new_model_name=sql_identifier.SqlIdentifier("MODEL2"), + statement_params=self.m_statement_params, + ) + mock_rename.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + new_model_db=None, + new_model_schema=None, + new_model_name=sql_identifier.SqlIdentifier("MODEL2"), + statement_params=self.m_statement_params, + ) + + def test_rename_fully_qualified_name(self) -> None: + with mock.patch.object( + self.m_ops._model_client, + "rename", + ) as mock_rename: + self.m_ops.rename( + model_name=sql_identifier.SqlIdentifier("MODEL"), + new_model_db=sql_identifier.SqlIdentifier("TEMP"), + new_model_schema=sql_identifier.SqlIdentifier("test", case_sensitive=True), + new_model_name=sql_identifier.SqlIdentifier("MODEL2"), + statement_params=self.m_statement_params, + ) + mock_rename.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + new_model_db=sql_identifier.SqlIdentifier("TEMP"), + new_model_schema=sql_identifier.SqlIdentifier("test", case_sensitive=True), + new_model_name=sql_identifier.SqlIdentifier("MODEL2"), statement_params=self.m_statement_params, ) @@ -874,7 +983,7 @@ def test_get_functions(self) -> None: Row(name="predict", return_type="NUMBER"), Row(name="predict_table", return_type="TABLE (RESULTS VARCHAR)"), ] - with mock.patch.object(self.m_ops, "_enable_model_details",) as mock_enable_model_details, mock.patch.object( + with mock.patch.object( self.m_ops._model_client, "show_versions", return_value=m_show_versions_result, @@ -890,12 +999,11 @@ def test_get_functions(self) -> None: version_name=sql_identifier.SqlIdentifier('"v1"'), statement_params=self.m_statement_params, ) - mock_enable_model_details.assert_called_once_with(statement_params=self.m_statement_params) mock_show_versions.assert_called_once_with( model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier('"v1"'), check_model_details=True, - statement_params=self.m_statement_params, + statement_params={**self.m_statement_params, "SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL": True}, ) mock_show_functions.assert_called_once_with( model_name=sql_identifier.SqlIdentifier("MODEL"), @@ -904,6 +1012,232 @@ def test_get_functions(self) -> None: ) mock_validate_model_metadata.assert_called_once_with(m_spec) + def test_download_files_minimal(self) -> None: + m_list_files_res = [ + [Row(name="versions/v1/model/model.yaml", size=419, md5="1234", last_modified="")], + [ + Row(name="versions/v1/model/env/conda.yml", size=419, md5="1234", last_modified=""), + Row(name="versions/v1/model/env/requirements.txt", size=419, md5="1234", last_modified=""), + ], + [ + Row(name="versions/v1/model/runtimes/cpu/env/conda.yml", size=419, md5="1234", last_modified=""), + Row(name="versions/v1/model/runtimes/cpu/env/requirements.txt", size=419, md5="1234", last_modified=""), + ], + ] + m_local_path = pathlib.Path("/tmp") + with mock.patch.object( + self.m_ops._model_version_client, + "list_file", + side_effect=m_list_files_res, + ) as mock_list_file, mock.patch.object( + self.m_ops._model_version_client, "get_file" + ) as mock_get_file, mock.patch.object( + pathlib.Path, "mkdir" + ): + self.m_ops.download_files( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + target_path=m_local_path, + mode="minimal", + statement_params=self.m_statement_params, + ) + mock_list_file.assert_has_calls( + [ + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/model.yaml"), + is_dir=False, + statement_params=self.m_statement_params, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/env"), + is_dir=True, + statement_params=self.m_statement_params, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/runtimes"), + is_dir=True, + statement_params=self.m_statement_params, + ), + ] + ) + mock_get_file.assert_has_calls( + [ + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/model.yaml"), + target_path=m_local_path / "model", + statement_params=self.m_statement_params, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/env/conda.yml"), + target_path=m_local_path / "model" / "env", + statement_params=self.m_statement_params, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/env/requirements.txt"), + target_path=m_local_path / "model" / "env", + statement_params=self.m_statement_params, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/runtimes/cpu/env/conda.yml"), + target_path=m_local_path / "model" / "runtimes" / "cpu" / "env", + statement_params=self.m_statement_params, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/runtimes/cpu/env/requirements.txt"), + target_path=m_local_path / "model" / "runtimes" / "cpu" / "env", + statement_params=self.m_statement_params, + ), + ] + ) + + def test_download_files_model(self) -> None: + m_list_files_res = [ + [ + Row(name="versions/v1/model/model.yaml", size=419, md5="1234", last_modified=""), + Row(name="versions/v1/model/env/conda.yml", size=419, md5="1234", last_modified=""), + Row(name="versions/v1/model/env/requirements.txt", size=419, md5="1234", last_modified=""), + ], + ] + m_local_path = pathlib.Path("/tmp") + with mock.patch.object( + self.m_ops._model_version_client, + "list_file", + side_effect=m_list_files_res, + ) as mock_list_file, mock.patch.object( + self.m_ops._model_version_client, "get_file" + ) as mock_get_file, mock.patch.object( + pathlib.Path, "mkdir" + ): + self.m_ops.download_files( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + target_path=m_local_path, + mode="model", + statement_params=self.m_statement_params, + ) + mock_list_file.assert_has_calls( + [ + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model"), + is_dir=True, + statement_params=self.m_statement_params, + ), + ] + ) + mock_get_file.assert_has_calls( + [ + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/model.yaml"), + target_path=m_local_path / "model", + statement_params=self.m_statement_params, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/env/conda.yml"), + target_path=m_local_path / "model" / "env", + statement_params=self.m_statement_params, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/env/requirements.txt"), + target_path=m_local_path / "model" / "env", + statement_params=self.m_statement_params, + ), + ] + ) + + def test_download_files_full(self) -> None: + m_list_files_res = [ + [ + Row(name="versions/v1/MANIFEST.yml", size=419, md5="1234", last_modified=""), + Row(name="versions/v1/model/model.yaml", size=419, md5="1234", last_modified=""), + Row(name="versions/v1/model/env/conda.yml", size=419, md5="1234", last_modified=""), + Row(name="versions/v1/model/env/requirements.txt", size=419, md5="1234", last_modified=""), + ], + ] + m_local_path = pathlib.Path("/tmp") + with mock.patch.object( + self.m_ops._model_version_client, + "list_file", + side_effect=m_list_files_res, + ) as mock_list_file, mock.patch.object( + self.m_ops._model_version_client, "get_file" + ) as mock_get_file, mock.patch.object( + pathlib.Path, "mkdir" + ): + self.m_ops.download_files( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + target_path=m_local_path, + mode="full", + statement_params=self.m_statement_params, + ) + mock_list_file.assert_has_calls( + [ + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("."), + is_dir=True, + statement_params=self.m_statement_params, + ), + ] + ) + mock_get_file.assert_has_calls( + [ + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("MANIFEST.yml"), + target_path=m_local_path, + statement_params=self.m_statement_params, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/model.yaml"), + target_path=m_local_path / "model", + statement_params=self.m_statement_params, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/env/conda.yml"), + target_path=m_local_path / "model" / "env", + statement_params=self.m_statement_params, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + file_path=pathlib.PurePosixPath("model/env/requirements.txt"), + target_path=m_local_path / "model" / "env", + statement_params=self.m_statement_params, + ), + ] + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_client/sql/BUILD.bazel b/snowflake/ml/model/_client/sql/BUILD.bazel index 567ee9a7..affe2595 100644 --- a/snowflake/ml/model/_client/sql/BUILD.bazel +++ b/snowflake/ml/model/_client/sql/BUILD.bazel @@ -11,7 +11,6 @@ py_library( deps = [ "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/_internal/utils:query_result_checker", - "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", ], diff --git a/snowflake/ml/model/_client/sql/model.py b/snowflake/ml/model/_client/sql/model.py index 0b4e444e..eb10b44b 100644 --- a/snowflake/ml/model/_client/sql/model.py +++ b/snowflake/ml/model/_client/sql/model.py @@ -121,21 +121,23 @@ def drop_model( statement_params=statement_params, ).has_dimensions(expected_rows=1, expected_cols=1).validate() - def config_model_details( + def rename( self, *, - enable: bool, + model_name: sql_identifier.SqlIdentifier, + new_model_db: Optional[sql_identifier.SqlIdentifier], + new_model_schema: Optional[sql_identifier.SqlIdentifier], + new_model_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - if enable: - query_result_checker.SqlResultValidator( - self._session, - "ALTER SESSION SET SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL=true", - statement_params=statement_params, - ).has_dimensions(expected_rows=1, expected_cols=1).validate() - else: - query_result_checker.SqlResultValidator( - self._session, - "ALTER SESSION UNSET SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL", - statement_params=statement_params, - ).has_dimensions(expected_rows=1, expected_cols=1).validate() + # Use registry's database and schema if a non fully qualified new model name is provided. + new_fully_qualified_name = identifier.get_schema_level_object_identifier( + new_model_db.identifier() if new_model_db else self._database_name.identifier(), + new_model_schema.identifier() if new_model_schema else self._schema_name.identifier(), + new_model_name.identifier(), + ) + query_result_checker.SqlResultValidator( + self._session, + f"ALTER MODEL {self.fully_qualified_model_name(model_name)} RENAME TO {new_fully_qualified_name}", + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() diff --git a/snowflake/ml/model/_client/sql/model_test.py b/snowflake/ml/model/_client/sql/model_test.py index 75131a21..ccb77293 100644 --- a/snowflake/ml/model/_client/sql/model_test.py +++ b/snowflake/ml/model/_client/sql/model_test.py @@ -198,35 +198,41 @@ def test_drop_model(self) -> None: statement_params=m_statement_params, ) - def test_config_model_details_enable(self) -> None: + def test_rename(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame( - collect_result=[Row("Session successfully altered.")], collect_statement_params=m_statement_params + collect_result=[Row("Model MODEL successfully dropped.")], collect_statement_params=m_statement_params ) - self.m_session.add_mock_sql("""ALTER SESSION SET SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL=true""", m_df) + self.m_session.add_mock_sql("""ALTER MODEL TEMP."test".MODEL RENAME TO TEMP."test".MODEL2""", m_df) c_session = cast(Session, self.m_session) model_sql.ModelSQLClient( c_session, database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), - ).config_model_details( - enable=True, + ).rename( + model_name=sql_identifier.SqlIdentifier("MODEL"), + new_model_db=None, + new_model_schema=None, + new_model_name=sql_identifier.SqlIdentifier("MODEL2"), statement_params=m_statement_params, ) - def test_config_model_details_false(self) -> None: + def test_rename_fully_qualified_name(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame( - collect_result=[Row("Session successfully altered.")], collect_statement_params=m_statement_params + collect_result=[Row("Model MODEL successfully dropped.")], collect_statement_params=m_statement_params ) - self.m_session.add_mock_sql("""ALTER SESSION UNSET SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL""", m_df) + self.m_session.add_mock_sql("""ALTER MODEL TEMP."test".MODEL RENAME TO TEMP2."test2".MODEL2""", m_df) c_session = cast(Session, self.m_session) model_sql.ModelSQLClient( c_session, database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), - ).config_model_details( - enable=False, + ).rename( + model_name=sql_identifier.SqlIdentifier("MODEL"), + new_model_db=sql_identifier.SqlIdentifier("TEMP2"), + new_model_schema=sql_identifier.SqlIdentifier("test2", case_sensitive=True), + new_model_name=sql_identifier.SqlIdentifier("MODEL2"), statement_params=m_statement_params, ) diff --git a/snowflake/ml/model/_client/sql/model_version.py b/snowflake/ml/model/_client/sql/model_version.py index 62aae054..0930db9c 100644 --- a/snowflake/ml/model/_client/sql/model_version.py +++ b/snowflake/ml/model/_client/sql/model_version.py @@ -96,6 +96,38 @@ def set_default_version( statement_params=statement_params, ).has_dimensions(expected_rows=1, expected_cols=1).validate() + def list_file( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + file_path: pathlib.PurePosixPath, + is_dir: bool = False, + statement_params: Optional[Dict[str, Any]] = None, + ) -> List[row.Row]: + # Workaround for snowURL bug. + trailing_slash = "/" if is_dir else "" + + stage_location = ( + pathlib.PurePosixPath( + self.fully_qualified_model_name(model_name), "versions", version_name.resolved(), file_path + ).as_posix() + + trailing_slash + ) + stage_location_url = ParseResult( + scheme="snow", netloc="model", path=stage_location, params="", query="", fragment="" + ).geturl() + + return ( + query_result_checker.SqlResultValidator( + self._session, + f"List {_normalize_url_for_sql(stage_location_url)}", + statement_params=statement_params, + ) + .has_column("name") + .validate() + ) + def get_file( self, *, @@ -162,7 +194,7 @@ def set_comment( statement_params=statement_params, ).has_dimensions(expected_rows=1, expected_cols=1).validate() - def invoke_method( + def invoke_function_method( self, *, model_name: sql_identifier.SqlIdentifier, @@ -232,6 +264,82 @@ def invoke_method( return output_df + def invoke_table_function_method( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + method_name: sql_identifier.SqlIdentifier, + input_df: dataframe.DataFrame, + input_args: List[sql_identifier.SqlIdentifier], + returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]], + partition_column: Optional[sql_identifier.SqlIdentifier], + statement_params: Optional[Dict[str, Any]] = None, + ) -> dataframe.DataFrame: + with_statements = [] + if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0: + INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT" + with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})") + else: + tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE) + INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier( + self._database_name.identifier(), + self._schema_name.identifier(), + tmp_table_name, + ) + input_df.write.save_as_table( # type: ignore[call-overload] + table_name=INTERMEDIATE_TABLE_NAME, + mode="errorifexists", + table_type="temporary", + statement_params=statement_params, + ) + + module_version_alias = "MODEL_VERSION_ALIAS" + with_statements.append( + f"{module_version_alias} AS " + f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}" + ) + + partition_by = partition_column.identifier() if partition_column is not None else "1" + + args_sql_list = [] + for input_arg_value in input_args: + args_sql_list.append(input_arg_value) + + args_sql = ", ".join(args_sql_list) + + sql = textwrap.dedent( + f"""WITH {','.join(with_statements)} + SELECT *, + FROM {INTERMEDIATE_TABLE_NAME}, + TABLE({module_version_alias}!{method_name.identifier()}({args_sql}) + OVER (PARTITION BY {partition_by}))""" + ) + + output_df = self._session.sql(sql) + + # Prepare the output + output_cols = [] + output_names = [] + + for output_name, output_type, output_col_name in returns: + output_cols.append(F.col(output_name).astype(output_type)) + output_names.append(output_col_name) + + if partition_column is not None: + output_cols.append(F.col(partition_column.identifier())) + output_names.append(partition_column) + + output_df = output_df.with_columns( + col_names=output_names, + values=output_cols, + ) + + if statement_params: + output_df._statement_params = statement_params # type: ignore[assignment] + + return output_df + def set_metadata( self, metadata_dict: Dict[str, Any], diff --git a/snowflake/ml/model/_client/sql/model_version_test.py b/snowflake/ml/model/_client/sql/model_version_test.py index bceb3801..a95cc5ce 100644 --- a/snowflake/ml/model/_client/sql/model_version_test.py +++ b/snowflake/ml/model/_client/sql/model_version_test.py @@ -1,3 +1,4 @@ +import os import pathlib from typing import cast from unittest import mock @@ -89,6 +90,49 @@ def test_set_comment(self) -> None: statement_params=m_statement_params, ) + def test_list_file(self) -> None: + m_statement_params = {"test": "1"} + m_res = [Row(name="versions/v1/model.yaml", size=419, md5="1234", last_modified="")] + m_df = mock_data_frame.MockDataFrame( + collect_result=m_res, + collect_statement_params=m_statement_params, + ) + self.m_session.add_mock_sql("""LIST 'snow://model/TEMP."test".MODEL/versions/v1/model.yaml'""", m_df) + c_session = cast(Session, self.m_session) + res = model_version_sql.ModelVersionSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).list_file( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + file_path=pathlib.PurePosixPath("model.yaml"), + statement_params=m_statement_params, + ) + self.assertEqual(res, m_res) + + def test_list_file_curdir_dir(self) -> None: + m_statement_params = {"test": "1"} + m_res = [Row(name="versions/v1/MANIFEST.yml", size=419, md5="1234", last_modified="")] + m_df = mock_data_frame.MockDataFrame( + collect_result=m_res, + collect_statement_params=m_statement_params, + ) + self.m_session.add_mock_sql("""LIST 'snow://model/TEMP."test".MODEL/versions/v1/'""", m_df) + c_session = cast(Session, self.m_session) + res = model_version_sql.ModelVersionSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).list_file( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + file_path=pathlib.PurePosixPath(os.curdir), + is_dir=True, + statement_params=m_statement_params, + ) + self.assertEqual(res, m_res) + def test_get_file(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame( @@ -113,7 +157,7 @@ def test_get_file(self) -> None: ) self.assertEqual(res, pathlib.Path("/tmp/model.yaml")) - def test_invoke_method(self) -> None: + def test_invoke_function_method(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame() self.m_session.add_mock_sql( @@ -135,7 +179,7 @@ def test_invoke_method(self) -> None: c_session, database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), - ).invoke_method( + ).invoke_function_method( model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("V1"), method_name=sql_identifier.SqlIdentifier("PREDICT"), @@ -152,7 +196,7 @@ def test_invoke_method(self) -> None: statement_params=m_statement_params, ) - def test_invoke_method_1(self) -> None: + def test_invoke_function_method_1(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame() self.m_session.add_mock_sql( @@ -174,7 +218,7 @@ def test_invoke_method_1(self) -> None: c_session, database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), - ).invoke_method( + ).invoke_function_method( model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("V1"), method_name=sql_identifier.SqlIdentifier("PREDICT"), @@ -191,7 +235,7 @@ def test_invoke_method_1(self) -> None: statement_params=m_statement_params, ) - def test_invoke_method_2(self) -> None: + def test_invoke_function_method_2(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame() self.m_session.add_mock_sql( @@ -209,7 +253,7 @@ def test_invoke_method_2(self) -> None: c_session, database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), - ).invoke_method( + ).invoke_function_method( model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("V1"), method_name=sql_identifier.SqlIdentifier("PREDICT"), @@ -219,6 +263,48 @@ def test_invoke_method_2(self) -> None: statement_params=m_statement_params, ) + def test_invoke_table_function_method_partition_col(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame() + partition_column = "partition_col" + self.m_session.add_mock_sql( + f"""WITH MODEL_VERSION_ALIAS AS MODEL TEMP."test".MODEL VERSION V1 + SELECT *, + FROM TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123, + TABLE(MODEL_VERSION_ALIAS!PREDICT_TABLE(COL1, COL2) OVER (PARTITION BY {partition_column})) + """, + m_df, + ) + m_df.add_mock_with_columns(["OUTPUT_1"], [F.col("OUTPUT_1")]) + c_session = cast(Session, self.m_session) + mock_writer = mock.MagicMock() + m_df.__setattr__("write", mock_writer) + m_df.__setattr__("queries", {"queries": ["query_1", "query_2"], "post_actions": []}) + with mock.patch.object(mock_writer, "save_as_table") as mock_save_as_table, mock.patch.object( + snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_TABLE_ABCDEF0123" + ) as mock_random_name_for_temp_object: + model_version_sql.ModelVersionSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).invoke_table_function_method( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + method_name=sql_identifier.SqlIdentifier("PREDICT_TABLE"), + input_df=cast(DataFrame, m_df), + input_args=[sql_identifier.SqlIdentifier("COL1"), sql_identifier.SqlIdentifier("COL2")], + returns=[("output_1", spt.IntegerType(), sql_identifier.SqlIdentifier("OUTPUT_1"))], + partition_column=sql_identifier.SqlIdentifier(partition_column), + statement_params=m_statement_params, + ) + mock_random_name_for_temp_object.assert_called_once_with(snowpark_utils.TempObjectType.TABLE) + mock_save_as_table.assert_called_once_with( + table_name='TEMP."test".SNOWPARK_TEMP_TABLE_ABCDEF0123', + mode="errorifexists", + table_type="temporary", + statement_params=m_statement_params, + ) + def test_set_metadata(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame(collect_result=[Row("")], collect_statement_params=m_statement_params) diff --git a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py b/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py index 60a0bbdb..ec0fae69 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +++ b/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py @@ -37,6 +37,7 @@ def __init__( session: snowpark.Session, artifact_stage_location: str, compute_pool: str, + job_name: str, external_access_integrations: List[str], ) -> None: """Initialization @@ -49,6 +50,7 @@ def __init__( artifact_stage_location: Spec file and future deployment related artifacts will be stored under {stage}/models/{model_id} compute_pool: The compute pool used to run docker image build workload. + job_name: job_name to use. external_access_integrations: EAIs for network connection. """ self.context_dir = context_dir @@ -58,6 +60,7 @@ def __init__( self.artifact_stage_location = artifact_stage_location self.compute_pool = compute_pool self.external_access_integrations = external_access_integrations + self.job_name = job_name self.client = snowservice_client.SnowServiceClient(session) assert artifact_stage_location.startswith( @@ -203,8 +206,9 @@ def _construct_and_upload_job_spec(self, base_image: str, kaniko_shell_script_st ) def _launch_kaniko_job(self, spec_stage_location: str) -> None: - logger.debug("Submitting job for building docker image with kaniko") + logger.debug(f"Submitting job {self.job_name} for building docker image with kaniko") self.client.create_job( + job_name=self.job_name, compute_pool=self.compute_pool, spec_stage_location=spec_stage_location, external_access_integrations=self.external_access_integrations, diff --git a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder_test.py b/snowflake/ml/model/_deploy_client/image_builds/server_image_builder_test.py index e553903f..68270e98 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder_test.py +++ b/snowflake/ml/model/_deploy_client/image_builds/server_image_builder_test.py @@ -14,6 +14,7 @@ def setUp(self) -> None: self.image_repo = "mock_image_repo" self.artifact_stage_location = "@stage/models/id" self.compute_pool = "test_pool" + self.job_name = "abcd" self.context_tarball_stage_location = f"{self.artifact_stage_location}/context.tar.gz" self.full_image_name = "org-account.registry.snowflakecomputing.com/db/schema/repo/image:latest" self.eais = ["eai_1"] @@ -34,6 +35,7 @@ def test_construct_and_upload_docker_entrypoint_script(self, m_session_class: mo session=m_session, artifact_stage_location=self.artifact_stage_location, compute_pool=self.compute_pool, + job_name=self.job_name, external_access_integrations=self.eais, ) diff --git a/snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template b/snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template index 32c7ba64..7d6494ec 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +++ b/snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template @@ -30,6 +30,7 @@ USER mambauser # Set MAMBA_DOCKERFILE_ACTIVATE=1 to activate the conda environment during build time. ARG MAMBA_DOCKERFILE_ACTIVATE=1 +ARG MAMBA_NO_LOW_SPEED_LIMIT=1 # Bitsandbytes uses this ENVVAR to determine CUDA library location ENV CONDA_PREFIX=/opt/conda diff --git a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture index ab6a9ecd..1ae0ec8e 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture +++ b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture @@ -27,6 +27,7 @@ RUN chmod +x ./gunicorn_run.sh USER mambauser ARG MAMBA_DOCKERFILE_ACTIVATE=1 +ARG MAMBA_NO_LOW_SPEED_LIMIT=1 ENV CONDA_PREFIX=/opt/conda RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="" \ micromamba install -y -n base -f conda.yml && \ diff --git a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_CUDA b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_CUDA index 7eb9358f..0a0dff19 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_CUDA +++ b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_CUDA @@ -27,6 +27,7 @@ RUN chmod +x ./gunicorn_run.sh USER mambauser ARG MAMBA_DOCKERFILE_ACTIVATE=1 +ARG MAMBA_NO_LOW_SPEED_LIMIT=1 ENV CONDA_PREFIX=/opt/conda RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="11.7" \ micromamba install -y -n base -f conda.yml && \ diff --git a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_model b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_model index a73c5173..8619b7b0 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_model +++ b/snowflake/ml/model/_deploy_client/image_builds/test_fixtures/dockerfile_test_fixture_with_model @@ -27,6 +27,7 @@ RUN chmod +x ./gunicorn_run.sh USER mambauser ARG MAMBA_DOCKERFILE_ACTIVATE=1 +ARG MAMBA_NO_LOW_SPEED_LIMIT=1 ENV CONDA_PREFIX=/opt/conda RUN --mount=type=cache,target=/opt/conda/pkgs CONDA_OVERRIDE_CUDA="11.7" \ micromamba install -y -n base -f conda.yml && \ diff --git a/snowflake/ml/model/_deploy_client/snowservice/deploy.py b/snowflake/ml/model/_deploy_client/snowservice/deploy.py index 5415be67..5c6ecd0a 100644 --- a/snowflake/ml/model/_deploy_client/snowservice/deploy.py +++ b/snowflake/ml/model/_deploy_client/snowservice/deploy.py @@ -346,6 +346,7 @@ def __init__( (db, schema, _, _) = identifier.parse_schema_level_object_identifier(service_func_name) self._service_name = identifier.get_schema_level_object_identifier(db, schema, f"service_{model_id}") + self._job_name = identifier.get_schema_level_object_identifier(db, schema, f"build_{model_id}") # Spec file and future deployment related artifacts will be stored under {stage}/models/{model_id} self._model_artifact_stage_location = posixpath.join(deployment_stage_path, "models", self.id) self.debug_dir: Optional[str] = None @@ -468,6 +469,7 @@ def _build_and_upload_image(self, context_dir: str, image_repo: str, full_image_ session=self.session, artifact_stage_location=self._model_artifact_stage_location, compute_pool=self.options.compute_pool, + job_name=self._job_name, external_access_integrations=self.options.external_access_integrations, ) else: diff --git a/snowflake/ml/model/_deploy_client/utils/constants.py b/snowflake/ml/model/_deploy_client/utils/constants.py index edc441df..f32d2e19 100644 --- a/snowflake/ml/model/_deploy_client/utils/constants.py +++ b/snowflake/ml/model/_deploy_client/utils/constants.py @@ -17,11 +17,6 @@ class ResourceStatus(Enum): INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error. -RESOURCE_TO_STATUS_FUNCTION_MAPPING = { - ResourceType.SERVICE: "SYSTEM$GET_SERVICE_STATUS", - ResourceType.JOB: "SYSTEM$GET_JOB_STATUS", -} - PREDICT = "predict" STAGE = "stage" COMPUTE_POOL = "compute_pool" diff --git a/snowflake/ml/model/_deploy_client/utils/snowservice_client.py b/snowflake/ml/model/_deploy_client/utils/snowservice_client.py index 5f4f2275..43a1f33f 100644 --- a/snowflake/ml/model/_deploy_client/utils/snowservice_client.py +++ b/snowflake/ml/model/_deploy_client/utils/snowservice_client.py @@ -70,13 +70,16 @@ def create_or_replace_service( logger.debug(f"Create service with SQL: \n {sql}") self.session.sql(sql).collect() - def create_job(self, compute_pool: str, spec_stage_location: str, external_access_integrations: List[str]) -> None: + def create_job( + self, job_name: str, compute_pool: str, spec_stage_location: str, external_access_integrations: List[str] + ) -> None: """Execute the job creation SQL command. Note that the job creation is synchronous, hence we execute it in a async way so that we can query the log in the meantime. Upon job failure, full job container log will be logged. Args: + job_name: name of the job compute_pool: name of the compute pool spec_stage_location: path to the stage location where the spec is located at. external_access_integrations: EAIs for network connection. @@ -84,19 +87,18 @@ def create_job(self, compute_pool: str, spec_stage_location: str, external_acces stage, path = uri.get_stage_and_path(spec_stage_location) sql = textwrap.dedent( f""" - EXECUTE SERVICE + EXECUTE JOB SERVICE IN COMPUTE POOL {compute_pool} FROM {stage} - SPEC = '{path}' + SPECIFICATION_FILE = '{path}' + NAME = {job_name} EXTERNAL_ACCESS_INTEGRATIONS = ({', '.join(external_access_integrations)}) """ ) logger.debug(f"Create job with SQL: \n {sql}") - cur = self.session._conn._conn.cursor() - cur.execute_async(sql) - job_id = cur._sfqid + self.session.sql(sql).collect_nowait() self.block_until_resource_is_ready( - resource_name=str(job_id), + resource_name=job_name, resource_type=constants.ResourceType.JOB, container_name=constants.KANIKO_CONTAINER_NAME, max_retries=240, @@ -182,10 +184,7 @@ def block_until_resource_is_ready( """ assert resource_type == constants.ResourceType.SERVICE or resource_type == constants.ResourceType.JOB query_command = "" - if resource_type == constants.ResourceType.SERVICE: - query_command = f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')" - elif resource_type == constants.ResourceType.JOB: - query_command = f"CALL SYSTEM$GET_JOB_LOGS('{resource_name}', '{container_name}')" + query_command = f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')" logger.warning( f"Best-effort log streaming from SPCS will be enabled when python logging level is set to INFO." f"Alternatively, you can also query the logs by running the query '{query_command}'" @@ -201,7 +200,7 @@ def block_until_resource_is_ready( ) lsp.process_new_logs(resource_log, log_level=logging.INFO) - status = self.get_resource_status(resource_name=resource_name, resource_type=resource_type) + status = self.get_resource_status(resource_name=resource_name) if resource_type == constants.ResourceType.JOB and status == constants.ResourceStatus.DONE: return @@ -246,52 +245,24 @@ def block_until_resource_is_ready( def get_resource_log( self, resource_name: str, resource_type: constants.ResourceType, container_name: str ) -> Optional[str]: - if resource_type == constants.ResourceType.SERVICE: - try: - row = self.session.sql( - f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')" - ).collect() - return str(row[0]["SYSTEM$GET_SERVICE_LOGS"]) - except Exception: - return None - elif resource_type == constants.ResourceType.JOB: - try: - row = self.session.sql(f"CALL SYSTEM$GET_JOB_LOGS('{resource_name}', '{container_name}')").collect() - return str(row[0]["SYSTEM$GET_JOB_LOGS"]) - except Exception: - return None - else: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.NOT_IMPLEMENTED, - original_exception=NotImplementedError( - f"{resource_type.name} is not yet supported in get_resource_log function" - ), - ) - - def get_resource_status( - self, resource_name: str, resource_type: constants.ResourceType - ) -> Optional[constants.ResourceStatus]: + try: + row = self.session.sql( + f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')" + ).collect() + return str(row[0]["SYSTEM$GET_SERVICE_LOGS"]) + except Exception: + return None + + def get_resource_status(self, resource_name: str) -> Optional[constants.ResourceStatus]: """Get resource status. Args: resource_name: Name of the resource. - resource_type: Type of the resource. - - Raises: - SnowflakeMLException: If resource type does not have a corresponding system function for querying status. - SnowflakeMLException: If corresponding status call failed. Returns: Optional[constants.ResourceStatus]: The status of the resource, or None if the resource status is empty. """ - if resource_type not in constants.RESOURCE_TO_STATUS_FUNCTION_MAPPING: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - f"Status querying is not supported for resources of type '{resource_type}'." - ), - ) - status_func = constants.RESOURCE_TO_STATUS_FUNCTION_MAPPING[resource_type] + status_func = "SYSTEM$GET_SERVICE_STATUS" try: row = self.session.sql(f"CALL {status_func}('{resource_name}');").collect() except Exception: diff --git a/snowflake/ml/model/_deploy_client/utils/snowservice_client_test.py b/snowflake/ml/model/_deploy_client/utils/snowservice_client_test.py index b2477e30..e160bc73 100644 --- a/snowflake/ml/model/_deploy_client/utils/snowservice_client_test.py +++ b/snowflake/ml/model/_deploy_client/utils/snowservice_client_test.py @@ -1,5 +1,5 @@ import json -from typing import Optional, cast +from typing import cast from absl.testing import absltest from absl.testing.absltest import mock @@ -52,29 +52,32 @@ def test_create_or_replace_service(self) -> None: external_access_integrations=["eai_a", "eai_b"], ) - def _add_mock_cursor_to_session(self, *, expected_job_id: Optional[str] = None) -> None: - mock_cursor = mock.Mock() - mock_cursor.execute_async.return_value = None - mock_cursor._sfqid = expected_job_id - - # Replace the cursor in the m_session with the mock_cursor - self.m_session._conn = mock.Mock() - self.m_session._conn._conn = mock.Mock() - self.m_session._conn._conn.cursor.return_value = mock_cursor - def test_create_job_successfully(self) -> None: - with mock.patch.object(self.client, "get_resource_status", return_value=constants.ResourceStatus.DONE): + with mock.patch.object( + self.client, "get_resource_status", return_value=constants.ResourceStatus.DONE + ) as mock_get_resource_status: m_compute_pool = "mock_compute_pool" m_stage = "@mock_spec_stage" m_stage_path = "a/hello.yaml" m_spec_storgae_location = f"{m_stage}/{m_stage_path}" - expected_job_id = "abcd" - self._add_mock_cursor_to_session(expected_job_id=expected_job_id) + m_job_name = "abcd" + self.m_session.add_mock_sql( + query=f""" + EXECUTE JOB SERVICE + IN COMPUTE POOL {m_compute_pool} + FROM {m_stage} + SPECIFICATION_FILE = '{m_stage_path}' + NAME = {m_job_name} + EXTERNAL_ACCESS_INTEGRATIONS = (eai_a, eai_b)""", + result=mock_data_frame.MockDataFrame(collect_result=[]), + ) self.client.create_job( + job_name=m_job_name, compute_pool=m_compute_pool, spec_stage_location=m_spec_storgae_location, external_access_integrations=["eai_a", "eai_b"], ) + mock_get_resource_status.assert_called_once_with(resource_name=m_job_name) def test_create_job_failed(self) -> None: with self.assertLogs(level="INFO") as cm: @@ -85,18 +88,28 @@ def test_create_job_failed(self) -> None: m_stage = "@mock_spec_stage" m_stage_path = "a/hello.yaml" m_spec_storgae_location = f"{m_stage}/{m_stage_path}" - expected_job_id = "abcd" + m_job_name = "abcd" self.m_session.add_mock_sql( - query=f"CALL SYSTEM$GET_JOB_LOGS('{expected_job_id}', '{constants.KANIKO_CONTAINER_NAME}')", + query=f""" + EXECUTE JOB SERVICE + IN COMPUTE POOL {m_compute_pool} + FROM {m_stage} + SPECIFICATION_FILE = '{m_stage_path}' + NAME = {m_job_name} + EXTERNAL_ACCESS_INTEGRATIONS = (eai_a, eai_b)""", + result=mock_data_frame.MockDataFrame(collect_result=[]), + ) + + self.m_session.add_mock_sql( + query=f"CALL SYSTEM$GET_SERVICE_LOGS('{m_job_name}', '0', '{constants.KANIKO_CONTAINER_NAME}')", result=mock_data_frame.MockDataFrame( - collect_result=[snowpark.Row(**{"SYSTEM$GET_JOB_LOGS": test_log})] + collect_result=[snowpark.Row(**{"SYSTEM$GET_SERVICE_LOGS": test_log})] ), ) - self._add_mock_cursor_to_session(expected_job_id=expected_job_id) - self.client.create_job( + job_name=m_job_name, compute_pool=m_compute_pool, spec_stage_location=m_spec_storgae_location, external_access_integrations=["eai_a", "eai_b"], @@ -183,7 +196,7 @@ def test_get_service_status(self) -> None: ) self.assertEqual( - self.client.get_resource_status(self.m_service_name, constants.ResourceType.SERVICE), + self.client.get_resource_status(self.m_service_name), constants.ResourceStatus("READY"), ) @@ -210,7 +223,7 @@ def test_get_service_status(self) -> None: ) self.assertEqual( - self.client.get_resource_status(self.m_service_name, constants.ResourceType.SERVICE), + self.client.get_resource_status(self.m_service_name), constants.ResourceStatus("FAILED"), ) @@ -235,7 +248,7 @@ def test_get_service_status(self) -> None: query="call system$GET_SERVICE_STATUS('mock_service_name');", result=mock_data_frame.MockDataFrame(collect_result=[row]), ) - self.assertEqual(self.client.get_resource_status(self.m_service_name, constants.ResourceType.SERVICE), None) + self.assertEqual(self.client.get_resource_status(self.m_service_name), None) def test_block_until_service_is_ready_happy_path(self) -> None: with mock.patch.object(self.client, "get_resource_status", return_value=constants.ResourceStatus("READY")): diff --git a/snowflake/ml/model/_model_composer/model_composer.py b/snowflake/ml/model/_model_composer/model_composer.py index 3365f428..e7b89b3a 100644 --- a/snowflake/ml/model/_model_composer/model_composer.py +++ b/snowflake/ml/model/_model_composer/model_composer.py @@ -8,8 +8,10 @@ from absl import logging from packaging import requirements +from typing_extensions import deprecated from snowflake.ml._internal import env as snowml_env, env_utils, file_utils +from snowflake.ml._internal.lineage import data_source from snowflake.ml.model import model_signature, type_hints as model_types from snowflake.ml.model._model_composer.model_manifest import model_manifest from snowflake.ml.model._packager import model_packager @@ -134,6 +136,7 @@ def save( model_meta=self.packager.meta, model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path), options=options, + data_sources=self._get_data_sources(model), ) file_utils.upload_directory_to_stage( @@ -143,7 +146,8 @@ def save( statement_params=self._statement_params, ) - def load( + @deprecated("Only used by PrPr model registry. Use static method version of load instead.") + def legacy_load( self, *, meta_only: bool = False, @@ -163,3 +167,20 @@ def load( with zipfile.ZipFile(self.model_local_path, mode="r", compression=zipfile.ZIP_DEFLATED) as zf: zf.extractall(path=self._packager_workspace_path) self.packager.load(meta_only=meta_only, options=options) + + @staticmethod + def load( + workspace_path: pathlib.Path, + *, + meta_only: bool = False, + options: Optional[model_types.ModelLoadOption] = None, + ) -> model_packager.ModelPackager: + mp = model_packager.ModelPackager(str(workspace_path / ModelComposer.MODEL_DIR_REL_PATH)) + mp.load(meta_only=meta_only, options=options) + return mp + + def _get_data_sources(self, model: model_types.SupportedModelType) -> Optional[List[data_source.DataSource]]: + data_sources = getattr(model, "_data_sources", None) + if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources): + return data_sources + return None diff --git a/snowflake/ml/model/_model_composer/model_composer_test.py b/snowflake/ml/model/_model_composer/model_composer_test.py index 7c10e609..192b34fd 100644 --- a/snowflake/ml/model/_model_composer/model_composer_test.py +++ b/snowflake/ml/model/_model_composer/model_composer_test.py @@ -9,7 +9,9 @@ from sklearn import linear_model from snowflake.ml._internal import env_utils, file_utils +from snowflake.ml.model import type_hints as model_types from snowflake.ml.model._model_composer import model_composer +from snowflake.ml.model._packager import model_packager from snowflake.ml.modeling.linear_model import ( # type:ignore[attr-defined] LinearRegression, ) @@ -79,6 +81,12 @@ def test_save_interface(self) -> None: c_session, local_path=mock.ANY, stage_path=pathlib.PurePosixPath(stage_path), statement_params=None ) + def test_load(self) -> None: + m_options = model_types.ModelLoadOption(use_gpu=False) + with mock.patch.object(model_packager.ModelPackager, "load") as mock_load: + model_composer.ModelComposer.load(pathlib.Path("workspace"), meta_only=True, options=m_options) + mock_load.assert_called_once_with(meta_only=True, options=m_options) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py index b15bc142..6dad7024 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py @@ -5,6 +5,7 @@ import yaml +from snowflake.ml._internal.lineage import data_source from snowflake.ml.model import type_hints from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.model._model_composer.model_method import ( @@ -36,6 +37,7 @@ def save( model_meta: model_meta_api.ModelMetadata, model_file_rel_path: pathlib.PurePosixPath, options: Optional[type_hints.ModelSaveOption] = None, + data_sources: Optional[List[data_source.DataSource]] = None, ) -> None: if options is None: options = {} @@ -90,6 +92,10 @@ def save( ], ) + lineage_sources = self._extract_lineage_info(data_sources) + if lineage_sources: + manifest_dict["lineage_sources"] = lineage_sources + with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f: # Anchors are not supported in the server, avoid that. yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign] @@ -108,3 +114,19 @@ def load(self) -> model_manifest_schema.ModelManifestDict: res = cast(model_manifest_schema.ModelManifestDict, raw_input) return res + + def _extract_lineage_info( + self, data_sources: Optional[List[data_source.DataSource]] + ) -> List[model_manifest_schema.LineageSourceDict]: + result = [] + if data_sources: + for source in data_sources: + result.append( + model_manifest_schema.LineageSourceDict( + # Currently, we only support lineage from Dataset. + type=model_manifest_schema.LineageSourceTypes.DATASET.value, + entity=source.fully_qualified_name, + version=source.version, + ) + ) + return result diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py index b22a8533..ebafde09 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py @@ -75,8 +75,19 @@ class SnowparkMLDataDict(TypedDict): functions: Required[List[ModelFunctionInfoDict]] +class LineageSourceTypes(enum.Enum): + DATASET = "DATASET" + + +class LineageSourceDict(TypedDict): + type: Required[str] + entity: Required[str] + version: NotRequired[str] + + class ModelManifestDict(TypedDict): manifest_version: Required[str] runtimes: Required[Dict[str, ModelRuntimeDict]] methods: Required[List[ModelMethodDict]] user_data: NotRequired[Dict[str, Any]] + lineage_sources: NotRequired[List[LineageSourceDict]] diff --git a/snowflake/ml/model/_packager/BUILD.bazel b/snowflake/ml/model/_packager/BUILD.bazel index 444c7742..1be32a3a 100644 --- a/snowflake/ml/model/_packager/BUILD.bazel +++ b/snowflake/ml/model/_packager/BUILD.bazel @@ -15,7 +15,6 @@ py_library( name = "model_packager", srcs = ["model_packager.py"], deps = [ - "//snowflake/ml/_internal:env_utils", "//snowflake/ml/_internal/exceptions", "//snowflake/ml/model:custom_model", "//snowflake/ml/model:model_signature", diff --git a/snowflake/ml/model/_packager/model_env/model_env.py b/snowflake/ml/model/_packager/model_env/model_env.py index 40f83d4f..68b0d7b8 100644 --- a/snowflake/ml/model/_packager/model_env/model_env.py +++ b/snowflake/ml/model/_packager/model_env/model_env.py @@ -284,6 +284,7 @@ def load_from_conda_file(self, conda_env_path: pathlib.Path) -> None: " This may prevent model deploying to Snowflake Warehouse." ), category=UserWarning, + stacklevel=2, ) if len(channel_dependencies) == 0 and channel not in self._conda_dependencies: warnings.warn( @@ -292,6 +293,7 @@ def load_from_conda_file(self, conda_env_path: pathlib.Path) -> None: " This may prevent model deploying to Snowflake Warehouse." ), category=UserWarning, + stacklevel=2, ) self._conda_dependencies[channel] = [] @@ -307,6 +309,7 @@ def load_from_conda_file(self, conda_env_path: pathlib.Path) -> None: " This may be unintentional." ), category=UserWarning, + stacklevel=2, ) if pip_requirements_list: @@ -316,6 +319,7 @@ def load_from_conda_file(self, conda_env_path: pathlib.Path) -> None: " This may prevent model deploying to Snowflake Warehouse." ), category=UserWarning, + stacklevel=2, ) for pip_dependency in pip_requirements_list: if any( @@ -338,6 +342,7 @@ def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None: " This may prevent model deploying to Snowflake Warehouse." ), category=UserWarning, + stacklevel=2, ) for pip_dependency in pip_requirements_list: if any( @@ -372,3 +377,39 @@ def save_as_dict(self, base_dir: pathlib.Path) -> model_meta_schema.ModelEnvDict "cuda_version": self.cuda_version, "snowpark_ml_version": self.snowpark_ml_version, } + + def validate_with_local_env( + self, check_snowpark_ml_version: bool = False + ) -> List[env_utils.IncorrectLocalEnvironmentError]: + errors = [] + try: + env_utils.validate_py_runtime_version(str(self._python_version)) + except env_utils.IncorrectLocalEnvironmentError as e: + errors.append(e) + + for conda_reqs in self._conda_dependencies.values(): + for conda_req in conda_reqs: + try: + env_utils.validate_local_installed_version_of_pip_package( + env_utils.try_convert_conda_requirement_to_pip(conda_req) + ) + except env_utils.IncorrectLocalEnvironmentError as e: + errors.append(e) + + for pip_req in self._pip_requirements: + try: + env_utils.validate_local_installed_version_of_pip_package(pip_req) + except env_utils.IncorrectLocalEnvironmentError as e: + errors.append(e) + + if check_snowpark_ml_version: + # For Modeling model + if self._snowpark_ml_version.base_version != snowml_env.VERSION: + errors.append( + env_utils.IncorrectLocalEnvironmentError( + f"The local installed version of Snowpark ML library is {snowml_env.VERSION} " + f"which differs from required version {self.snowpark_ml_version}." + ) + ) + + return errors diff --git a/snowflake/ml/model/_packager/model_env/model_env_test.py b/snowflake/ml/model/_packager/model_env/model_env_test.py index e5db33dd..f21e456c 100644 --- a/snowflake/ml/model/_packager/model_env/model_env_test.py +++ b/snowflake/ml/model/_packager/model_env/model_env_test.py @@ -3,6 +3,7 @@ import pathlib import tempfile import warnings +from unittest import mock import yaml from absl.testing import absltest @@ -953,6 +954,99 @@ def check_env_equality(this: model_env.ModelEnv, that: model_env.ModelEnv) -> bo loaded_env.load_from_dict(tmpdir_path, saved_dict) self.assertTrue(check_env_equality(env, loaded_env), "Loaded env object is different.") + def test_validate_with_local_env(self) -> None: + with mock.patch.object( + env_utils, "validate_py_runtime_version" + ) as mock_validate_py_runtime_version, mock.patch.object( + env_utils, "validate_local_installed_version_of_pip_package" + ) as mock_validate_local_installed_version_of_pip_package: + env = model_env.ModelEnv() + env.conda_dependencies = ["pytorch==1.3", "channel::some_package<1.2,>=1.0.1"] + env.pip_requirements = ["pip-package<1.2,>=1.0.1"] + env.python_version = "3.10.2" + env.cuda_version = "11.7.1" + env.snowpark_ml_version = "1.1.0" + + self.assertListEqual(env.validate_with_local_env(), []) + mock_validate_py_runtime_version.assert_called_once_with("3.10.2") + mock_validate_local_installed_version_of_pip_package.assert_has_calls( + [ + mock.call(requirements.Requirement("torch==1.3")), + mock.call(requirements.Requirement("some-package<1.2,>=1.0.1")), + mock.call(requirements.Requirement("pip-package<1.2,>=1.0.1")), + ] + ) + + with mock.patch.object( + env_utils, "validate_py_runtime_version", side_effect=env_utils.IncorrectLocalEnvironmentError() + ) as mock_validate_py_runtime_version, mock.patch.object( + env_utils, + "validate_local_installed_version_of_pip_package", + side_effect=env_utils.IncorrectLocalEnvironmentError(), + ) as mock_validate_local_installed_version_of_pip_package: + env = model_env.ModelEnv() + env.conda_dependencies = ["pytorch==1.3", "channel::some_package<1.2,>=1.0.1"] + env.pip_requirements = ["pip-package<1.2,>=1.0.1"] + env.python_version = "3.10.2" + env.cuda_version = "11.7.1" + env.snowpark_ml_version = "1.1.0" + + self.assertLen(env.validate_with_local_env(), 4) + mock_validate_py_runtime_version.assert_called_once_with("3.10.2") + mock_validate_local_installed_version_of_pip_package.assert_has_calls( + [ + mock.call(requirements.Requirement("torch==1.3")), + mock.call(requirements.Requirement("some-package<1.2,>=1.0.1")), + mock.call(requirements.Requirement("pip-package<1.2,>=1.0.1")), + ] + ) + + with mock.patch.object( + env_utils, "validate_py_runtime_version" + ) as mock_validate_py_runtime_version, mock.patch.object( + env_utils, "validate_local_installed_version_of_pip_package" + ) as mock_validate_local_installed_version_of_pip_package: + env = model_env.ModelEnv() + env.conda_dependencies = ["pytorch==1.3", "channel::some_package<1.2,>=1.0.1"] + env.pip_requirements = ["pip-package<1.2,>=1.0.1"] + env.python_version = "3.10.2" + env.cuda_version = "11.7.1" + env.snowpark_ml_version = f"{snowml_env.VERSION}+abcdef" + + self.assertListEqual(env.validate_with_local_env(check_snowpark_ml_version=True), []) + mock_validate_py_runtime_version.assert_called_once_with("3.10.2") + mock_validate_local_installed_version_of_pip_package.assert_has_calls( + [ + mock.call(requirements.Requirement("torch==1.3")), + mock.call(requirements.Requirement("some-package<1.2,>=1.0.1")), + mock.call(requirements.Requirement("pip-package<1.2,>=1.0.1")), + ] + ) + + with mock.patch.object( + env_utils, "validate_py_runtime_version", side_effect=env_utils.IncorrectLocalEnvironmentError() + ) as mock_validate_py_runtime_version, mock.patch.object( + env_utils, + "validate_local_installed_version_of_pip_package", + side_effect=env_utils.IncorrectLocalEnvironmentError(), + ) as mock_validate_local_installed_version_of_pip_package: + env = model_env.ModelEnv() + env.conda_dependencies = ["pytorch==1.3", "channel::some_package<1.2,>=1.0.1"] + env.pip_requirements = ["pip-package<1.2,>=1.0.1"] + env.python_version = "3.10.2" + env.cuda_version = "11.7.1" + env.snowpark_ml_version = "0.0.0" + + self.assertLen(env.validate_with_local_env(check_snowpark_ml_version=True), 5) + mock_validate_py_runtime_version.assert_called_once_with("3.10.2") + mock_validate_local_installed_version_of_pip_package.assert_has_calls( + [ + mock.call(requirements.Requirement("torch==1.3")), + mock.call(requirements.Requirement("some-package<1.2,>=1.0.1")), + mock.call(requirements.Requirement("pip-package<1.2,>=1.0.1")), + ] + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_packager/model_meta/model_meta.py b/snowflake/ml/model/_packager/model_meta/model_meta.py index 4f03dbd7..1ab04be3 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta.py @@ -320,11 +320,7 @@ def save(self, model_dir_path: str) -> None: with open(model_yaml_path, "w", encoding="utf-8") as out: yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign] - yaml.safe_dump( - model_dict, - stream=out, - default_flow_style=False, - ) + yaml.safe_dump(model_dict, stream=out, default_flow_style=False) @staticmethod def _validate_model_metadata(loaded_meta: Any) -> model_meta_schema.ModelMetadataDict: diff --git a/snowflake/ml/model/_packager/model_packager.py b/snowflake/ml/model/_packager/model_packager.py index ea9e7624..1913a2dc 100644 --- a/snowflake/ml/model/_packager/model_packager.py +++ b/snowflake/ml/model/_packager/model_packager.py @@ -4,7 +4,6 @@ from absl import logging -from snowflake.ml._internal import env_utils from snowflake.ml._internal.exceptions import ( error_codes, exceptions as snowml_exceptions, @@ -129,8 +128,6 @@ def load( model_meta.load_code_path(self.local_dir_path) - env_utils.validate_py_runtime_version(self.meta.env.python_version) - handler = model_handler.load_handler(self.meta.model_type) if handler is None: raise snowml_exceptions.SnowflakeMLException( diff --git a/snowflake/ml/model/_packager/model_packager_test.py b/snowflake/ml/model/_packager/model_packager_test.py index 7323cfb7..ca7448fb 100644 --- a/snowflake/ml/model/_packager/model_packager_test.py +++ b/snowflake/ml/model/_packager/model_packager_test.py @@ -182,43 +182,6 @@ def test_save_validation_2(self) -> None: assert isinstance(pk.model, LinearRegression) np.testing.assert_allclose(predictions, desired=pk.model.predict(df[:1])[[OUTPUT_COLUMNS]]) - def test_bad_save_model(self) -> None: - with tempfile.TemporaryDirectory() as workspace: - os.mkdir(os.path.join(workspace, "bias")) - with open(os.path.join(workspace, "bias", "bias1"), "w", encoding="utf-8") as f: - f.write("25") - with open(os.path.join(workspace, "bias", "bias2"), "w", encoding="utf-8") as f: - f.write("68") - lm = DemoModelWithManyArtifacts( - custom_model.ModelContext(models={}, artifacts={"bias": os.path.join(workspace, "bias")}) - ) - arr = np.array([[1, 2, 3], [4, 2, 5]]) - d = pd.DataFrame(arr, columns=["c1", "c2", "c3"]) - s = {"predict": model_signature.infer_signature(d, lm.predict(d))} - - with self.assertRaises(ValueError): - model_packager.ModelPackager(os.path.join(workspace, "model1")).save( - name="model1", - model=lm, - signatures={**s, "another_predict": s["predict"]}, - metadata={"author": "halu", "version": "1"}, - ) - - model_packager.ModelPackager(os.path.join(workspace, "model1")).save( - name="model1", - model=lm, - signatures=s, - metadata={"author": "halu", "version": "1"}, - python_version="3.5.2", - ) - - pk = model_packager.ModelPackager(os.path.join(workspace, "model1")) - pk.load(meta_only=True) - - with exception_utils.assert_snowml_exceptions(self, expected_original_error_type=RuntimeError): - pk = model_packager.ModelPackager(os.path.join(workspace, "model1")) - pk.load() - if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py b/snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py index f2214623..151ad2b3 100644 --- a/snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +++ b/snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py @@ -3,6 +3,8 @@ import pandas as pd +from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result + class PandasModelTrainer: """ @@ -72,11 +74,61 @@ def train_fit_predict( Tuple[pd.DataFrame, object]: [predicted dataset, estimator] """ assert hasattr(self.estimator, "fit_predict") # make type checker happy - args = {"X": self.dataset[self.input_cols]} - result = self.estimator.fit_predict(**args) + result = self.estimator.fit_predict(X=self.dataset[self.input_cols]) result_df = pd.DataFrame(data=result, columns=expected_output_cols_list) if drop_input_cols: result_df = result_df else: - result_df = pd.concat([self.dataset, result_df], axis=1) + # in case the output column name overlap with the input column names, + # remove the ones in input column names + remove_dataset_col_name_exist_in_output_col = list( + set(self.dataset.columns) - set(expected_output_cols_list) + ) + result_df = pd.concat([self.dataset[remove_dataset_col_name_exist_in_output_col], result_df], axis=1) + return (result_df, self.estimator) + + def train_fit_transform( + self, + expected_output_cols_list: List[str], + drop_input_cols: Optional[bool] = False, + ) -> Tuple[pd.DataFrame, object]: + """Trains the model using specified features and target columns from the dataset. + This API is different from fit itself because it would also provide the transform + output. + + Args: + expected_output_cols_list (List[str]): The output columns + name as a list. Defaults to None. + drop_input_cols (Optional[bool]): Boolean to determine whether to + drop the input columns from the output dataset. + + Returns: + Tuple[pd.DataFrame, object]: [transformed dataset, estimator] + """ + assert hasattr(self.estimator, "fit") # make type checker happy + assert hasattr(self.estimator, "fit_transform") # make type checker happy + + argspec = inspect.getfullargspec(self.estimator.fit) + args = {"X": self.dataset[self.input_cols]} + if self.label_cols: + label_arg_name = "Y" if "Y" in argspec.args else "y" + args[label_arg_name] = self.dataset[self.label_cols].squeeze() + + if self.sample_weight_col is not None and "sample_weight" in argspec.args: + args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze() + + inference_res = self.estimator.fit_transform(**args) + + transformed_numpy_array, output_cols = handle_inference_result( + inference_res=inference_res, output_cols=expected_output_cols_list, inference_method="fit_transform" + ) + + result_df = pd.DataFrame(data=transformed_numpy_array, columns=output_cols) + if drop_input_cols: + result_df = result_df + else: + # in case the output column name overlap with the input column names, + # remove the ones in input column names + remove_dataset_col_name_exist_in_output_col = list(set(self.dataset.columns) - set(output_cols)) + result_df = pd.concat([self.dataset[remove_dataset_col_name_exist_in_output_col], result_df], axis=1) return (result_df, self.estimator) diff --git a/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py b/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py index cb562d57..d3ded488 100644 --- a/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +++ b/snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py @@ -72,24 +72,40 @@ def batch_inference( """ - handler = SnowparkTransformHandlers( - dataset=self.dataset, - estimator=self.estimator, - class_name=self._class_name, - subproject=self._subproject, - autogenerated=self._autogenerated, - ) - return handler.batch_inference( - inference_method, - input_cols, - expected_output_cols, - session, - dependencies, - drop_input_cols, - expected_output_cols_type, - *args, - **kwargs, - ) + mlrs_inference_methods = ["predict", "predict_proba", "predict_log_proba"] + + if inference_method in mlrs_inference_methods: + result_df = self.client.inference( + estimator=self.estimator, + dataset=self.dataset, + inference_method=inference_method, + input_cols=input_cols, + output_cols=expected_output_cols, + drop_input_cols=drop_input_cols, + ) + + else: + handler = SnowparkTransformHandlers( + dataset=self.dataset, + estimator=self.estimator, + class_name=self._class_name, + subproject=self._subproject, + autogenerated=self._autogenerated, + ) + result_df = handler.batch_inference( + inference_method, + input_cols, + expected_output_cols, + session, + dependencies, + drop_input_cols, + expected_output_cols_type, + *args, + **kwargs, + ) + + assert isinstance(result_df, DataFrame) # mypy - The MLRS return types are annotated as `object`. + return result_df def score( self, diff --git a/snowflake/ml/modeling/_internal/model_trainer.py b/snowflake/ml/modeling/_internal/model_trainer.py index d1863a05..7896121b 100644 --- a/snowflake/ml/modeling/_internal/model_trainer.py +++ b/snowflake/ml/modeling/_internal/model_trainer.py @@ -22,3 +22,10 @@ def train_fit_predict( drop_input_cols: Optional[bool] = False, ) -> Tuple[Union[DataFrame, pd.DataFrame], object]: raise NotImplementedError + + def train_fit_transform( + self, + expected_output_cols_list: List[str], + drop_input_cols: Optional[bool] = False, + ) -> Tuple[Union[DataFrame, pd.DataFrame], object]: + raise NotImplementedError diff --git a/snowflake/ml/modeling/_internal/model_trainer_builder.py b/snowflake/ml/modeling/_internal/model_trainer_builder.py index 2be7d455..57bfb880 100644 --- a/snowflake/ml/modeling/_internal/model_trainer_builder.py +++ b/snowflake/ml/modeling/_internal/model_trainer_builder.py @@ -138,21 +138,13 @@ def build_fit_predict( cls, estimator: object, dataset: Union[DataFrame, pd.DataFrame], - input_cols: Optional[List[str]] = None, + input_cols: List[str], autogenerated: bool = False, subproject: str = "", ) -> ModelTrainer: """ Builder method that creates an appropriate ModelTrainer instance based on the given params. """ - if input_cols is None: - raise exceptions.SnowflakeMLException( - error_code=error_codes.NOT_FOUND, - original_exception=ValueError( - "The input column names (input_cols) is None.\n" - "Please put your input_cols when initializing the estimator\n" - ), - ) if isinstance(dataset, pd.DataFrame): return PandasModelTrainer( estimator=estimator, @@ -179,3 +171,44 @@ def build_fit_predict( f"Unexpected dataset type: {type(dataset)}." "Supported dataset types: snowpark.DataFrame, pandas.DataFrame." ) + + @classmethod + def build_fit_transform( + cls, + estimator: object, + dataset: Union[DataFrame, pd.DataFrame], + input_cols: List[str], + label_cols: Optional[List[str]] = None, + sample_weight_col: Optional[str] = None, + autogenerated: bool = False, + subproject: str = "", + ) -> ModelTrainer: + """ + Builder method that creates an appropriate ModelTrainer instance based on the given params. + """ + if isinstance(dataset, pd.DataFrame): + return PandasModelTrainer( + estimator=estimator, + dataset=dataset, + input_cols=input_cols, + label_cols=label_cols, + sample_weight_col=sample_weight_col, + ) + elif isinstance(dataset, DataFrame): + trainer_klass = SnowparkModelTrainer + init_args = { + "estimator": estimator, + "dataset": dataset, + "session": dataset._session, + "input_cols": input_cols, + "label_cols": label_cols, + "sample_weight_col": sample_weight_col, + "autogenerated": autogenerated, + "subproject": subproject, + } + return trainer_klass(**init_args) # type: ignore[arg-type] + else: + raise TypeError( + f"Unexpected dataset type: {type(dataset)}." + "Supported dataset types: snowpark.DataFrame, pandas.DataFrame." + ) diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/BUILD.bazel b/snowflake/ml/modeling/_internal/snowpark_implementations/BUILD.bazel index 18639e52..6bbb3940 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/BUILD.bazel +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/BUILD.bazel @@ -11,6 +11,7 @@ py_library( "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/exceptions:modeling_error_messages", "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:pkg_version_utils", "//snowflake/ml/_internal/utils:query_result_checker", "//snowflake/ml/_internal/utils:snowpark_dataframe_utils", "//snowflake/ml/_internal/utils:temp_file_utils", diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py index bddb2097..a6528fb7 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py @@ -9,7 +9,11 @@ import pandas as pd from snowflake.ml._internal import telemetry -from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils +from snowflake.ml._internal.utils import ( + identifier, + pkg_version_utils, + snowpark_dataframe_utils, +) from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator from snowflake.ml._internal.utils.temp_file_utils import ( cleanup_temp_files, @@ -91,6 +95,7 @@ def batch_inference( A new dataset of the same type as the input dataset. """ + dependencies = self._get_validated_snowpark_dependencies(session, dependencies) dataset = self.dataset estimator = self.estimator # Register vectorized UDF for batch inference @@ -210,7 +215,8 @@ def score( Returns: An accuracy score for the model on the given test data. """ - + dependencies = self._get_validated_snowpark_dependencies(session, dependencies) + dependencies.append("snowflake-snowpark-python") dataset = self.dataset estimator = self.estimator dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(dataset) @@ -335,3 +341,19 @@ def score_wrapper_sproc( cleanup_temp_files([local_score_file_name]) return score + + def _get_validated_snowpark_dependencies(self, session: Session, dependencies: List[str]) -> List[str]: + """A helper function to validate dependencies and return the available packages that exists + in the snowflake anaconda channel + + Args: + session: the active snowpark Session + dependencies: unvalidated dependencies + + Returns: + A list of packages present in the snoflake conda channel. + """ + + return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( + pkg_versions=dependencies, session=session, subproject=self._subproject + ) diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py index c1208a47..36953f55 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py @@ -23,20 +23,26 @@ cleanup_temp_files, get_temp_file_path, ) +from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result from snowflake.ml.modeling._internal.model_specifications import ( ModelSpecifications, ModelSpecificationsBuilder, ) -from snowflake.snowpark import DataFrame, Session, exceptions as snowpark_exceptions +from snowflake.snowpark import ( + DataFrame, + Session, + exceptions as snowpark_exceptions, + functions as F, +) from snowflake.snowpark._internal.utils import ( TempObjectType, random_name_for_temp_object, ) -from snowflake.snowpark.functions import sproc from snowflake.snowpark.stored_procedure import StoredProcedure cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path)) cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name)) +cp.register_pickle_by_value(inspect.getmodule(handle_inference_result)) _PROJECT = "ModelDevelopment" @@ -122,7 +128,7 @@ def _upload_model_to_stage(self, stage_name: str) -> Tuple[str, str]: project=_PROJECT, subproject=self._subproject, function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name), - api_calls=[sproc], + api_calls=[F.sproc], custom_tags=dict([("autogen", True)]) if self._autogenerated else None, ) # Put locally serialized transform on stage. @@ -292,7 +298,7 @@ def _build_fit_predict_wrapper_sproc( """ imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml - def fit_wrapper_function( + def fit_predict_wrapper_function( session: Session, sql_queries: List[str], stage_transform_file_name: str, @@ -329,7 +335,7 @@ def fit_wrapper_function( with open(local_transform_file_path, mode="r+b") as local_transform_file_obj: estimator = cp.load(local_transform_file_obj) - fit_predict_result = estimator.fit_predict(df[input_cols]) + fit_predict_result = estimator.fit_predict(X=df[input_cols]) local_result_file_name = get_temp_file_path() @@ -349,8 +355,16 @@ def fit_wrapper_function( fit_predict_result_pd = pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list) else: df = df.copy() - fit_predict_result_pd = pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list) - fit_predict_result_pd = pd.concat([df, fit_predict_result_pd], axis=1) + # in case the output column name overlap with the input column names, + # remove the ones in input column names + remove_dataset_col_name_exist_in_output_col = list(set(df.columns) - set(expected_output_cols_list)) + fit_predict_result_pd = pd.concat( + [ + df[remove_dataset_col_name_exist_in_output_col], + pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list), + ], + axis=1, + ) # write into a temp table in sproc and load the table from outside session.write_pandas( @@ -361,17 +375,150 @@ def fit_wrapper_function( # to pass debug information to the caller. return str(os.path.basename(local_result_file_name)) - return fit_wrapper_function + return fit_predict_wrapper_function + + def _build_fit_transform_wrapper_sproc( + self, + model_spec: ModelSpecifications, + ) -> Callable[ + [ + Session, + List[str], + str, + str, + List[str], + Optional[List[str]], + Optional[str], + Dict[str, str], + bool, + List[str], + str, + ], + str, + ]: + """ + Constructs and returns a python stored procedure function to be used for training model. + + Args: + model_spec: ModelSpecifications object that contains model specific information + like required imports, package dependencies, etc. + + Returns: + A callable that can be registered as a stored procedure. + """ + imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml + + def fit_transform_wrapper_function( + session: Session, + sql_queries: List[str], + stage_transform_file_name: str, + stage_result_file_name: str, + input_cols: List[str], + label_cols: Optional[List[str]], + sample_weight_col: Optional[str], + statement_params: Dict[str, str], + drop_input_cols: bool, + expected_output_cols_list: List[str], + fit_transform_result_name: str, + ) -> str: + import os + + import cloudpickle as cp + import pandas as pd + + for import_name in imports: + importlib.import_module(import_name) + + # Execute snowpark queries and obtain the results as pandas dataframe + # NB: this implies that the result data must fit into memory. + for query in sql_queries[:-1]: + _ = session.sql(query).collect(statement_params=statement_params) + sp_df = session.sql(sql_queries[-1]) + df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params) + df.columns = sp_df.columns + + local_transform_file_name = get_temp_file_path() + + session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params) + + local_transform_file_path = os.path.join( + local_transform_file_name, os.listdir(local_transform_file_name)[0] + ) + with open(local_transform_file_path, mode="r+b") as local_transform_file_obj: + estimator = cp.load(local_transform_file_obj) + + argspec = inspect.getfullargspec(estimator.fit) + args = {"X": df[input_cols]} + if label_cols: + label_arg_name = "Y" if "Y" in argspec.args else "y" + args[label_arg_name] = df[label_cols].squeeze() + + if sample_weight_col is not None and "sample_weight" in argspec.args: + args["sample_weight"] = df[sample_weight_col].squeeze() + + fit_transform_result = estimator.fit_transform(**args) + + local_result_file_name = get_temp_file_path() + + with open(local_result_file_name, mode="w+b") as local_result_file_obj: + cp.dump(estimator, local_result_file_obj) + + session.file.put( + local_result_file_name, + stage_result_file_name, + auto_compress=False, + overwrite=True, + statement_params=statement_params, + ) + + transformed_numpy_array, output_cols = handle_inference_result( + inference_res=fit_transform_result, + output_cols=expected_output_cols_list, + inference_method="fit_transform", + within_udf=True, + ) + + if len(transformed_numpy_array.shape) > 1: + if transformed_numpy_array.shape[1] != len(output_cols): + series = pd.Series(transformed_numpy_array.tolist()) + transformed_pandas_df = pd.DataFrame(series, columns=output_cols) + else: + transformed_pandas_df = pd.DataFrame(transformed_numpy_array.tolist(), columns=output_cols) + else: + transformed_pandas_df = pd.DataFrame(transformed_numpy_array, columns=output_cols) + + # store the transform output + if not drop_input_cols: + df = df.copy() + # in case the output column name overlap with the input column names, + # remove the ones in input column names + remove_dataset_col_name_exist_in_output_col = list(set(df.columns) - set(output_cols)) + transformed_pandas_df = pd.concat( + [df[remove_dataset_col_name_exist_in_output_col], transformed_pandas_df], axis=1 + ) + + # write into a temp table in sproc and load the table from outside + session.write_pandas( + transformed_pandas_df, + fit_transform_result_name, + auto_create_table=True, + table_type="temp", + quote_identifiers=False, + ) + + return str(os.path.basename(local_result_file_name)) + + return fit_transform_wrapper_function def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure: # If the sproc already exists, don't register. - if not hasattr(self.session, "_FIT_PRE_WRAPPER_SPROCS"): - self.session._FIT_PRE_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc] + if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"): + self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc] model_spec = ModelSpecificationsBuilder.build(model=self.estimator) - fit_predict_sproc_key = model_spec.__class__.__name__ - if fit_predict_sproc_key in self.session._FIT_PRE_WRAPPER_SPROCS: # type: ignore[attr-defined] - fit_sproc: StoredProcedure = self.session._FIT_PRE_WRAPPER_SPROCS[ # type: ignore[attr-defined] + fit_predict_sproc_key = model_spec.__class__.__name__ + "_fit_predict" + if fit_predict_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined] + fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined] fit_predict_sproc_key ] return fit_sproc @@ -392,12 +539,47 @@ def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str]) -> St statement_params=statement_params, ) - self.session._FIT_PRE_WRAPPER_SPROCS[ # type: ignore[attr-defined] + self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined] fit_predict_sproc_key ] = fit_predict_wrapper_sproc return fit_predict_wrapper_sproc + def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure: + # If the sproc already exists, don't register. + if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"): + self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc] + + model_spec = ModelSpecificationsBuilder.build(model=self.estimator) + fit_transform_sproc_key = model_spec.__class__.__name__ + "_fit_transform" + if fit_transform_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined] + fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined] + fit_transform_sproc_key + ] + return fit_sproc + + fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE) + + relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( + pkg_versions=model_spec.pkgDependencies, session=self.session + ) + + fit_transform_wrapper_sproc = self.session.sproc.register( + func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec), + is_permanent=False, + name=fit_transform_sproc_name, + packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type] + replace=True, + session=self.session, + statement_params=statement_params, + ) + + self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined] + fit_transform_sproc_key + ] = fit_transform_wrapper_sproc + + return fit_transform_wrapper_sproc + def train(self) -> object: """ Trains the model by pushing down the compute into Snowflake using stored procedures. @@ -498,10 +680,10 @@ def train_fit_predict( custom_tags=dict([("autogen", True)]) if self._autogenerated else None, ) - fit_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params) + fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params) fit_predict_result_name = random_name_for_temp_object(TempObjectType.TABLE) - sproc_export_file_name: str = fit_wrapper_sproc( + sproc_export_file_name: str = fit_predict_wrapper_sproc( self.session, queries, stage_transform_file_name, @@ -521,3 +703,66 @@ def train_fit_predict( ) return output_result_sp, fitted_estimator + + def train_fit_transform( + self, + expected_output_cols_list: List[str], + drop_input_cols: Optional[bool] = False, + ) -> Tuple[Union[DataFrame, pd.DataFrame], object]: + """Trains the model by pushing down the compute into Snowflake using stored procedures. + This API is different from fit itself because it would also provide the transform + output. + + Args: + expected_output_cols_list (List[str]): The output columns + name as a list. Defaults to None. + drop_input_cols (Optional[bool]): Boolean to determine whether to + drop the input columns from the output dataset. + + Returns: + Tuple[Union[DataFrame, pd.DataFrame], object]: [transformed dataset, estimator] + """ + dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(self.dataset) + + # Extract query that generated the dataframe. We will need to pass it to the fit procedure. + queries = dataset.queries["queries"] + + transform_stage_name = self._create_temp_stage() + (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage( + stage_name=transform_stage_name + ) + + # Call fit sproc + statement_params = telemetry.get_function_usage_statement_params( + project=_PROJECT, + subproject=self._subproject, + function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name), + api_calls=[Session.call], + custom_tags=dict([("autogen", True)]) if self._autogenerated else None, + ) + + fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params) + fit_transform_result_name = random_name_for_temp_object(TempObjectType.TABLE) + + sproc_export_file_name: str = fit_transform_wrapper_sproc( + self.session, + queries, + stage_transform_file_name, + stage_result_file_name, + self.input_cols, + self.label_cols, + self.sample_weight_col, + statement_params, + drop_input_cols, + expected_output_cols_list, + fit_transform_result_name, + ) + + output_result_sp = self.session.table(fit_transform_result_name) + fitted_estimator = self._fetch_model_from_stage( + dir_path=stage_result_file_name, + file_name=sproc_export_file_name, + statement_params=statement_params, + ) + + return output_result_sp, fitted_estimator diff --git a/snowflake/ml/modeling/framework/BUILD.bazel b/snowflake/ml/modeling/framework/BUILD.bazel index cd147f75..8ab86112 100644 --- a/snowflake/ml/modeling/framework/BUILD.bazel +++ b/snowflake/ml/modeling/framework/BUILD.bazel @@ -14,6 +14,7 @@ py_library( "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/exceptions:error_messages", "//snowflake/ml/_internal/exceptions:modeling_error_messages", + "//snowflake/ml/_internal/lineage:dataset_dataframe", "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/_internal/utils:parallelize", "//snowflake/ml/modeling/_internal:transformer_protocols", diff --git a/snowflake/ml/modeling/framework/base.py b/snowflake/ml/modeling/framework/base.py index 5ba6a667..4046a6fd 100644 --- a/snowflake/ml/modeling/framework/base.py +++ b/snowflake/ml/modeling/framework/base.py @@ -16,6 +16,7 @@ exceptions, modeling_error_messages, ) +from snowflake.ml._internal.lineage import data_source, dataset_dataframe from snowflake.ml._internal.utils import identifier, parallelize from snowflake.ml.modeling.framework import _utils from snowflake.snowpark import functions as F @@ -385,6 +386,7 @@ def __init__( self.file_names = file_names self.custom_states = custom_states self.sample_weight_col = sample_weight_col + self._data_sources: Optional[List[data_source.DataSource]] = None self.start_time = datetime.now().strftime(_utils.DATETIME_FORMAT)[:-3] @@ -419,12 +421,17 @@ def _get_dependencies(self) -> List[str]: """ return [] + def _get_data_sources(self) -> Optional[List[data_source.DataSource]]: + return self._data_sources + @telemetry.send_api_usage_telemetry( project=PROJECT, subproject=SUBPROJECT, ) def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "BaseEstimator": """Runs universal logics for all fit implementations.""" + if isinstance(dataset, dataset_dataframe.DatasetDataFrame): + self._data_sources = dataset._get_sources() return self._fit(dataset) @abstractmethod @@ -539,58 +546,78 @@ def _enforce_fit(self) -> None: ), ) - def _infer_input_output_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> None: + def _infer_input_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> List[str]: """ - Infer `self.input_cols` and `self.output_cols` if they are not explicitly set. + Infer input_cols from the dataset. Input column are all columns in the input dataset that are not + designated as label, passthrough, or sample weight columns. Args: dataset: Input dataset. + Returns: + The list of input columns. + """ + cols = [ + c + for c in dataset.columns + if (c not in self.get_label_cols() and c not in self.get_passthrough_cols() and c != self.sample_weight_col) + ] + return cols + + def _infer_output_cols(self) -> List[str]: + """Infer output column names from based on the estimator. + + Returns: + The list of output columns. + Raises: SnowflakeMLException: If unable to infer output columns + """ - if not self.input_cols: - cols = [ - c - for c in dataset.columns - if ( - c not in self.get_label_cols() - and c not in self.get_passthrough_cols() - and c != self.sample_weight_col - ) - ] - self.set_input_cols(input_cols=cols) - if not self.output_cols: - # keep mypy happy - assert self._sklearn_object is not None - - if hasattr(self._sklearn_object, "_estimator_type"): - # For supervised estimators, infer the output columns from the label columns - if self._sklearn_object._estimator_type in SKLEARN_SUPERVISED_ESTIMATORS: - cols = [identifier.concat_names(["OUTPUT_", c]) for c in self.label_cols] - self.set_output_cols(output_cols=cols) - - # For density estimators, clusterers, and outlier detectors, there is always exactly one output column. - elif self._sklearn_object._estimator_type in SKLEARN_SINGLE_OUTPUT_ESTIMATORS: - self.set_output_cols(output_cols=["OUTPUT_0"]) - - else: - raise exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - f"Unable to infer output columns for estimator type {self._sklearn_object._estimator_type}." - f"Please include `output_cols` explicitly." - ), - ) + # keep mypy happy + assert self._sklearn_object is not None + if hasattr(self._sklearn_object, "_estimator_type"): + # For supervised estimators, infer the output columns from the label columns + if self._sklearn_object._estimator_type in SKLEARN_SUPERVISED_ESTIMATORS: + cols = [identifier.concat_names(["OUTPUT_", c]) for c in self.label_cols] + return cols + + # For density estimators, clusterers, and outlier detectors, there is always exactly one output column. + elif self._sklearn_object._estimator_type in SKLEARN_SINGLE_OUTPUT_ESTIMATORS: + return ["OUTPUT_0"] + else: raise exceptions.SnowflakeMLException( error_code=error_codes.INVALID_ARGUMENT, original_exception=ValueError( - f"Unable to infer output columns for object {self._sklearn_object}." + f"Unable to infer output columns for estimator type {self._sklearn_object._estimator_type}." f"Please include `output_cols` explicitly." ), ) + else: + raise exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_ARGUMENT, + original_exception=ValueError( + f"Unable to infer output columns for object {self._sklearn_object}." + f"Please include `output_cols` explicitly." + ), + ) + + def _infer_input_output_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> None: + """ + Infer `self.input_cols` and `self.output_cols` if they are not explicitly set. + + Args: + dataset: Input dataset. + """ + if not self.input_cols: + cols = self._infer_input_cols(dataset=dataset) + self.set_input_cols(input_cols=cols) + + if not self.output_cols: + cols = self._infer_output_cols() + self.set_output_cols(output_cols=cols) def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]: """Returns the list of output columns for predict_proba(), decision_function(), etc.. functions. diff --git a/snowflake/ml/modeling/model_selection/grid_search_cv.py b/snowflake/ml/modeling/model_selection/grid_search_cv.py index 10b68897..27b1e7ba 100644 --- a/snowflake/ml/modeling/model_selection/grid_search_cv.py +++ b/snowflake/ml/modeling/model_selection/grid_search_cv.py @@ -334,9 +334,12 @@ def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "GridSearchCV": self._generate_model_signatures(dataset) return self - def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]: - """Util method to run validate that batch inference can be run on a snowpark dataframe and - return the available package that exists in the snowflake anaconda channel + def _batch_inference_validate_snowpark( + self, + dataset: DataFrame, + inference_method: str, + ) -> None: + """Util method to run validate that batch inference can be run on a snowpark dataframe. Args: dataset: snowpark dataframe @@ -346,8 +349,6 @@ def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_metho SnowflakeMLException: If the estimator is not fitted, raise error SnowflakeMLException: If the session is None, raise error - Returns: - A list of available package that exists in the snowflake anaconda channel """ if not self._is_fitted: raise exceptions.SnowflakeMLException( @@ -363,10 +364,6 @@ def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_metho error_code=error_codes.NOT_FOUND, original_exception=ValueError("Session must not specified for snowpark dataset."), ) - # Validate that key package version in user workspace are supported in snowflake conda channel - return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( - pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT - ) @available_if(original_estimator_has_callable("predict")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( @@ -415,10 +412,8 @@ def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, p ) expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type()) - self._deps = self._batch_inference_validate_snowpark( - dataset=dataset, - inference_method=inference_method, - ) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance( dataset._session, Session @@ -476,7 +471,8 @@ def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, inference_method = "transform" if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -535,7 +531,8 @@ def predict_proba( inference_method = "predict_proba" if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -595,7 +592,8 @@ def predict_log_proba( inference_method = "predict_log_proba" if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -655,7 +653,8 @@ def decision_function( inference_method = "decision_function" if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -716,7 +715,8 @@ def score_samples( transform_kwargs: BatchInferenceKwargsTypedDict = dict() if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -767,17 +767,15 @@ def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float: transform_kwargs: ScoreKwargsTypedDict = dict() if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark( - dataset=dataset, - inference_method="score", - ) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score") + self._deps = self._get_dependencies() selected_cols = self._get_active_columns() if len(selected_cols) > 0: dataset = dataset.select(selected_cols) assert isinstance(dataset._session, Session) # keep mypy happy transform_kwargs = dict( session=dataset._session, - dependencies=["snowflake-snowpark-python"] + self._deps, + dependencies=self._deps, score_sproc_imports=["sklearn"], ) elif isinstance(dataset, pd.DataFrame): diff --git a/snowflake/ml/modeling/model_selection/randomized_search_cv.py b/snowflake/ml/modeling/model_selection/randomized_search_cv.py index b0c2dbc3..77d3ba8f 100644 --- a/snowflake/ml/modeling/model_selection/randomized_search_cv.py +++ b/snowflake/ml/modeling/model_selection/randomized_search_cv.py @@ -347,8 +347,22 @@ def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "RandomizedSearchCV": self._generate_model_signatures(dataset) return self - def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]: - """Util method to run validate that batch inference can be run on a snowpark dataframe.""" + def _batch_inference_validate_snowpark( + self, + dataset: DataFrame, + inference_method: str, + ) -> None: + """Util method to run validate that batch inference can be run on a snowpark dataframe. + + Args: + dataset: snowpark dataframe + inference_method: the inference method such as predict, score... + + Raises: + SnowflakeMLException: If the estimator is not fitted, raise error + SnowflakeMLException: If the session is None, raise error + + """ if not self._is_fitted: raise exceptions.SnowflakeMLException( error_code=error_codes.METHOD_NOT_ALLOWED, @@ -363,10 +377,6 @@ def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_metho error_code=error_codes.NOT_FOUND, original_exception=ValueError("Session must not specified for snowpark dataset."), ) - # Validate that key package version in user workspace are supported in snowflake conda channel - return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( - pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT - ) @available_if(original_estimator_has_callable("predict")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( @@ -414,10 +424,9 @@ def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, p ) expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type()) - self._deps = self._batch_inference_validate_snowpark( - dataset=dataset, - inference_method=inference_method, - ) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() + assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -473,7 +482,9 @@ def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, inference_method = "transform" if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() + assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -531,7 +542,9 @@ def predict_proba( inference_method = "predict_proba" if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() + assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -591,7 +604,9 @@ def predict_log_proba( inference_method = "predict_log_proba" if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() + assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -650,7 +665,9 @@ def decision_function( inference_method = "decision_function" if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() + assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -711,7 +728,9 @@ def score_samples( transform_kwargs: BatchInferenceKwargsTypedDict = dict() if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method) + self._deps = self._get_dependencies() + assert isinstance( dataset._session, Session ) # mypy does not recognize the check in _batch_inference_validate_snowpark() @@ -761,10 +780,9 @@ def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float: transform_kwargs: ScoreKwargsTypedDict = dict() if isinstance(dataset, DataFrame): - self._deps = self._batch_inference_validate_snowpark( - dataset=dataset, - inference_method="score", - ) + self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score") + self._deps = self._get_dependencies() + selected_cols = self._get_active_columns() if len(selected_cols) > 0: dataset = dataset.select(selected_cols) @@ -772,7 +790,7 @@ def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float: assert isinstance(dataset._session, Session) # keep mypy happy transform_kwargs = dict( session=dataset._session, - dependencies=["snowflake-snowpark-python"] + self._deps, + dependencies=self._deps, score_sproc_imports=["sklearn"], ) elif isinstance(dataset, pd.DataFrame): diff --git a/snowflake/ml/modeling/pipeline/BUILD.bazel b/snowflake/ml/modeling/pipeline/BUILD.bazel index 876d3a64..b7a2de6d 100644 --- a/snowflake/ml/modeling/pipeline/BUILD.bazel +++ b/snowflake/ml/modeling/pipeline/BUILD.bazel @@ -1,4 +1,4 @@ -load("//bazel:py_rules.bzl", "py_library", "py_package") +load("//bazel:py_rules.bzl", "py_library", "py_package", "py_test") package(default_visibility = ["//visibility:public"]) @@ -19,9 +19,14 @@ py_library( ], deps = [ ":init", + "//snowflake/ml/_internal:file_utils", "//snowflake/ml/_internal:telemetry", "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/utils:snowpark_dataframe_utils", + "//snowflake/ml/_internal/utils:temp_file_utils", + "//snowflake/ml/model:model_signature", + "//snowflake/ml/modeling/_internal:model_transformer_builder", + "//snowflake/ml/modeling/_internal/snowpark_implementations:snowpark_handlers", ], ) @@ -32,3 +37,15 @@ py_package( ":pipeline", ], ) + +py_test( + name = "pipeline_test", + srcs = ["pipeline_test.py"], + deps = [ + ":pipeline", + "//snowflake/ml/modeling/lightgbm:lgbm_classifier", + "//snowflake/ml/modeling/linear_model:linear_regression", + "//snowflake/ml/modeling/preprocessing:min_max_scaler", + "//snowflake/ml/modeling/xgboost:xgb_regressor", + ], +) diff --git a/snowflake/ml/modeling/pipeline/pipeline.py b/snowflake/ml/modeling/pipeline/pipeline.py index ec8154e2..3c50bf74 100644 --- a/snowflake/ml/modeling/pipeline/pipeline.py +++ b/snowflake/ml/modeling/pipeline/pipeline.py @@ -1,7 +1,12 @@ #!/usr/bin/env python3 +import inspect +import os +import posixpath +import tempfile from itertools import chain from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +import cloudpickle as cp import numpy as np import pandas as pd from sklearn import __version__ as skversion, pipeline @@ -10,14 +15,20 @@ from sklearn.utils import metaestimators from snowflake import snowpark -from snowflake.ml._internal import telemetry +from snowflake.ml._internal import file_utils, telemetry from snowflake.ml._internal.exceptions import error_codes, exceptions -from snowflake.ml._internal.utils import snowpark_dataframe_utils +from snowflake.ml._internal.utils import snowpark_dataframe_utils, temp_file_utils from snowflake.ml.model.model_signature import ModelSignature, _infer_signature +from snowflake.ml.modeling._internal.model_transformer_builder import ( + ModelTransformerBuilder, +) from snowflake.ml.modeling.framework import _utils, base +from snowflake.snowpark import Session, functions as F +from snowflake.snowpark._internal import utils as snowpark_utils _PROJECT = "ModelDevelopment" _SUBPROJECT = "Framework" +IN_ML_RUNTIME_ENV_VAR = "IN_SPCS_ML_RUNTIME" def _final_step_has(attr: str) -> Callable[..., bool]: @@ -113,6 +124,8 @@ def __init__(self, steps: List[Tuple[str, Any]]) -> None: if isinstance(obj, base.BaseTransformer): deps = deps | set(obj._get_dependencies()) self._deps = list(deps) + self._sklearn_object = None + self.label_cols = self._get_label_cols() @staticmethod def _is_estimator(obj: object) -> bool: @@ -147,6 +160,33 @@ def _reset(self) -> None: self._n_features_in = [] self._transformers_to_input_indices = {} + def _is_convertible_to_sklearn_object(self) -> bool: + """Checks if the pipeline can be converted to a native sklearn pipeline. + - We can not create an sklearn pipeline if its label or sample weight column are + modified in the pipeline. + - We can not create an sklearn pipeline if any of its steps cannot be converted to an sklearn pipeline + - We can not create an sklearn pipeline if input columns are specified in any step other than + the first step + + Returns: + True if the pipeline can be converted to a native sklearn pipeline, else false. + """ + if self._is_pipeline_modifying_label_or_sample_weight(): + return False + + # check that nested pipelines can be converted to sklearn + for _, base_estimator in self.steps: + if hasattr(base_estimator, "_is_convertible_to_sklearn_object"): + if not base_estimator._is_convertible_to_sklearn_object(): + return False + + # check that no column after the first column has 'input columns' set. + for _, base_estimator in self.steps[1:]: + if base_estimator.get_input_cols(): + # We only want Falsy values - None and [] + return False + return True + def _is_pipeline_modifying_label_or_sample_weight(self) -> bool: """ Checks if pipeline is modifying label or sample_weight columns. @@ -214,27 +254,167 @@ def _fit_transform_dataset( self._append_step_feature_consumption_info( step_name=name, all_cols=transformed_dataset.columns[:], input_cols=trans.get_input_cols() ) - if has_callable_attr(trans, "fit_transform"): - transformed_dataset = trans.fit_transform(transformed_dataset) - else: - trans.fit(transformed_dataset) - transformed_dataset = trans.transform(transformed_dataset) + trans.fit(transformed_dataset) + transformed_dataset = trans.transform(transformed_dataset) return transformed_dataset + def _upload_model_to_stage(self, stage_name: str, estimator: object, session: Session) -> Tuple[str, str]: + """ + Util method to pickle and upload the model to a temp Snowflake stage. + + Args: + stage_name: Stage name to save model. + estimator: the pipeline estimator itself + session: Session object + + Returns: + a tuple containing stage file paths for pickled input model for training and location to store trained + models(response from training sproc). + """ + # Create a temp file and dump the transform to that file. + local_transform_file_name = temp_file_utils.get_temp_file_path() + with open(local_transform_file_name, mode="w+b") as local_transform_file: + cp.dump(estimator, local_transform_file) + + # Use posixpath to construct stage paths + stage_transform_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name)) + stage_result_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name)) + + # Put locally serialized transform on stage. + session.file.put( + local_transform_file_name, + stage_transform_file_name, + auto_compress=False, + overwrite=True, + ) + + temp_file_utils.cleanup_temp_files([local_transform_file_name]) + return (stage_transform_file_name, stage_result_file_name) + + def _fit_snowpark_dataframe_within_one_sproc(self, session: Session, dataset: snowpark.DataFrame) -> None: + # Extract queries that generated the dataframe. We will need to pass it to score procedure. + sql_queries = dataset.queries["queries"] + + # Zip the current snowml package + with tempfile.TemporaryDirectory() as tmpdir: + snowml_zip_module_filename = os.path.join(tmpdir, "snowflake-ml-python.zip") + file_utils.zip_python_package(snowml_zip_module_filename, "snowflake.ml") + imports = [snowml_zip_module_filename] + + sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE) + required_deps = self._deps + sproc_statement_params = telemetry.get_function_usage_statement_params( + project=_PROJECT, + subproject="PIPELINE", + function_name=telemetry.get_statement_params_full_func_name( + inspect.currentframe(), self.__class__.__name__ + ), + api_calls=[F.sproc], + ) + transform_stage_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE) + stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};" + session.sql(stage_creation_query).collect() + (stage_estimator_file_name, stage_result_file_name) = self._upload_model_to_stage( + transform_stage_name, self, session + ) + + def pipeline_within_one_sproc( + session: Session, + sql_queries: List[str], + stage_estimator_file_name: str, + stage_result_file_name: str, + sproc_statement_params: Dict[str, str], + ) -> str: + import os + + import cloudpickle as cp + import pandas as pd + + for query in sql_queries[:-1]: + _ = session.sql(query).collect(statement_params=sproc_statement_params) + sp_df = session.sql(sql_queries[-1]) + df: pd.DataFrame = sp_df.to_pandas(statement_params=sproc_statement_params) + df.columns = sp_df.columns + + local_estimator_file_name = temp_file_utils.get_temp_file_path() + + session.file.get(stage_estimator_file_name, local_estimator_file_name) + + local_estimator_file_path = os.path.join( + local_estimator_file_name, os.listdir(local_estimator_file_name)[0] + ) + with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj: + estimator = cp.load(local_estimator_file_obj) + + estimator.fit(df) + + local_result_file_name = temp_file_utils.get_temp_file_path() + + with open(local_result_file_name, mode="w+b") as local_result_file_obj: + cp.dump(estimator, local_result_file_obj) + + session.file.put( + local_result_file_name, + stage_result_file_name, + auto_compress=False, + overwrite=True, + statement_params=sproc_statement_params, + ) + + return str(os.path.basename(local_result_file_name)) + + session.sproc.register( + func=pipeline_within_one_sproc, + is_permanent=False, + name=sproc_name, + packages=required_deps, # type: ignore[arg-type] + replace=True, + session=session, + anonymous=True, + imports=imports, # type: ignore[arg-type] + statement_params=sproc_statement_params, + ) + + sproc_export_file_name: str = pipeline_within_one_sproc( + session, + sql_queries, + stage_estimator_file_name, + stage_result_file_name, + sproc_statement_params, + ) + + local_result_file_name = temp_file_utils.get_temp_file_path() + session.file.get( + posixpath.join(stage_estimator_file_name, sproc_export_file_name), + local_result_file_name, + statement_params=sproc_statement_params, + ) + + with open(os.path.join(local_result_file_name, sproc_export_file_name), mode="r+b") as result_file_obj: + fit_estimator = cp.load(result_file_obj) + + temp_file_utils.cleanup_temp_files([local_result_file_name]) + for key, val in vars(fit_estimator).items(): + setattr(self, key, val) + @telemetry.send_api_usage_telemetry( project=_PROJECT, subproject=_SUBPROJECT, ) - def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "Pipeline": + def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame], squash: Optional[bool] = False) -> "Pipeline": """ Fit the entire pipeline using the dataset. Args: dataset: Input dataset. + squash: Run the whole pipeline within a stored procedure Returns: Fitted pipeline. + + Raises: + ValueError: A pipeline incompatible with sklearn is used on MLRS """ self._validate_steps() @@ -243,19 +423,33 @@ def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "Pipeline": if isinstance(dataset, snowpark.DataFrame) else dataset ) - transformed_dataset = self._fit_transform_dataset(dataset) - estimator = self._get_estimator() - if estimator: - all_cols = transformed_dataset.columns[:] - estimator[1].fit(transformed_dataset) + if self._can_be_trained_in_ml_runtime(dataset): + if not self._is_convertible_to_sklearn_object(): + raise ValueError("This pipeline cannot be converted to an sklearn pipeline.") + self._fit_ml_runtime(dataset) - self._append_step_feature_consumption_info( - step_name=estimator[0], all_cols=all_cols, input_cols=estimator[1].get_input_cols() - ) + elif squash and isinstance(dataset, snowpark.DataFrame): + session = dataset._session + assert session is not None + self._fit_snowpark_dataframe_within_one_sproc(session=session, dataset=dataset) + + else: + transformed_dataset = self._fit_transform_dataset(dataset) + + estimator = self._get_estimator() + if estimator: + all_cols = transformed_dataset.columns[:] + estimator[1].fit(transformed_dataset) + + self._append_step_feature_consumption_info( + step_name=estimator[0], all_cols=all_cols, input_cols=estimator[1].get_input_cols() + ) + + self._generate_model_signatures(dataset=dataset) - self._generate_model_signatures(dataset=dataset) self._is_fitted = True + return self @metaestimators.available_if(_final_step_has("transform")) # type: ignore[misc] @@ -280,6 +474,22 @@ def transform(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[s else dataset ) + if self._sklearn_object is not None: + handler = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name="Pipeline", + subproject="", + autogenerated=False, + ) + return handler.batch_inference( + inference_method="transform", + input_cols=self.input_cols if self.input_cols else self._infer_input_cols(dataset), + expected_output_cols=self._infer_output_cols(), + session=dataset._session, + dependencies=self._deps, + ) + transformed_dataset = self._transform_dataset(dataset=dataset) estimator = self._get_estimator() if estimator: @@ -389,8 +599,32 @@ def predict(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[sno Returns: Output dataset. + + Raises: + ValueError: An sklearn object has not been fit and stored before calling this function. """ - return self._invoke_estimator_func("predict", dataset) + if os.environ.get(IN_ML_RUNTIME_ENV_VAR): + if self._sklearn_object is None: + raise ValueError("Model must be fit before inference.") + + expected_output_cols = self._infer_output_cols() + handler = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name="Pipeline", + subproject="", + autogenerated=False, + ) + return handler.batch_inference( + inference_method="predict", + input_cols=self.input_cols if self.input_cols else self._infer_input_cols(dataset), + expected_output_cols=expected_output_cols, + session=dataset._session, + dependencies=self._deps, + ) + + else: + return self._invoke_estimator_func("predict", dataset) @metaestimators.available_if(_final_step_has("score_samples")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( @@ -408,8 +642,32 @@ def score_samples( Returns: Output dataset. + + Raises: + ValueError: An sklearn object has not been fit before calling this function """ - return self._invoke_estimator_func("score_samples", dataset) + + if os.environ.get(IN_ML_RUNTIME_ENV_VAR): + if self._sklearn_object is None: + raise ValueError("Model must be fit before inference.") + + expected_output_cols = self._get_output_column_names("score_samples") + handler = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name="Pipeline", + subproject="", + autogenerated=False, + ) + return handler.batch_inference( + inference_method="score_samples", + input_cols=self.input_cols if self.input_cols else self._infer_input_cols(dataset), + expected_output_cols=expected_output_cols, + session=dataset._session, + dependencies=self._deps, + ) + else: + return self._invoke_estimator_func("score_samples", dataset) @metaestimators.available_if(_final_step_has("predict_proba")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( @@ -427,8 +685,32 @@ def predict_proba( Returns: Output dataset. + + Raises: + ValueError: An sklearn object has not been fit before calling this function """ - return self._invoke_estimator_func("predict_proba", dataset) + + if os.environ.get(IN_ML_RUNTIME_ENV_VAR): + if self._sklearn_object is None: + raise ValueError("Model must be fit before inference.") + expected_output_cols = self._get_output_column_names("predict_proba") + + handler = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name="Pipeline", + subproject="", + autogenerated=False, + ) + return handler.batch_inference( + inference_method="predict_proba", + input_cols=self.input_cols if self.input_cols else self._infer_input_cols(dataset), + expected_output_cols=expected_output_cols, + session=dataset._session, + dependencies=self._deps, + ) + else: + return self._invoke_estimator_func("predict_proba", dataset) @metaestimators.available_if(_final_step_has("predict_log_proba")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( @@ -447,8 +729,31 @@ def predict_log_proba( Returns: Output dataset. + + Raises: + ValueError: An sklearn object has not been fit before calling this function """ - return self._invoke_estimator_func("predict_log_proba", dataset) + if os.environ.get(IN_ML_RUNTIME_ENV_VAR): + if self._sklearn_object is None: + raise ValueError("Model must be fit before inference.") + + expected_output_cols = self._get_output_column_names("predict_log_proba") + handler = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name="Pipeline", + subproject="", + autogenerated=False, + ) + return handler.batch_inference( + inference_method="predict_log_proba", + input_cols=self.input_cols if self.input_cols else self._infer_input_cols(dataset), + expected_output_cols=expected_output_cols, + session=dataset._session, + dependencies=self._deps, + ) + else: + return self._invoke_estimator_func("predict_log_proba", dataset) @metaestimators.available_if(_final_step_has("score")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( @@ -464,8 +769,30 @@ def score(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[snowp Returns: Output dataset. + + Raises: + ValueError: An sklearn object has not been fit before calling this function """ - return self._invoke_estimator_func("score", dataset) + + if os.environ.get(IN_ML_RUNTIME_ENV_VAR): + if self._sklearn_object is None: + raise ValueError("Model must be fit before scoreing.") + handler = ModelTransformerBuilder.build( + dataset=dataset, + estimator=self._sklearn_object, + class_name="Pipeline", + subproject="", + autogenerated=False, + ) + return handler.score( + input_cols=self._infer_input_cols(), + label_cols=self._get_label_cols(), + session=dataset._session, + dependencies=self._deps, + score_sproc_imports=[], + ) + else: + return self._invoke_estimator_func("score", dataset) def _invoke_estimator_func( self, func_name: str, dataset: Union[snowpark.DataFrame, pd.DataFrame] @@ -495,15 +822,6 @@ def _invoke_estimator_func( res: snowpark.DataFrame = getattr(estimator[1], func_name)(transformed_dataset) return res - def _create_unfitted_sklearn_object(self) -> pipeline.Pipeline: - sksteps = [] - for step in self.steps: - if isinstance(step[1], base.BaseTransformer): - sksteps.append(tuple([step[0], _utils.to_native_format(step[1])])) - else: - sksteps.append(tuple([step[0], step[1]])) - return pipeline.Pipeline(steps=sksteps) - def _construct_fitted_column_transformer_object( self, step_name_in_pipeline: str, @@ -562,6 +880,125 @@ def _construct_fitted_column_transformer_object( ct._name_to_fitted_passthrough = {step_name_in_ct: ft} return ct + def _fit_ml_runtime(self, dataset: snowpark.DataFrame) -> None: + """Train the pipeline in the ML Runtime. + + Args: + dataset: The training Snowpark dataframe + + Raises: + ModuleNotFoundError: The ML Runtime Client is not installed. + """ + try: + from snowflake.ml.runtime import MLRuntimeClient + except ModuleNotFoundError as e: + # The snowflake.ml.runtime module should always be present when + # the env var IN_SPCS_ML_RUNTIME is present. + raise ModuleNotFoundError("ML Runtime Python Client is not installed.") from e + + client = MLRuntimeClient() + ml_runtime_compatible_pipeline = self._create_unfitted_sklearn_object() + + label_cols = self._get_label_cols() + all_df_cols = dataset.columns + input_cols = [col for col in all_df_cols if col not in label_cols] + + trained_pipeline = client.train( + estimator=ml_runtime_compatible_pipeline, + dataset=dataset, + input_cols=input_cols, + label_cols=label_cols, + sample_weight_col=self.sample_weight_col, + ) + + self._sklearn_object = trained_pipeline + + def _get_label_cols(self) -> List[str]: + """Util function to get the label columns from the pipeline. + The label column is only present in the estimator + + Returns: + List of label columns, or empty list if no label cols. + """ + label_cols = [] + estimator = self._get_estimator() + if estimator is not None: + label_cols = estimator[1].get_label_cols() + + return label_cols + + def _can_be_trained_in_ml_runtime(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> bool: + """A utility function to determine if the pipeline cam be pushed down to the ML Runtime for training. + Currently, this is true if: + - The training dataset is a snowpark dataframe, + - The IN_SPCS_ML_RUNTIME environment is present and + - The pipeline can be converted to an sklearn pipeline. + + Args: + dataset: The training dataset + + Returns: + True if the dataset can be fit in the ml runtime, else false. + + """ + if not isinstance(dataset, snowpark.DataFrame): + return False + + if not os.environ.get(IN_ML_RUNTIME_ENV_VAR): + return False + + return self._is_convertible_to_sklearn_object() + + @staticmethod + def _wrap_transformer_in_column_transformer( + transformer_name: str, transformer: base.BaseTransformer + ) -> ColumnTransformer: + """A helper function to convert a transformer object to an sklearn object and wrap in an sklearn + ColumnTransformer. + + Args: + transformer_name: Name of the transformer to be wrapped. + transformer: The transformer object to be wrapped. + + Returns: + A column transformer sklearn object that uses the input columns from the initial snowpark ml transformer. + """ + column_transformer = ColumnTransformer( + transformers=[(transformer_name, Pipeline._get_native_object(transformer), transformer.get_input_cols())], + remainder="passthrough", + ) + return column_transformer + + def _create_unfitted_sklearn_object(self) -> pipeline.Pipeline: + """Create a sklearn pipeline from the current snowml pipeline. + ColumnTransformers are used to wrap transformers as their input columns can be specified + as a subset of the pipeline's input columns. + + Returns: + An unfit pipeline that can be fit using the ML runtime client. + """ + + sklearn_pipeline_steps = [] + + first_step_name, first_step_object = self.steps[0] + + # Only the first step can have the input_cols field not None/empty. + if first_step_object.get_input_cols(): + first_step_column_transformer = Pipeline._wrap_transformer_in_column_transformer( + first_step_name, first_step_object + ) + first_step_skl = (first_step_name, first_step_column_transformer) + else: + first_step_skl = (first_step_name, Pipeline._get_native_object(first_step_object)) + + sklearn_pipeline_steps.append(first_step_skl) + + for step_name, step_object in self.steps[1:]: + skl_step = (step_name, Pipeline._get_native_object(step_object)) + sklearn_pipeline_steps.append(skl_step) + + return pipeline.Pipeline(sklearn_pipeline_steps) + def _create_sklearn_object(self) -> pipeline.Pipeline: if not self._is_fitted: return self._create_unfitted_sklearn_object() @@ -570,7 +1007,7 @@ def _create_sklearn_object(self) -> pipeline.Pipeline: raise exceptions.SnowflakeMLException( error_code=error_codes.METHOD_NOT_ALLOWED, original_exception=ValueError( - "The pipeline can't be converted to SKLearn equivalent because it processing label or " + "The pipeline can't be converted to SKLearn equivalent because it modifies processing label or " "sample_weight columns as part of pipeline preprocessing steps which is not allowed in SKLearn." ), ) @@ -631,3 +1068,48 @@ def model_signatures(self) -> Dict[str, ModelSignature]: original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"), ) return self._model_signature_dict + + @staticmethod + def _get_native_object(estimator: base.BaseEstimator) -> object: + """A helper function to get the native(sklearn, xgboost, or lightgbm) + object from a snowpark ml estimator. + TODO - better type hinting - is there a common base class for all xgb/lgbm estimators? + + Args: + estimator: the estimator from which to derive the native object. + + Returns: + a native estimator object + + Raises: + ValueError: The estimator is not an sklearn, xgboost, or lightgbm estimator. + """ + methods = ["to_sklearn", "to_xgboost", "to_lightgbm"] + for method_name in methods: + if hasattr(estimator, method_name): + try: + result = getattr(estimator, method_name)() + return result + except exceptions.SnowflakeMLException: + pass # Do nothing and continue to the next method + raise ValueError("The estimator must be an sklearn, xgboost, or lightgbm estimator.") + + def to_sklearn(self) -> pipeline.Pipeline: + """Returns an sklearn Pipeline representing the object, if possible. + + Returns: + previously fit sklearn Pipeline if present, else an unfit pipeline + + Raises: + ValueError: The pipeline cannot be represented as an sklearn pipeline. + """ + if self._is_fitted: + if self._sklearn_object is not None: + return self._sklearn_object + else: + return self._create_sklearn_object() + else: + if self._is_convertible_to_sklearn_object(): + return self._create_unfitted_sklearn_object() + else: + raise ValueError("This pipeline can not be converted to an sklearn pipeline.") diff --git a/snowflake/ml/modeling/pipeline/pipeline_test.py b/snowflake/ml/modeling/pipeline/pipeline_test.py new file mode 100644 index 00000000..6ed8f4bc --- /dev/null +++ b/snowflake/ml/modeling/pipeline/pipeline_test.py @@ -0,0 +1,237 @@ +import os + +import pandas as pd +from absl.testing import absltest +from sklearn.compose import ColumnTransformer +from sklearn.linear_model import LinearRegression as sklearn_LR +from sklearn.pipeline import Pipeline as sklearn_Pipeline +from sklearn.preprocessing import MinMaxScaler as sklearn_MMS + +from snowflake.ml.modeling.lightgbm import LGBMClassifier +from snowflake.ml.modeling.linear_model import LinearRegression +from snowflake.ml.modeling.pipeline.pipeline import IN_ML_RUNTIME_ENV_VAR, Pipeline +from snowflake.ml.modeling.preprocessing import MinMaxScaler +from snowflake.ml.modeling.xgboost import XGBRegressor +from snowflake.snowpark import DataFrame + + +class PipelineTest(absltest.TestCase): + def setUp(self) -> None: + self.dataframe_snowpark = absltest.mock.MagicMock(spec=DataFrame) + self.simple_pipeline = Pipeline( + steps=[ + ( + "MMS", + MinMaxScaler(input_cols=["col1"], output_cols=["col2"]), + ), + ] + ) + + self.pipeline_two_steps_no_estimator = Pipeline( + steps=[ + ( + "MMS", + MinMaxScaler(input_cols=["col1"], output_cols=["col1"]), + ), + ] + * 2 + ) + + self.pipeline_two_steps_with_estimator = Pipeline( + steps=[ + ( + "MMS", + MinMaxScaler(input_cols=["col1"], output_cols=["col1"]), + ), + ("model", LinearRegression(label_cols=["col3"])), + ] + ) + return super().setUp() + + def test_dataset_can_be_trained_in_ml_runtime(self) -> None: + """Test that the pipeline can only be trained in ml runtime if correct dataset type and + environment variables present. + """ + + assert self.simple_pipeline._can_be_trained_in_ml_runtime(dataset=pd.DataFrame()) is False + assert self.simple_pipeline._can_be_trained_in_ml_runtime(dataset=self.dataframe_snowpark) is False + + os.environ[IN_ML_RUNTIME_ENV_VAR] = "True" + assert self.simple_pipeline._can_be_trained_in_ml_runtime(dataset=pd.DataFrame()) is False + assert self.simple_pipeline._can_be_trained_in_ml_runtime(dataset=self.dataframe_snowpark) is True + + del os.environ[IN_ML_RUNTIME_ENV_VAR] + + def test_pipeline_can_be_trained_in_ml_runtime(self) -> None: + """Test that the pipeline can be trained in the ml runtime if it has the correct configuration + of steps. + """ + os.environ[IN_ML_RUNTIME_ENV_VAR] = "True" + + assert self.simple_pipeline._can_be_trained_in_ml_runtime(dataset=self.dataframe_snowpark) is True + + pipeline_three_steps = Pipeline( + steps=[ + ( + "MMS", + MinMaxScaler(input_cols=["col1"], output_cols=["col1"]), + ), + ] + * 3 + ) + + assert pipeline_three_steps._can_be_trained_in_ml_runtime(dataset=self.dataframe_snowpark) is False + + assert ( + self.pipeline_two_steps_no_estimator._can_be_trained_in_ml_runtime(dataset=self.dataframe_snowpark) is False + ) + + assert ( + self.pipeline_two_steps_with_estimator._can_be_trained_in_ml_runtime(dataset=self.dataframe_snowpark) + is True + ) + + del os.environ[IN_ML_RUNTIME_ENV_VAR] + + def test_wrap_transformer_in_column_transformer(self): + input_cols = ["col1"] + transformer = MinMaxScaler(input_cols=input_cols, output_cols=["col2"]) + transformer_name = "MMS" + + wrapped_transformer = Pipeline._wrap_transformer_in_column_transformer(transformer_name, transformer) + assert isinstance(wrapped_transformer, ColumnTransformer) + assert len(wrapped_transformer.transformers) == 1 + + inner_transformer = wrapped_transformer.transformers[0] + assert inner_transformer[0] == transformer_name + assert isinstance(inner_transformer[1], sklearn_MMS) + assert inner_transformer[2] == input_cols + + def test_create_unfitted_sklearn_object(self) -> None: + sklearn_pipeline = self.pipeline_two_steps_with_estimator._create_unfitted_sklearn_object() + assert isinstance(sklearn_pipeline, sklearn_Pipeline) + for transformer_step in sklearn_pipeline.steps[:-1]: + assert isinstance(transformer_step[1], ColumnTransformer) + + pipeline_three_steps_with_estimator = Pipeline( + steps=[ + ( + "MMS", + MinMaxScaler(input_cols=["col1"], output_cols=["col1"]), + ), + ( + "MMS2", + MinMaxScaler(), + ), + ("model", LinearRegression()), + ] + ) + sklearn_pipeline = pipeline_three_steps_with_estimator._create_unfitted_sklearn_object() + assert isinstance(sklearn_pipeline, sklearn_Pipeline) + + skl_pipeline_steps = sklearn_pipeline.steps + assert isinstance(skl_pipeline_steps[0][1], ColumnTransformer) + assert isinstance(skl_pipeline_steps[1][1], sklearn_MMS) + assert isinstance(skl_pipeline_steps[2][1], sklearn_LR) + + def test_get_native_object(self) -> None: + sklearn_type = LinearRegression(input_cols=["col1, col2"], label_cols=["col3"]) + Pipeline._get_native_object(sklearn_type) + + xgb_type = XGBRegressor(input_cols=["col1, col2"], label_cols=["col3"]) + Pipeline._get_native_object(xgb_type) + + lgbm_type = LGBMClassifier(input_cols=["col1, col2"], label_cols=["col3"]) + Pipeline._get_native_object(lgbm_type) + + with self.assertRaises(ValueError): + Pipeline._get_native_object(pd.DataFrame()) + + def test_get_label_cols(self) -> None: + assert self.simple_pipeline._get_label_cols() == [] + + assert self.pipeline_two_steps_no_estimator._get_label_cols() == [] + + assert len(self.pipeline_two_steps_with_estimator._get_label_cols()) == 1 + + def test_is_pipeline_modifying_label_or_sample_weight(self) -> None: + """Tests whether the pipeline modifies either the label or sample weight columns.""" + assert self.simple_pipeline._is_pipeline_modifying_label_or_sample_weight() is False + + pipeline_modifying_label = Pipeline( + steps=[ + ( + "MMS", + MinMaxScaler(input_cols=["col1"], output_cols=["col1_out"]), + ), + ( + "MMS2", + MinMaxScaler(input_cols=["col3"], output_cols=["col3_out"]), + ), + ("model", LinearRegression(input_cols=["col1"], label_cols=["col3"])), + ] + ) + assert pipeline_modifying_label._is_pipeline_modifying_label_or_sample_weight() is True + + pipeline_modifying_sample_weight = Pipeline( + steps=[ + ( + "MMS", + MinMaxScaler(input_cols=["col1"], output_cols=["col1_out"]), + ), + ( + "MMS2", + MinMaxScaler(input_cols=["col3"], output_cols=["col3_out"]), + ), + ("model", LinearRegression(input_cols=["col1"], sample_weight_col="col3")), + ] + ) + assert pipeline_modifying_sample_weight._is_pipeline_modifying_label_or_sample_weight() is True + + def test_is_convertible_to_sklearn_object(self) -> None: + assert self.simple_pipeline._is_convertible_to_sklearn_object() is True + assert self.pipeline_two_steps_with_estimator._is_convertible_to_sklearn_object() is True + + pipeline_second_step_uses_input_cols = Pipeline( + steps=[ + ( + "MMS", + MinMaxScaler(input_cols=["col1"], output_cols=["col1_out"]), + ), + ( + "MMS", + MinMaxScaler(input_cols=["col_2"], output_cols=["col_2_out"]), + ), + ] + ) + assert pipeline_second_step_uses_input_cols._is_convertible_to_sklearn_object() is False + + pipeline_modifying_label = Pipeline( + steps=[ + ( + "MMS", + MinMaxScaler(input_cols=["col1"], output_cols=["col1_out"]), + ), + ( + "MMS2", + MinMaxScaler(input_cols=["col3"], output_cols=["col3_out"]), + ), + ("model", LinearRegression(input_cols=["col1"], label_cols=["col3"])), + ] + ) + assert pipeline_second_step_uses_input_cols._is_convertible_to_sklearn_object() is False + + pipeline_inner_step_is_not_convertible = Pipeline(steps=[("pipeline", pipeline_modifying_label)]) + assert pipeline_inner_step_is_not_convertible._is_convertible_to_sklearn_object() is False + + def test_to_sklearn(self) -> None: + """Tests behavior for converting the pipeline to an sklearn pipeline""" + assert isinstance(self.simple_pipeline.to_sklearn(), sklearn_Pipeline) + + def tearDown(self) -> None: + os.environ.pop(IN_ML_RUNTIME_ENV_VAR, None) + return super().tearDown() + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/modeling/preprocessing/one_hot_encoder.py b/snowflake/ml/modeling/preprocessing/one_hot_encoder.py index 6422e6d6..c82354a6 100644 --- a/snowflake/ml/modeling/preprocessing/one_hot_encoder.py +++ b/snowflake/ml/modeling/preprocessing/one_hot_encoder.py @@ -832,6 +832,18 @@ def map_encoded_value(row: pd.Series) -> List[int]: # columns: COLUMN_NAME, CATEGORY, COUNT, FITTED_CATEGORY, ENCODING, N_FEATURES_OUT, ENCODED_VALUE, OUTPUT_CATs assert dataset._session is not None + + def convert_to_string_excluding_nan(item: Any) -> Union[None, str]: + if pd.isna(item): + return None # or np.nan if you prefer to keep as NaN + else: + return str(item) + + # In case of fitting with pandas dataframe and transforming with snowpark dataframe + # state_pandas cannot recognize the datatype of _CATEGORY and _FITTED_CATEGORY column + # Therefore, apply the convert_to_string_excluding_nan function to _CATEGORY and _FITTED_CATEGORY + state_pandas[[_CATEGORY]] = state_pandas[[_CATEGORY]].applymap(convert_to_string_excluding_nan) + state_pandas[[_FITTED_CATEGORY]] = state_pandas[[_FITTED_CATEGORY]].applymap(convert_to_string_excluding_nan) state_df = dataset._session.create_dataframe(state_pandas) transformed_dataset = dataset diff --git a/snowflake/ml/registry/BUILD.bazel b/snowflake/ml/registry/BUILD.bazel index 19d7c1a1..08b6eb7e 100644 --- a/snowflake/ml/registry/BUILD.bazel +++ b/snowflake/ml/registry/BUILD.bazel @@ -5,11 +5,9 @@ package(default_visibility = ["//visibility:public"]) py_library( name = "model_registry", srcs = [ - "artifact.py", "model_registry.py", ], deps = [ - ":artifact_manager", ":schema", "//snowflake/ml/_internal:telemetry", "//snowflake/ml/_internal/utils:formatting", @@ -53,30 +51,6 @@ py_library( ], ) -py_library( - name = "artifact_manager", - srcs = [ - "_artifact_manager.py", - "artifact.py", - ], - deps = [ - ":schema", - "//snowflake/ml/_internal/utils:formatting", - "//snowflake/ml/_internal/utils:table_manager", - ], -) - -py_test( - name = "_artifact_test", - srcs = ["_artifact_test.py"], - deps = [ - ":artifact_manager", - "//snowflake/ml/_internal/utils:identifier", - "//snowflake/ml/test_utils:mock_data_frame", - "//snowflake/ml/test_utils:mock_session", - ], -) - py_library( name = "registry_impl", srcs = [ @@ -99,7 +73,6 @@ py_library( "__init__.py", ], deps = [ - ":artifact_manager", ":model_registry", ":registry_impl", ":schema", diff --git a/snowflake/ml/registry/_artifact_manager.py b/snowflake/ml/registry/_artifact_manager.py deleted file mode 100644 index 4927e0aa..00000000 --- a/snowflake/ml/registry/_artifact_manager.py +++ /dev/null @@ -1,156 +0,0 @@ -from typing import Optional, cast - -from snowflake import connector, snowpark -from snowflake.ml._internal.utils import formatting, table_manager -from snowflake.ml.registry import _initial_schema, artifact - - -class ArtifactManager: - """It manages artifacts in model registry.""" - - def __init__( - self, - session: snowpark.Session, - database_name: str, - schema_name: str, - ) -> None: - """Initializer of artifact manager. - - Args: - session: Session object to communicate with Snowflake. - database_name: Desired name of the model registry database. - schema_name: Desired name of the schema used by this model registry inside the database. - """ - self._session = session - self._database_name = database_name - self._schema_name = schema_name - self._fully_qualified_table_name = table_manager.get_fully_qualified_table_name( - self._database_name, self._schema_name, _initial_schema._ARTIFACT_TABLE_NAME - ) - - def exists( - self, - artifact_name: str, - artifact_version: Optional[str] = None, - ) -> bool: - """Validate if an artifact exists. - - Args: - artifact_name: Name of artifact. - artifact_version: Version of artifact. - - Returns: - bool: True if the artifact exists, False otherwise. - """ - selected_artifact = self.get(artifact_name, artifact_version).collect() - - assert ( - len(selected_artifact) < 2 - ), f"""Multiple records found for artifact with name/version: {artifact_name}/{artifact_version}!""" - - return len(selected_artifact) == 1 - - def add( - self, - artifact: artifact.Artifact, - artifact_id: str, - artifact_name: str, - artifact_version: Optional[str] = None, - ) -> artifact.Artifact: - """ - Add a new artifact. - - Args: - artifact: artifact object. - artifact_id: id of artifact. - artifact_name: name of artifact. - artifact_version: version of artifact. - - Returns: - A reference to artifact. - """ - if artifact_version is None: - artifact_version = "" - assert artifact_id != "", "Artifact id can't be empty." - - new_artifact = { - "ID": artifact_id, - "TYPE": artifact.type.value, - "NAME": artifact_name, - "VERSION": artifact_version, - "CREATION_ROLE": self._session.get_current_role(), - "CREATION_TIME": formatting.SqlStr("CURRENT_TIMESTAMP()"), - "ARTIFACT_SPEC": artifact._spec, - } - - # TODO: Consider updating the METADATA table for artifact history tracking as well. - table_manager.insert_table_entry(self._session, self._fully_qualified_table_name, new_artifact) - artifact._log(name=artifact_name, version=artifact_version, id=artifact_id) - return artifact - - def delete( - self, - artifact_name: str, - artifact_version: Optional[str] = None, - error_if_not_exist: bool = False, - ) -> None: - """ - Remove an artifact. - - Args: - artifact_name: Name of artifact. - artifact_version: Version of artifact. - error_if_not_exist: Whether to raise errors if the target entry doesn't exist. Default to be false. - - Raises: - DataError: If error_if_not_exist is true and the artifact doesn't exist in the database. - RuntimeError: If the artifact deletion failed. - """ - if not self.exists(artifact_name, artifact_version): - if error_if_not_exist: - raise connector.DataError( - f"Artifact {artifact_name}/{artifact_version} doesn't exist. Deletion failed." - ) - else: - return - - if artifact_version is None: - artifact_version = "" - delete_query = f"""DELETE FROM {self._fully_qualified_table_name} - WHERE NAME='{artifact_name}' AND VERSION='{artifact_version}' - """ - - # TODO: Consider updating the METADATA table for artifact history tracking as well. - try: - self._session.sql(delete_query).collect() - except Exception as e: - raise RuntimeError(f"Delete artifact {artifact_name}/{artifact_version} failed due to {e}") - - def get( - self, - artifact_name: str, - artifact_version: Optional[str] = None, - ) -> snowpark.DataFrame: - """Retrieve the Snowpark dataframe of the artifact matching the provided artifact id and type. - - Given that ID and TYPE act as a compound primary key for the artifact table, - the resulting dataframe should have at most, one row. - - Args: - artifact_name: Name of artifact. - artifact_version: Version of artifact. - - Returns: - A Snowpark dataframe representing the artifacts that match the given constraints. - - WARNING: - The returned DataFrame is writable and shouldn't be made accessible to users. - """ - if artifact_version is None: - artifact_version = "" - - artifacts = self._session.sql(f"SELECT * FROM {self._fully_qualified_table_name}") - target_artifact = artifacts.filter(snowpark.Column("NAME") == artifact_name).filter( - snowpark.Column("VERSION") == artifact_version - ) - return cast(snowpark.DataFrame, target_artifact) diff --git a/snowflake/ml/registry/_artifact_test.py b/snowflake/ml/registry/_artifact_test.py deleted file mode 100644 index 2202c813..00000000 --- a/snowflake/ml/registry/_artifact_test.py +++ /dev/null @@ -1,158 +0,0 @@ -import datetime -from typing import List, cast - -from absl.testing import absltest - -from snowflake import connector, snowpark -from snowflake.ml._internal.utils import identifier, table_manager -from snowflake.ml.registry import _artifact_manager, artifact -from snowflake.ml.test_utils import mock_data_frame, mock_session - -_DATABASE_NAME = identifier.get_inferred_name("_SYSTEM_MODEL_REGISTRY") -_SCHEMA_NAME = identifier.get_inferred_name("_SYSTEM_MODEL_REGISTRY_SCHEMA") -_TABLE_NAME = identifier.get_inferred_name("_SYSTEM_REGISTRY_ARTIFACTS") -_FULLY_QUALIFIED_TABLE_NAME = table_manager.get_fully_qualified_table_name(_DATABASE_NAME, _SCHEMA_NAME, _TABLE_NAME) - - -class ArtifactTest(absltest.TestCase): - """Testing Artifact table related functions.""" - - def setUp(self) -> None: - """Creates Snowpark environments for testing.""" - self._session = mock_session.MockSession(conn=None, test_case=self) - - def tearDown(self) -> None: - """Complete test case. Ensure all expected operations have been observed.""" - self._session.finalize() - - def _get_show_tables_success( - self, name: str, database_name: str = _DATABASE_NAME, schema_name: str = _SCHEMA_NAME - ) -> List[snowpark.Row]: - """Helper method that returns a DataFrame that looks like the response of from a successful listing of - tables.""" - return [ - snowpark.Row( - created_on=datetime.datetime(2022, 11, 4, 17, 1, 30, 153000), - name=name, - database_name=database_name, - schema_name=schema_name, - kind="TABLE", - comment="", - cluster_by="", - rows=0, - bytes=0, - owner="OWNER_ROLE", - retention_time=1, - change_tracking="OFF", - is_external="N", - enable_schema_evolution="N", - ) - ] - - def _get_select_artifact(self) -> List[snowpark.Row]: - """Helper method that returns a DataFrame that looks like the response of from a successful listing of - tables.""" - return [ - snowpark.Row( - id="FAKE_ID", - type=artifact.ArtifactType.TESTTYPE, - creation_time=datetime.datetime(2022, 11, 4, 17, 1, 30, 153000), - creation_role="OWNER_ROLE", - artifact_spec={}, - ) - ] - - def test_if_artifact_exists(self) -> None: - for mock_df_collect, expected_res in [ - (self._get_select_artifact(), True), - ([], False), - ]: - with self.subTest(): - artifact_name = "FAKE_ID" - artifact_version = "FAKE_VERSION" - expected_df = mock_data_frame.MockDataFrame() - expected_df.add_operation("filter") - expected_df.add_operation("filter") - expected_df.add_collect_result(cast(List[snowpark.Row], mock_df_collect)) - self._session.add_mock_sql(query=f"SELECT * FROM {_FULLY_QUALIFIED_TABLE_NAME}", result=expected_df) - self.assertEqual( - _artifact_manager.ArtifactManager( - session=cast(snowpark.Session, self._session), - database_name=_DATABASE_NAME, - schema_name=_SCHEMA_NAME, - ).exists( - artifact_name, - artifact_version, - ), - expected_res, - ) - - def test_add_artifact(self) -> None: - artifact_id = "FAKE_ID" - artifact_name = "FAKE_NAME" - artifact_version = "FAKE_VERSION" - art_obj = artifact.Artifact(type=artifact.ArtifactType.TESTTYPE, spec='{"description": "mock description"}') - - # Mock the insertion call - self._session.add_operation("get_current_role", result="current_role") - insert_query = ( - f"INSERT INTO {_FULLY_QUALIFIED_TABLE_NAME}" - " ( ARTIFACT_SPEC,CREATION_ROLE,CREATION_TIME,ID,NAME,TYPE,VERSION )" - " SELECT" - " '{\"description\": \"mock description\"}','current_role',CURRENT_TIMESTAMP()," - "'FAKE_ID','FAKE_NAME', 'TESTTYPE', 'FAKE_VERSION' " - ) - self._session.add_mock_sql( - query=insert_query, - result=mock_data_frame.MockDataFrame([snowpark.Row(**{"number of rows inserted": 1})]), - ) - _artifact_manager.ArtifactManager( - session=cast(snowpark.Session, self._session), - database_name=_DATABASE_NAME, - schema_name=_SCHEMA_NAME, - ).add( - artifact=art_obj, - artifact_id=artifact_id, - artifact_name=artifact_name, - artifact_version=artifact_version, - ) - - def test_delete_artifact(self) -> None: - for error_if_not_exist in [True, False]: - with self.subTest(): - if error_if_not_exist: - artifact_name = "FAKE_NAME" - artifact_version = "FAKE_VERSION" - expected_df = mock_data_frame.MockDataFrame() - expected_df.add_operation("filter") - expected_df.add_operation("filter") - expected_df.add_collect_result([]) - self._session.add_mock_sql(query=f"SELECT * FROM {_FULLY_QUALIFIED_TABLE_NAME}", result=expected_df) - with self.assertRaises(connector.DataError): - _artifact_manager.ArtifactManager( - session=cast(snowpark.Session, self._session), - database_name=_DATABASE_NAME, - schema_name=_SCHEMA_NAME, - ).delete( - artifact_name, - artifact_version, - True, - ) - else: - expected_df = mock_data_frame.MockDataFrame() - expected_df.add_operation("filter") - expected_df.add_operation("filter") - expected_df.add_collect_result([]) - self._session.add_mock_sql(query=f"SELECT * FROM {_FULLY_QUALIFIED_TABLE_NAME}", result=expected_df) - _artifact_manager.ArtifactManager( - session=cast(snowpark.Session, self._session), - database_name=_DATABASE_NAME, - schema_name=_SCHEMA_NAME, - ).delete( - artifact_name, - artifact_version, - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/registry/artifact.py b/snowflake/ml/registry/artifact.py deleted file mode 100644 index f6aff3d5..00000000 --- a/snowflake/ml/registry/artifact.py +++ /dev/null @@ -1,46 +0,0 @@ -import enum -from typing import Optional - - -# Set of allowed artifact types. -class ArtifactType(enum.Enum): - TESTTYPE = "TESTTYPE" # A placeholder type just for unit test - DATASET = "DATASET" - - -class Artifact: - """ - A reference to artifact. - - Properties: - id: A globally unique id represents this artifact. - spec: Specification of artifact in json format. - type: Type of artifact. - name: Name of artifact. - version: Version of artifact. - """ - - def __init__(self, type: ArtifactType, spec: str) -> None: - """Create an artifact. - - Args: - type: type of artifact. - spec: specification in json format. - """ - self.type: ArtifactType = type - self.name: Optional[str] = None - self.version: Optional[str] = None - self._spec: str = spec - self._id: Optional[str] = None - - def _log(self, name: str, version: str, id: str) -> None: - """Additional information when this artifact is logged. - - Args: - name: name of artifact. - version: version of artifact. - id: A global unique id represents this artifact. - """ - self.name = name - self.version = version - self._id = id diff --git a/snowflake/ml/registry/model_registry.py b/snowflake/ml/registry/model_registry.py index 0830e5c0..3e6a9862 100644 --- a/snowflake/ml/registry/model_registry.py +++ b/snowflake/ml/registry/model_registry.py @@ -29,19 +29,13 @@ table_manager, uri, ) -from snowflake.ml.dataset import dataset from snowflake.ml.model import ( _api as model_api, deploy_platforms, model_signature, type_hints as model_types, ) -from snowflake.ml.registry import ( - _artifact_manager, - _initial_schema, - _schema_version_manager, - artifact, -) +from snowflake.ml.registry import _initial_schema, _schema_version_manager from snowflake.snowpark._internal import utils as snowpark_utils if TYPE_CHECKING: @@ -142,7 +136,6 @@ def _create_registry_views( registry_table_name: str, metadata_table_name: str, deployment_table_name: str, - artifact_table_name: str, statement_params: Dict[str, Any], ) -> None: """Create views on underlying ModelRegistry tables. @@ -154,7 +147,6 @@ def _create_registry_views( registry_table_name: Name for the main model registry table. metadata_table_name: Name for the metadata table used by the model registry. deployment_table_name: Name for the deployment event table. - artifact_table_name: Name for the artifact table. statement_params: Function usage statement parameters used in sql query executions. """ fully_qualified_schema_name = table_manager.get_fully_qualified_schema_name(database_name, schema_name) @@ -235,23 +227,6 @@ def _create_registry_views( FROM {registry_table_name} {metadata_views_join}""" ).collect(statement_params=statement_params) - # Create artifact view. it joins artifact tables with registry table on model id. - artifact_view_name = identifier.concat_names([artifact_table_name, "_VIEW"]) - session.sql( - f"""CREATE OR REPLACE TEMPORARY VIEW {fully_qualified_schema_name}.{artifact_view_name} COPY GRANTS AS - SELECT - {registry_table_name}.NAME AS MODEL_NAME, - {registry_table_name}.VERSION AS MODEL_VERSION, - {artifact_table_name}.* - FROM {registry_table_name} - LEFT JOIN {artifact_table_name} - ON (ARRAY_CONTAINS( - {artifact_table_name}.ID::VARIANT, - {registry_table_name}.ARTIFACT_IDS) - ) - """ - ).collect(statement_params=statement_params) - def _create_active_permanent_deployment_view( session: snowpark.Session, @@ -337,11 +312,8 @@ def __init__( self._deployment_table = identifier.get_inferred_name(_DEPLOYMENT_TABLE_NAME) self._permanent_deployment_view = identifier.concat_names([self._deployment_table, "_VIEW"]) self._permanent_deployment_stage = identifier.concat_names([self._deployment_table, "_STAGE"]) - self._artifact_table = identifier.get_inferred_name(_initial_schema._ARTIFACT_TABLE_NAME) - self._artifact_view = identifier.concat_names([self._artifact_table, "_VIEW"]) self._session = session self._svm = _schema_version_manager.SchemaVersionManager(self._session, self._name, self._schema) - self._artifact_manager = _artifact_manager.ArtifactManager(self._session, self._name, self._schema) # A in-memory deployment info cache to store information of temporary deployments # TODO(zhe): Use a temporary table to replace the in-memory cache. @@ -359,7 +331,6 @@ def __init__( self._registry_table, self._metadata_table, self._deployment_table, - self._artifact_table, statement_params, ) @@ -399,9 +370,6 @@ def _fully_qualified_permanent_deployment_view_name(self) -> str: """Get the fully qualified name to the permanent deployment view.""" return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._permanent_deployment_view) - def _fully_qualified_artifact_view_name(self) -> str: - return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._artifact_view) - def _fully_qualified_schema_name(self) -> str: """Get the fully qualified name to the current registry schema.""" return table_manager.get_fully_qualified_schema_name(self._name, self._schema) @@ -858,7 +826,6 @@ def _register_model_with_id( output_spec: Optional[Dict[str, str]] = None, description: Optional[str] = None, tags: Optional[Dict[str, str]] = None, - artifacts: Optional[List[artifact.Artifact]] = None, ) -> None: """Helper function to register model metadata. @@ -878,10 +845,8 @@ def _register_model_with_id( description: A description for the model. The description can be changed later. tags: Key-value pairs of tags to be set for this model. Tags can be modified after model registration. - artifacts: A list of artifact references. Raises: - ValueError: Artifact ids not found in model registry. DataError: The given model already exists. DatabaseError: Unable to register the model properties into table. """ @@ -897,12 +862,6 @@ def _register_model_with_id( new_model["CREATION_ROLE"] = self._session.get_current_role() new_model["CREATION_ENVIRONMENT_SPEC"] = {"python": ".".join(map(str, sys.version_info[:3]))} - if artifacts is not None: - for atf in artifacts: - if not self._artifact_manager.exists(atf.name if atf.name is not None else "", atf.version): - raise ValueError(f"Artifact {atf.name}/{atf.version} not found in model registry.") - new_model["ARTIFACT_IDS"] = [art._id for art in artifacts] - existing_model_nums = self._list_selected_models(model_name=model_name, model_version=model_version).count() if existing_model_nums: raise connector.DataError( @@ -1356,42 +1315,6 @@ def get_metrics(self, model_name: str, model_version: str) -> Dict[str, object]: else: return dict() - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="1.0.10") - def log_artifact( - self, - artifact: artifact.Artifact, - name: str, - version: Optional[str] = None, - ) -> artifact.Artifact: - """Upload and register an artifact to the Model Registry. - - Args: - artifact: artifact object. - name: name of artifact. - version: version of artifact. - - Raises: - DataError: Artifact with same name and version already exists. - - Returns: - Return a reference to the artifact. - """ - - if self._artifact_manager.exists(name, version): - raise connector.DataError(f"Artifact {name}/{version} already exists.") - - artifact_id = self._get_new_unique_identifier() - return self._artifact_manager.add( - artifact=artifact, - artifact_id=artifact_id, - artifact_name=name, - artifact_version=version, - ) - # Combined Registry and Repository operations. @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, @@ -1410,7 +1333,6 @@ def log_model( pip_requirements: Optional[List[str]] = None, signatures: Optional[Dict[str, model_signature.ModelSignature]] = None, sample_input_data: Optional[Any] = None, - artifacts: Optional[List[artifact.Artifact]] = None, code_paths: Optional[List[str]] = None, options: Optional[model_types.BaseModelSaveOption] = None, ) -> Optional["ModelReference"]: @@ -1431,19 +1353,15 @@ def log_model( pip requirements. signatures: Signatures of the model, which is a mapping from target method name to signatures of input and output, which could be inferred by calling `infer_signature` method with sample input data. - sample_input_data: Sample of the input data for the model. If artifacts contains a feature store - generated dataset, then sample_input_data is not needed. If both sample_input_data and dataset provided - , then sample_input_data will be used to infer model signature. - artifacts: A list of artifact ids, which are generated from log_artifact(). + sample_input_data: Sample of the input data for the model. code_paths: Directory of code to import when loading and deploying the model. options: Additional options when saving the model. Raises: DataError: Raised when: 1) the given model already exists; - 2) given artifacts does not exists in this registry. ValueError: Raised when: # noqa: DAR402 - 1) Signatures, sample_input_data and artifact(dataset) are both not provided and model is not a + 1) Signatures and sample_input_data are both not provided and model is not a snowflake estimator. Exception: Raised when there is any error raised when saving the model. @@ -1458,18 +1376,6 @@ def log_model( self._model_identifier_is_nonempty_or_raise(model_name, model_version) - if artifacts is not None: - for atf in artifacts: - if not self._artifact_manager.exists(atf.name if atf.name is not None else "", atf.version): - raise connector.DataError(f"Artifact {atf.name}/{atf.version} does not exists.") - - if sample_input_data is None and artifacts is not None: - for atf in artifacts: - if atf.type == artifact.ArtifactType.DATASET: - ds = self.get_artifact(atf.name if atf.name is not None else "", atf.version) - sample_input_data = ds.features_df() - break - existing_model_nums = self._list_selected_models(model_name=model_name, model_version=model_version).count() if existing_model_nums: raise connector.DataError(f"Model {model_name}/{model_version} already exists. Unable to log the model.") @@ -1508,7 +1414,6 @@ def log_model( uri=uri.get_uri_from_snowflake_stage_path(stage_path), description=description, tags=tags, - artifacts=artifacts, ) return ModelReference(registry=self, model_name=model_name, model_version=model_version) @@ -1733,25 +1638,6 @@ def list_deployments(self, model_name: str, model_version: str) -> snowpark.Data ) return cast(snowpark.DataFrame, res) - @snowpark._internal.utils.private_preview(version="1.0.1") - def list_artifacts(self, model_name: str, model_version: Optional[str] = None) -> snowpark.DataFrame: - """List all artifacts that associated with given model name and version. - - Args: - model_name: Name of model. - model_version: Version of model. If version is none then only filter on name. - Defaults to none. - - Returns: - A snowpark dataframe that contains all artifacts that associated with the given model. - """ - artifacts = self._session.sql(f"SELECT * FROM {self._fully_qualified_artifact_view_name()}").filter( - snowpark.Column("MODEL_NAME") == model_name - ) - if model_version is not None: - artifacts = artifacts.filter(snowpark.Column("MODEL_VERSION") == model_version) - return cast(snowpark.DataFrame, artifacts) - @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, @@ -1782,38 +1668,6 @@ def get_deployment(self, model_name: str, model_version: str, *, deployment_name ) return cast(snowpark.DataFrame, deployment) - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - @snowpark._internal.utils.private_preview(version="1.0.11") - def get_artifact(self, name: str, version: Optional[str] = None) -> Optional[artifact.Artifact]: - """Get artifact with the given (name, version). - - Args: - name: Name of artifact. - version: Version of artifact. - - Returns: - A reference to artifact if found, otherwise none. - """ - artifacts = self._artifact_manager.get( - name, - version, - ).collect() - - if len(artifacts) == 0: - return None - - atf = artifacts[0] - if atf["TYPE"] == artifact.ArtifactType.DATASET.value: - ds = dataset.Dataset.from_json(atf["ARTIFACT_SPEC"], self._session) - ds._log(name=atf["NAME"], version=atf["VERSION"], id=atf["ID"]) - return ds - - assert f"Unrecognized artifact type: {atf['TYPE']}" - return None - @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, diff --git a/snowflake/ml/registry/model_registry_test.py b/snowflake/ml/registry/model_registry_test.py index 756c5e77..ba7ce9ab 100644 --- a/snowflake/ml/registry/model_registry_test.py +++ b/snowflake/ml/registry/model_registry_test.py @@ -555,25 +555,6 @@ def setup_create_views_call(self) -> None: [snowpark.Row(status=f"View {_REGISTRY_TABLE_NAME}_VIEW successfully created.")] ), ) - self.add_session_mock_sql( - query=( - f"""CREATE OR REPLACE TEMPORARY VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_ARTIFACTS_TABLE_NAME}_VIEW - COPY GRANTS AS - SELECT - {_REGISTRY_TABLE_NAME}.NAME AS MODEL_NAME, - {_REGISTRY_TABLE_NAME}.VERSION AS MODEL_VERSION, - {_ARTIFACTS_TABLE_NAME}.* - FROM {_REGISTRY_TABLE_NAME} - LEFT JOIN {_ARTIFACTS_TABLE_NAME} - ON (ARRAY_CONTAINS( - {_ARTIFACTS_TABLE_NAME}.ID::VARIANT, - {_REGISTRY_TABLE_NAME}.ARTIFACT_IDS)) - """ - ), - result=mock_data_frame.MockDataFrame( - [snowpark.Row(status=f"View {_ARTIFACTS_TABLE_NAME}_VIEW successfully created.")] - ), - ) def setup_open_existing(self) -> None: self.add_session_mock_sql( @@ -1169,7 +1150,6 @@ def test_log_model(self) -> None: uri=uri.get_uri_from_snowflake_stage_path(model_path), description="description", tags=None, - artifacts=None, ) self._mock_show_version_table_exists({}) diff --git a/snowflake/ml/registry/package_visibility_test.py b/snowflake/ml/registry/package_visibility_test.py index 8dd064c9..8ae69986 100644 --- a/snowflake/ml/registry/package_visibility_test.py +++ b/snowflake/ml/registry/package_visibility_test.py @@ -3,7 +3,7 @@ from absl.testing import absltest from snowflake.ml import registry -from snowflake.ml.registry import artifact, model_registry +from snowflake.ml.registry import model_registry class PackageVisibilityTest(absltest.TestCase): @@ -14,7 +14,6 @@ def test_class_visible(self) -> None: def test_module_visible(self) -> None: self.assertIsInstance(model_registry, ModuleType) - self.assertIsInstance(artifact, ModuleType) if __name__ == "__main__": diff --git a/snowflake/ml/test_utils/mock_data_frame.py b/snowflake/ml/test_utils/mock_data_frame.py index d8700763..90d5089e 100644 --- a/snowflake/ml/test_utils/mock_data_frame.py +++ b/snowflake/ml/test_utils/mock_data_frame.py @@ -163,6 +163,10 @@ def collect(self, *args: Any, **kwargs: Any) -> Any: mdfo = self._check_operation("collect", args, kwargs) return mdfo.result + def collect_nowait(self, *args: Any, **kwargs: Any) -> Any: + """Collect a dataframe. Corresponds to DataFrame.collect_nowait.""" + return self.collect(*args, **kwargs) + def filter(self, *args: Any, **kwargs: Any) -> Any: """Filter a dataframe. Corresponds to DataFrame.filter. diff --git a/snowflake/ml/version.bzl b/snowflake/ml/version.bzl index 448df7bf..01286942 100644 --- a/snowflake/ml/version.bzl +++ b/snowflake/ml/version.bzl @@ -1,2 +1,2 @@ # This is parsed by regex in conda reciper meta file. Make sure not to break it. -VERSION = "1.4.1" +VERSION = "1.5.0" diff --git a/tests/integ/snowflake/ml/_internal/BUILD.bazel b/tests/integ/snowflake/ml/_internal/BUILD.bazel index 5f6bd10c..ee91f525 100644 --- a/tests/integ/snowflake/ml/_internal/BUILD.bazel +++ b/tests/integ/snowflake/ml/_internal/BUILD.bazel @@ -28,6 +28,7 @@ py_test( "//snowflake/ml/_internal:env_utils", "//snowflake/ml/modeling/_internal:estimator_utils", "//snowflake/ml/modeling/_internal/snowpark_implementations:snowpark_handlers", + "//snowflake/ml/modeling/linear_model:linear_regression", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/test_utils:common_test_base", ], diff --git a/tests/integ/snowflake/ml/_internal/snowpark_handlers_test.py b/tests/integ/snowflake/ml/_internal/snowpark_handlers_test.py index 10085345..220844df 100644 --- a/tests/integ/snowflake/ml/_internal/snowpark_handlers_test.py +++ b/tests/integ/snowflake/ml/_internal/snowpark_handlers_test.py @@ -10,6 +10,7 @@ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import ( SnowparkTransformHandlers, ) +from snowflake.ml.modeling.linear_model import LinearRegression from tests.integ.snowflake.ml.test_utils import common_test_base @@ -28,6 +29,7 @@ def setUp(self) -> None: self._handlers = SnowparkTransformHandlers( dataset=self.input_df, estimator=self.fit_estimator, class_name="test", subproject="subproject" ) + self.dependencies = LinearRegression()._get_dependencies() def _get_test_dataset(self) -> Tuple[pd.DataFrame, List[str], List[str]]: """Constructs input dataset to be used in the integration test. @@ -67,7 +69,7 @@ def test_batch_inference(self) -> None: predictions = self._handlers.batch_inference( session=self.session, - dependencies=["snowflake-snowpark-python", "numpy", "scikit-learn", "cloudpickle"], + dependencies=self.dependencies, inference_method="predict", input_cols=self.input_cols, drop_input_cols=False, @@ -85,7 +87,7 @@ def test_score_snowpark(self) -> None: score = self._handlers.score( session=self.session, - dependencies=["snowflake-snowpark-python", "numpy", "scikit-learn", "cloudpickle"], + dependencies=self.dependencies, score_sproc_imports=["sklearn"], input_cols=self.input_cols, label_cols=self.label_cols, diff --git a/tests/integ/snowflake/ml/dataset/BUILD.bazel b/tests/integ/snowflake/ml/dataset/BUILD.bazel new file mode 100644 index 00000000..fd231a2d --- /dev/null +++ b/tests/integ/snowflake/ml/dataset/BUILD.bazel @@ -0,0 +1,39 @@ +load("//bazel:py_rules.bzl", "py_library", "py_test") + +package(default_visibility = [ + "//bazel:snowml_public_common", +]) + +py_library( + name = "dataset_integ_test_base", + testonly = True, + srcs = ["dataset_integ_test_base.py"], + deps = [ + "//snowflake/ml/dataset", + "//snowflake/ml/utils:connection_params", + "//tests/integ/snowflake/ml/fileset:fileset_integ_utils", + "//tests/integ/snowflake/ml/test_utils:common_test_base", + ], +) + +py_test( + name = "dataset_integ_test", + timeout = "long", + srcs = ["dataset_integ_test.py"], + shard_count = 6, + deps = [ + ":dataset_integ_test_base", + "//snowflake/ml/dataset", + ], +) + +py_test( + name = "dataset_tensorflow_integ_test", + timeout = "long", + srcs = ["dataset_tensorflow_integ_test.py"], + shard_count = 4, + deps = [ + ":dataset_integ_test_base", + "//snowflake/ml/dataset", + ], +) diff --git a/tests/integ/snowflake/ml/dataset/dataset_integ_test.py b/tests/integ/snowflake/ml/dataset/dataset_integ_test.py new file mode 100644 index 00000000..925399f0 --- /dev/null +++ b/tests/integ/snowflake/ml/dataset/dataset_integ_test.py @@ -0,0 +1,453 @@ +import os +import random +from typing import Any, Dict, Generator +from uuid import uuid4 + +import numpy as np +import pandas as pd +import torch +from absl.testing import absltest, parameterized +from numpy import typing as npt +from torch.utils import data + +from snowflake import snowpark +from snowflake.ml import dataset +from snowflake.ml._internal.exceptions import dataset_errors +from snowflake.snowpark import functions +from tests.integ.snowflake.ml.dataset import dataset_integ_test_base +from tests.integ.snowflake.ml.fileset import fileset_integ_utils +from tests.integ.snowflake.ml.test_utils import common_test_base + +np.random.seed(0) +random.seed(0) + + +class TestSnowflakeDataset(dataset_integ_test_base.TestSnowflakeDatasetBase): + """Integration tests for Snowflake Dataset.""" + + DS_INTEG_TEST_DB = "SNOWML_DATASET_TEST_DB" + DS_INTEG_TEST_SCHEMA = "DATASET_TEST" + + @common_test_base.CommonTestBase.sproc_test(local=True, additional_packages=[]) + def test_dataset_management(self) -> None: + """Test Dataset management APIs""" + dataset_name = f"dataset_integ_management_{uuid4().hex}" + ds = dataset.Dataset.create(self.session, dataset_name) + assert isinstance(ds, dataset.Dataset) # Use plain assert so type inferencing works + self.assertEmpty(ds.list_versions()) + + with self.assertRaises(dataset_errors.DatasetExistError): + dataset.Dataset.create(self.session, dataset_name) + + with self.assertRaises(dataset_errors.DatasetNotExistError): + dataset.Dataset.load(self.session, "dataset_not_exist") + + loaded_ds = dataset.Dataset.load(self.session, dataset_name) + self.assertEmpty(loaded_ds.list_versions()) + + # Create version. Should be reflected in both + dataset_version1 = "v1" + ds.create_version( + version=dataset_version1, + input_dataframe=self.session.table(self.test_table).limit(1000), + ) + self.assertListEqual([dataset_version1], ds.list_versions()) + self.assertListEqual([dataset_version1], loaded_ds.list_versions()) + + # FIXME: Add DatasetVersionExistError + with self.assertRaises(dataset_errors.DatasetExistError): + ds.create_version( + version=dataset_version1, + input_dataframe=self.session.table(self.test_table).limit(1000), + ) + + # Validate with two versions (including version list ordering) + dataset_version2 = "v2" + ds.create_version( + version=dataset_version2, + input_dataframe=self.session.table(self.test_table).limit(1000), + ) + self.assertListEqual([dataset_version1, dataset_version2], ds.list_versions()) + self.assertListEqual([dataset_version1, dataset_version2], loaded_ds.list_versions()) + + # Don't run in sprocs to speed up tests + def test_dataset_case_sensitivity(self) -> None: + dataset_name = "dataset_integ_case_sensitive" + dataset_version = "v1" + ds = dataset.create_from_dataframe( + self.session, + f'"{dataset_name}"', + dataset_version, + self.session.table(self.test_table).limit(1000), + ) + + # Test loading with dataset.Dataset.load() + with self.assertRaises(dataset_errors.DatasetNotExistError): + dataset.Dataset.load(self.session, dataset_name) + + loaded_ds = dataset.load_dataset(self.session, f'"{dataset_name}"', dataset_version) + self.assertEqual([dataset_version], loaded_ds.list_versions()) + self.assertEqual(ds.selected_version.url(), loaded_ds.selected_version.url()) + + # Test loading with dataset.load_dataset() + with self.assertRaises(dataset_errors.DatasetNotExistError): + dataset.Dataset.load(self.session, dataset_name) + + # Test version case sensitivity + _ = ds.select_version("v1") + with self.assertRaises(dataset_errors.DatasetNotExistError): + _ = ds.select_version("V1") + + def test_dataset_properties(self) -> None: + """Test Dataset version property loading""" + from datetime import datetime, timezone + + current_time = datetime.now(timezone.utc) + dataset_name = "dataset_integ_metadata" + dataset_version = "v1" + ds = dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=dataset_version, + input_dataframe=self.session.sql("SELECT 1"), + exclude_cols=["timestamp"], + comment="this is my dataset 'with quotes'", + ) + + self.assertListEqual(ds.read.data_sources[0].exclude_cols, ["timestamp"]) + self.assertEqual(ds.selected_version.comment, "this is my dataset 'with quotes'") + self.assertGreaterEqual(ds.selected_version.created_on, current_time) + + ds1 = ds.create_version("no_comment", self.session.sql("SELECT 1")) + self.assertEmpty(ds1.read.data_sources[0].exclude_cols) + self.assertIsNone(ds1.selected_version.comment) + + # Don't run in sprocs to speed up tests + def test_dataset_partition_by(self) -> None: + """Test Dataset creation from Snowpark DataFrame""" + dataset_name = f"{self.db}.{self.schema}.dataset_integ_partition" + ds1_version = "constant_partition" + ds1 = dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=ds1_version, + input_dataframe=self.session.table(self.test_table).limit(1000), + partition_by="'subdir'", + ) + ds1_dirs = {os.path.dirname(f) for f in ds1.read.files()} + self.assertListEqual([f"snow://dataset/{dataset_name}/versions/{ds1_version}/subdir"], sorted(ds1_dirs)) + + ds2_version = "range_partition" + ds2 = dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=ds2_version, + input_dataframe=self.session.sql( + "select seq4() as ID, uniform(1, 4, random(42)) as part from table(generator(rowcount => 10000))" + ), + partition_by="to_varchar(PART)", + ) + ds2_dirs = {os.path.dirname(f) for f in ds2.read.files()} + self.assertListEqual( + [ + f"snow://dataset/{dataset_name}/versions/{ds2_version}/1", + f"snow://dataset/{dataset_name}/versions/{ds2_version}/2", + f"snow://dataset/{dataset_name}/versions/{ds2_version}/3", + f"snow://dataset/{dataset_name}/versions/{ds2_version}/4", + ], + sorted(ds2_dirs), + ) + + # Don't run in sprocs to speed up tests + def test_create_from_dataframe(self) -> None: + """Test Dataset creation from Snowpark DataFrame""" + dataset_name = "dataset_integ_create_from_dataframe" + dataset_version = "v1" + ds = dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=dataset_version, + input_dataframe=self.session.table(self.test_table).limit(1000), + ) + + df = ds.read.to_snowpark_dataframe() + df_count = df.count() + + # Verify that duplicate dataset is not allowed to be created. + with self.assertRaises(dataset_errors.DatasetExistError): + dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=dataset_version, + input_dataframe=self.session.table(self.test_table).limit(1000), + ) + + # Verify that creating a different Dataset version works + dataset_version2 = "v2" + dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=dataset_version2, + input_dataframe=self.session.table(self.test_table).limit(1000), + ) + + # Ensure v1 contents unaffected by v2 + self.assertEqual(df_count, ds.read.to_snowpark_dataframe().count()) + + # Don't run in sprocs due to quirky schema handling in sproc (can't use USE SCHEMA but CREATE SCHEMA changes schema) + def test_create_from_dataframe_fqn(self) -> None: + """Test Dataset creation with fully qualified name""" + schema = dataset_integ_test_base.create_random_schema(self.session, self.DS_INTEG_TEST_SCHEMA) + self.session.use_schema(self.schema) # Keep session on main test schema + try: + dataset_name = f"{self.db}.{schema}.dataset_integ_create_from_dataframe_fqn" + dataset_version = "v1" + ds = dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=dataset_version, + input_dataframe=self.session.table(self.test_table).limit(1000), + ) + + self.assertGreater(len(ds.read.files()), 0) + for file in ds.read.files(): + self.assertStartsWith(file, f"snow://dataset/{dataset_name}/versions/{dataset_version}/") + finally: + self.session.sql(f"drop schema {self.db}.{schema}").collect() + + @common_test_base.CommonTestBase.sproc_test(local=True, additional_packages=[]) + def test_dataset_from_dataset(self) -> None: + # Generate random prefixes due to race condition between sprocs causing dataset collision + dataset_name = f"dataset_integ_dataset_from_dataset_{uuid4().hex}" + dataset_version = "v1" + ds = dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=dataset_version, + input_dataframe=self.session.table(self.test_table), + ) + + ds_df = ds.read.to_snowpark_dataframe() + dataset_version2 = "v2" + ds2 = dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=dataset_version2, + input_dataframe=ds_df, + ) + + self._validate_snowpark_dataframe(ds2.read.to_snowpark_dataframe()) + + # Don't run in sprocs since sprocs don't have delete privilege + def test_dataset_delete(self) -> None: + """Test dataset deletion""" + dataset_name = "dataset_integ_delete" + dataset_version = "test" + ds = dataset.create_from_dataframe( + self.session, + dataset_name, + dataset_version, + self.session.table(self.test_table).limit(1000), + ) + dsv = ds.selected_version + + self.assertNotEmpty(ds.list_versions()) + self.assertNotEmpty(dsv.list_files()) + ds.delete_version(dataset_version) + self.assertEmpty(ds.list_versions()) + + # Delete dataset. Loaded Dataset should also be deleted + loaded_ds = dataset.Dataset.load(self.session, dataset_name) + ds.delete() + with self.assertRaises(dataset_errors.DatasetNotExistError): + ds.list_versions() + with self.assertRaises(dataset_errors.DatasetNotExistError): + loaded_ds.list_versions() + + # create_version should fail for deleted/nonexistent datasets + with self.assertRaises(dataset_errors.DatasetNotExistError): + ds.create_version( + version="new_version", + input_dataframe=self.session.table(self.test_table).limit(1000), + ) + + # create_version should fail for deleted/nonexistent datasets + with self.assertRaises(dataset_errors.DatasetNotExistError): + dataset.Dataset.load(self.session, dataset_name) + + # Don't run in sprocs to speed up tests + def test_restore_nonexistent_dataset(self) -> None: + """Test load of non-existent dataset""" + # Dataset not exist + dataset_name = "dataset_integ_notexist" + with self.assertRaises(dataset_errors.DatasetNotExistError): + dataset.load_dataset( + name=dataset_name, + version="test", + session=self.session, + ) + + # Version not exist + dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version="test", + input_dataframe=self.session.sql("select 1"), + ) + with self.assertRaises(dataset_errors.DatasetNotExistError): + dataset.load_dataset( + name=dataset_name, + version="not_exist", + session=self.session, + ) + + @common_test_base.CommonTestBase.sproc_test(local=True, additional_packages=[]) + def test_file_access(self) -> None: + import pyarrow.parquet as pq + + dataset_name = f"dataset_integ_file_access_{uuid4().hex}" + dataset_version = "v1" + ds = dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=dataset_version, + input_dataframe=self.session.table(self.test_table), + ) + + pq_ds = pq.ParquetDataset(ds.read.files(), filesystem=ds.read.filesystem()) + pq_table = pq_ds.read() + self.assertEqual(self.num_rows, len(pq_table)) + self._validate_pandas(pq_table.to_pandas()) + + @common_test_base.CommonTestBase.sproc_test(local=True, additional_packages=[]) + def test_to_pandas(self) -> None: + dataset_name = f"dataset_integ_pandas_{uuid4().hex}" + dataset_version = "v1" + ds = dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=dataset_version, + input_dataframe=self.session.table(self.test_table), + ) + + pd_df = ds.read.to_pandas() + self._validate_pandas(pd_df) + + # FIXME: This currently fails due to float64 -> float32 cast during Dataset creation + # Additionally may need to sort the Pandas DataFrame to align with Snowpark DataFrame + # df = ds.to_snowpark_dataframe() + # pd.testing.assert_frame_equal(df.to_pandas(), pd_df, check_index_type=False) + + @common_test_base.CommonTestBase.sproc_test(local=True, additional_packages=[]) + def test_to_dataframe(self) -> None: + all_columns = [col for col, _ in fileset_integ_utils._TEST_RESULTSET_SCHEMA] + exclude_cols = all_columns[:2] + label_cols = all_columns[1:3] # Intentionally overlap with exclude_cols (unintended but likely common behavior) + + # Generate random prefixes due to race condition between sprocs causing dataset collision + dataset_name = f"dataset_integ_to_dataframe_{uuid4().hex}" + dataset_version = "test" + ds = dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=dataset_version, + input_dataframe=self.session.table(self.test_table), + exclude_cols=exclude_cols, + label_cols=label_cols, + ) + + df = ds.read.to_snowpark_dataframe() + self._validate_snowpark_dataframe(df) + self.assertSameElements(all_columns, df.columns) + + features_df = ds.read.to_snowpark_dataframe(only_feature_cols=True) + non_feature_cols = set(exclude_cols + label_cols) + feature_cols = [col for col in all_columns if col not in non_feature_cols] + self.assertSameElements(feature_cols, features_df.columns) + + @parameterized.parameters( # type: ignore[misc] + {"dataset_shuffle": True, "datapipe_shuffle": False, "drop_last_batch": False}, + {"dataset_shuffle": False, "datapipe_shuffle": True, "drop_last_batch": False}, + {"dataset_shuffle": False, "datapipe_shuffle": False, "drop_last_batch": True}, + {"dataset_shuffle": True, "datapipe_shuffle": True, "drop_last_batch": True}, + ) + def test_dataset_connectors(self, dataset_shuffle: bool, datapipe_shuffle: bool, drop_last_batch: bool) -> None: + self._test_dataset_connectors(dataset_shuffle, datapipe_shuffle, drop_last_batch) + + @common_test_base.CommonTestBase.sproc_test(local=False, additional_packages=[]) + @parameterized.parameters( # type: ignore[misc] + {"dataset_shuffle": True, "datapipe_shuffle": True, "drop_last_batch": True}, + ) + def test_dataset_connectors_sproc( + self, dataset_shuffle: bool, datapipe_shuffle: bool, drop_last_batch: bool + ) -> None: + # Generate random prefixes due to race condition between sprocs causing dataset collision + self._test_dataset_connectors( + dataset_shuffle, datapipe_shuffle, drop_last_batch, dataset_prefix=f"dataset_integ_sproc_{uuid4().hex}" + ) + + def validate_dataset( + self, datapipe_shuffle: bool, drop_last_batch: bool, batch_size: int, ds: dataset.Dataset + ) -> None: + pt_dp = ds.read.to_torch_datapipe( + batch_size=batch_size, shuffle=datapipe_shuffle, drop_last_batch=drop_last_batch + ) + self._validate_torch_datapipe(pt_dp, batch_size, drop_last_batch) + + df = ds.read.to_snowpark_dataframe() + self._validate_snowpark_dataframe(df) + + def _validate_torch_datapipe( + self, datapipe: "data.IterDataPipe[Dict[str, npt.NDArray[Any]]]", batch_size: int, drop_last_batch: bool + ) -> None: + def numpy_batch_generator() -> Generator[Dict[str, npt.NDArray[Any]], None, None]: + for batch in data.DataLoader(datapipe, batch_size=None, num_workers=0): + numpy_batch = {} + for k, v in batch.items(): + self.assertIsInstance(v, torch.Tensor) + self.assertEqual(1, v.dim()) + numpy_batch[k] = v.numpy() + yield numpy_batch + + self._validate_batches(batch_size, drop_last_batch, numpy_batch_generator) + + def _validate_snowpark_dataframe(self, df: snowpark.DataFrame) -> None: + for key in ["NUMBER_INT_COL", "NUMBER_FIXED_POINT_COL"]: + self.assertAlmostEqual( + fileset_integ_utils.get_column_min(key), + df.select(functions.min(key)).collect()[0][0], + 1, + ) + self.assertAlmostEqual( + fileset_integ_utils.get_column_max(key, self.num_rows), + df.select(functions.max(key)).collect()[0][0], + 1, + ) + self.assertAlmostEqual( + fileset_integ_utils.get_column_avg(key, self.num_rows), + df.select(functions.avg(key)).collect()[0][0], + 1, + ) + + def _validate_pandas(self, df: pd.DataFrame) -> None: + for key in ["NUMBER_INT_COL", "FLOAT_COL"]: + with self.subTest(key): + self.assertAlmostEqual( + fileset_integ_utils.get_column_min(key), + df[key].min(), + 1, + ) + self.assertAlmostEqual( + fileset_integ_utils.get_column_max(key, self.num_rows), + df[key].max(), + 1, + ) + self.assertAlmostEqual( + fileset_integ_utils.get_column_avg(key, self.num_rows), + df[key].mean(), + delta=1, # FIXME: We lose noticeable precision from data casting (~0.5 error) + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/dataset/dataset_integ_test_base.py b/tests/integ/snowflake/ml/dataset/dataset_integ_test_base.py new file mode 100644 index 00000000..224bc7cb --- /dev/null +++ b/tests/integ/snowflake/ml/dataset/dataset_integ_test_base.py @@ -0,0 +1,177 @@ +import random +from typing import Any, Callable, Dict, Generator, Optional +from uuid import uuid4 + +import numpy as np +from absl.testing import absltest +from numpy import typing as npt + +from snowflake.ml import dataset +from snowflake.snowpark import Session +from snowflake.snowpark._internal import utils as snowpark_utils +from tests.integ.snowflake.ml.fileset import fileset_integ_utils +from tests.integ.snowflake.ml.test_utils import common_test_base, test_env_utils + +np.random.seed(0) +random.seed(0) + + +class TestSnowflakeDatasetBase(common_test_base.CommonTestBase): + """Integration tests for Snowflake Dataset.""" + + DS_INTEG_TEST_DB: str + DS_INTEG_TEST_SCHEMA: str + + def setUp(self) -> None: + # Disable base class setup/teardown in favor of classmethods + pass + + def tearDown(self) -> None: + # Disable base class setup/teardown in favor of classmethods + pass + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.session = test_env_utils.get_available_session() + cls.num_rows = 10000 + cls.query = fileset_integ_utils.get_fileset_query(cls.num_rows) + cls.test_table = "test_table" + if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call] + cls.session.sql(f"CREATE DATABASE IF NOT EXISTS {cls.DS_INTEG_TEST_DB}").collect() + cls.session.use_database(cls.DS_INTEG_TEST_DB) + + cls.db = cls.session.get_current_database() + cls.schema = create_random_schema(cls.session, cls.DS_INTEG_TEST_SCHEMA, database=cls.db) + else: + cls.db = cls.session.get_current_database() + cls.schema = cls.session.get_current_schema() + cls.session.sql(f"create table if not exists {cls.test_table} as ({cls.query})").collect() + + @classmethod + def tearDownClass(cls) -> None: + if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call] + cls.session.sql(f"DROP SCHEMA IF EXISTS {cls.schema}").collect() + cleanup_schemas(cls.session, f"{cls.DS_INTEG_TEST_SCHEMA}%", cls.db) + cls.session.close() + super().tearDownClass() + + def validate_dataset( + self, datapipe_shuffle: bool, drop_last_batch: bool, batch_size: int, ds: dataset.Dataset + ) -> None: + raise NotImplementedError + + def _test_dataset_connectors( + self, + dataset_shuffle: bool, + datapipe_shuffle: bool, + drop_last_batch: bool, + batch_size: int = 2048, + dataset_prefix: str = "dataset_integ_connector", + ) -> None: + """Test if dataset create() can materialize a dataframe, and create a ready-to-use Dataset object.""" + dataset_name = f"{dataset_prefix}_{dataset_shuffle}_{datapipe_shuffle}_{drop_last_batch}" + dataset_version = "test" + created_ds = dataset.create_from_dataframe( + session=self.session, + name=dataset_name, + version=dataset_version, + input_dataframe=self.session.table(self.test_table), + shuffle=dataset_shuffle, + ) + + for file in created_ds.read.files(): + self.assertRegex( + file, rf"snow://dataset/{self.db}.{self.schema}.{dataset_name}/versions/{dataset_version}/.*[.]parquet" + ) + + # Verify that we can restore a Dataset object + ds = dataset.load_dataset( + name=dataset_name, + version=dataset_version, + session=self.session, + ) + for file in ds.read.files(): + self.assertRegex( + file, rf"snow://dataset/{self.db}.{self.schema}.{dataset_name}/versions/{dataset_version}/.*[.]parquet" + ) + + self.validate_dataset(datapipe_shuffle, drop_last_batch, batch_size, ds) + + def _validate_batches( + self, + batch_size: int, + drop_last_batch: bool, + numpy_batch_generator: Callable[[], Generator[Dict[str, npt.NDArray[Any]], None, None]], + ) -> None: + if drop_last_batch: + expected_num_rows = self.num_rows - self.num_rows % batch_size + else: + expected_num_rows = self.num_rows + + actual_min_counter = { + "NUMBER_INT_COL": float("inf"), + "NUMBER_FIXED_POINT_COL": float("inf"), + } + actual_max_counter = { + "NUMBER_INT_COL": 0.0, + "NUMBER_FIXED_POINT_COL": 0.0, + } + actual_sum_counter = { + "NUMBER_INT_COL": 0.0, + "NUMBER_FIXED_POINT_COL": 0.0, + } + actual_num_rows = 0 + for iteration, batch in enumerate(numpy_batch_generator()): + # If drop_last_batch is False, the last batch might not have the same size as the other batches. + if not drop_last_batch and iteration == self.num_rows // batch_size: + expected_batch_size = self.num_rows % batch_size + else: + expected_batch_size = batch_size + + for col_name in ["NUMBER_INT_COL", "NUMBER_FIXED_POINT_COL"]: + col = batch[col_name] + self.assertEqual(col.size, expected_batch_size) + + actual_min_counter[col_name] = min(np.min(col), actual_min_counter[col_name]) + actual_max_counter[col_name] = max(np.max(col), actual_max_counter[col_name]) + actual_sum_counter[col_name] += np.sum(col) + + actual_num_rows += expected_batch_size + + self.assertEqual(actual_num_rows, expected_num_rows) + actual_avg_counter = {"NUMBER_INT_COL": 0.0, "NUMBER_FIXED_POINT_COL": 0.0} + for key, value in actual_sum_counter.items(): + actual_avg_counter[key] = value / actual_num_rows + + if not drop_last_batch: + # We can only get the whole set of data for comparison if drop_last_batch is False. + for key in ["NUMBER_INT_COL", "NUMBER_FIXED_POINT_COL"]: + self.assertAlmostEqual(fileset_integ_utils.get_column_min(key), actual_min_counter[key], 1) + self.assertAlmostEqual( + fileset_integ_utils.get_column_max(key, expected_num_rows), actual_max_counter[key], 1 + ) + self.assertAlmostEqual( + fileset_integ_utils.get_column_avg(key, expected_num_rows), actual_avg_counter[key], 1 + ) + + +def create_random_schema( + session: Session, prefix: str, database: Optional[str] = None, additional_options: str = "" +) -> str: + database = database or session.get_current_database() + schema = f'"{prefix}_{uuid4().hex.upper()}"' + session.sql(f"CREATE SCHEMA IF NOT EXISTS {database}.{schema} {additional_options}").collect() + return schema + + +def cleanup_schemas(session: Session, schema: str, database: Optional[str] = None, expire_days: int = 1) -> None: + database = database or session.get_current_database() + schemas_df = session.sql(f"SHOW SCHEMAS LIKE '{schema}' IN DATABASE {database}") + stale_schemas = schemas_df.filter(f"\"created_on\" < dateadd('day', {-expire_days}, current_timestamp())").collect() + for stale_schema in stale_schemas: + session.sql(f"DROP SCHEMA IF EXISTS {database}.{stale_schema.name}").collect() + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/dataset/dataset_tensorflow_integ_test.py b/tests/integ/snowflake/ml/dataset/dataset_tensorflow_integ_test.py new file mode 100644 index 00000000..34b207e8 --- /dev/null +++ b/tests/integ/snowflake/ml/dataset/dataset_tensorflow_integ_test.py @@ -0,0 +1,65 @@ +import random +from typing import Any, Dict, Generator +from uuid import uuid4 + +import numpy as np +import tensorflow as tf +from absl.testing import absltest, parameterized +from numpy import typing as npt + +from snowflake.ml import dataset +from tests.integ.snowflake.ml.dataset import dataset_integ_test_base +from tests.integ.snowflake.ml.test_utils import common_test_base + +np.random.seed(0) +random.seed(0) + + +class TestSnowflakeDataseTensorflow(dataset_integ_test_base.TestSnowflakeDatasetBase): + """Integration tests for Snowflake Dataset.""" + + DS_INTEG_TEST_DB = "SNOWML_DATASET_TF_TEST_DB" + DS_INTEG_TEST_SCHEMA = "DATASET_TF_TEST" + + @parameterized.parameters( # type: ignore[misc] + {"dataset_shuffle": True, "datapipe_shuffle": False, "drop_last_batch": False}, + {"dataset_shuffle": False, "datapipe_shuffle": True, "drop_last_batch": False}, + {"dataset_shuffle": False, "datapipe_shuffle": False, "drop_last_batch": True}, + {"dataset_shuffle": True, "datapipe_shuffle": True, "drop_last_batch": True}, + ) + def test_dataset_connectors(self, dataset_shuffle: bool, datapipe_shuffle: bool, drop_last_batch: bool) -> None: + self._test_dataset_connectors(dataset_shuffle, datapipe_shuffle, drop_last_batch) + + @common_test_base.CommonTestBase.sproc_test(local=False, additional_packages=[]) + @parameterized.parameters( # type: ignore[misc] + {"dataset_shuffle": True, "datapipe_shuffle": True, "drop_last_batch": True}, + ) + def test_dataset_connectors_sproc( + self, dataset_shuffle: bool, datapipe_shuffle: bool, drop_last_batch: bool + ) -> None: + # Generate random prefixes due to race condition between sprocs causing dataset collision + self._test_dataset_connectors( + dataset_shuffle, datapipe_shuffle, drop_last_batch, dataset_prefix=f"dataset_integ_sproc_{uuid4().hex}" + ) + + def validate_dataset( + self, datapipe_shuffle: bool, drop_last_batch: bool, batch_size: int, ds: dataset.Dataset + ) -> None: + tf_ds = ds.read.to_tf_dataset(batch_size=batch_size, shuffle=datapipe_shuffle, drop_last_batch=drop_last_batch) + self._validate_tf_dataset(tf_ds, batch_size, drop_last_batch) + + def _validate_tf_dataset(self, dataset: "tf.data.Dataset", batch_size: int, drop_last_batch: bool) -> None: + def numpy_batch_generator() -> Generator[Dict[str, npt.NDArray[Any]], None, None]: + for batch in dataset: + numpy_batch = {} + for k, v in batch.items(): + self.assertIsInstance(v, tf.Tensor) + self.assertEqual(1, v.shape.rank) + numpy_batch[k] = v.numpy() + yield numpy_batch + + self._validate_batches(batch_size, drop_last_batch, numpy_batch_generator) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/extra_tests/BUILD.bazel b/tests/integ/snowflake/ml/extra_tests/BUILD.bazel index 7ec641e1..9ce48905 100644 --- a/tests/integ/snowflake/ml/extra_tests/BUILD.bazel +++ b/tests/integ/snowflake/ml/extra_tests/BUILD.bazel @@ -174,15 +174,3 @@ py_test( "//snowflake/ml/utils:connection_params", ], ) - -py_test( - name = "fit_transform_test", - srcs = ["fit_transform_test.py"], - shard_count = 3, - deps = [ - "//snowflake/ml/modeling/manifold:mds", - "//snowflake/ml/modeling/manifold:spectral_embedding", - "//snowflake/ml/modeling/manifold:tsne", - "//snowflake/ml/utils:connection_params", - ], -) diff --git a/tests/integ/snowflake/ml/extra_tests/fit_transform_test.py b/tests/integ/snowflake/ml/extra_tests/fit_transform_test.py deleted file mode 100644 index d29a611c..00000000 --- a/tests/integ/snowflake/ml/extra_tests/fit_transform_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import numpy as np -import pandas as pd -from absl.testing.absltest import TestCase, main -from sklearn.datasets import load_digits -from sklearn.manifold import ( - MDS as SKMDS, - TSNE as SKTSNE, - SpectralEmbedding as SKSpectralEmbedding, -) - -from snowflake.ml.modeling.manifold import MDS, TSNE, SpectralEmbedding -from snowflake.ml.utils.connection_params import SnowflakeLoginOptions -from snowflake.snowpark import Session - - -class FitTransformTest(TestCase): - def _load_data(self): - X, _ = load_digits(return_X_y=True) - self._input_df_pandas = pd.DataFrame(X)[:100] - self._input_df_pandas.columns = [str(c) for c in self._input_df_pandas.columns] - self._input_df = self._session.create_dataframe(self._input_df_pandas) - self._input_cols = self._input_df.columns - self._output_cols = [str(c) for c in range(100)] - - def setUp(self): - """Creates Snowpark and Snowflake environments for testing.""" - self._session = Session.builder.configs(SnowflakeLoginOptions()).create() - self._load_data() - - def tearDown(self): - self._session.close() - - def testMDS(self): - sk_embedding = SKMDS(n_components=2, normalized_stress="auto", random_state=2024) - - embedding = MDS( - input_cols=self._input_cols, - output_cols=self._output_cols, - n_components=2, - normalized_stress="auto", - random_state=2024, - ) - sk_X_transformed = sk_embedding.fit_transform(self._input_df_pandas) - X_transformed = embedding.fit_transform(self._input_df) - np.testing.assert_allclose(sk_X_transformed, X_transformed, rtol=1.0e-1, atol=1.0e-2) - - def testSpectralEmbedding(self): - sk_embedding = SKSpectralEmbedding(n_components=2, random_state=2024) - sk_X_transformed = sk_embedding.fit_transform(self._input_df_pandas) - - embedding = SpectralEmbedding( - input_cols=self._input_cols, output_cols=self._output_cols, n_components=2, random_state=2024 - ) - X_transformed = embedding.fit_transform(self._input_df) - np.testing.assert_allclose(sk_X_transformed, X_transformed, rtol=1.0e-1, atol=1.0e-2) - - def testTSNE(self): - sk_embedding = SKTSNE(n_components=2, random_state=2024, n_jobs=1) - sk_X_transformed = sk_embedding.fit_transform(self._input_df_pandas) - - embedding = TSNE( - input_cols=self._input_cols, - output_cols=self._output_cols, - n_components=2, - random_state=2024, - n_jobs=1, - ) - X_transformed = embedding.fit_transform(self._input_df) - np.testing.assert_allclose(sk_X_transformed.shape, X_transformed.shape, rtol=1.0e-1, atol=1.0e-2) - - -if __name__ == "__main__": - main() diff --git a/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py b/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py index 605dc95f..11b83f69 100644 --- a/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py +++ b/tests/integ/snowflake/ml/extra_tests/pipeline_with_ohe_and_xgbr_test.py @@ -230,6 +230,42 @@ def test_pipeline_with_limited_number_of_columns_in_estimator_export(self) -> No sk_results = sk_pipeline.predict(pd_df) np.testing.assert_allclose(snow_results.flatten(), sk_results.flatten(), rtol=1.0e-1, atol=1.0e-2) + def test_pipeline_squash(self) -> None: + pd_data = self._test_data + pd_data["ROW_INDEX"] = pd_data.reset_index().index + raw_data = self._session.create_dataframe(pd_data) + + pipeline = Pipeline( + steps=[ + ( + "OHE", + OneHotEncoder( + input_cols=categorical_columns, output_cols=categorical_columns, drop_input_cols=True + ), + ), + ( + "MMS", + MinMaxScaler( + clip=True, + input_cols=numerical_columns, + output_cols=numerical_columns, + ), + ), + ("KNNImputer", KNNImputer(input_cols=numerical_columns, output_cols=numerical_columns)), + ("regression", XGBClassifier(label_cols=label_column, passthrough_cols="ROW_INDEX")), + ] + ) + + p1 = pipeline.fit(raw_data) + results1 = p1.predict(raw_data).to_pandas().sort_values(by=["ROW_INDEX"])["OUTPUT_LABEL"].to_numpy() + + p2 = pipeline.fit(raw_data, squash=True) + results2 = p2.predict(raw_data).to_pandas().sort_values(by=["ROW_INDEX"])["OUTPUT_LABEL"].to_numpy() + + self.assertEqual(hash(p1), hash(p2)) + + np.testing.assert_allclose(results1.flatten(), results2.flatten(), rtol=1.0e-1, atol=1.0e-2) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/feature_store/BUILD.bazel b/tests/integ/snowflake/ml/feature_store/BUILD.bazel index d8a15d5a..0f72613c 100644 --- a/tests/integ/snowflake/ml/feature_store/BUILD.bazel +++ b/tests/integ/snowflake/ml/feature_store/BUILD.bazel @@ -48,7 +48,7 @@ py_test( srcs = [ "feature_store_large_scale_test.py", ], - shard_count = 4, + shard_count = 2, deps = [ ":common_utils", "//snowflake/ml/feature_store:feature_store_lib", @@ -84,3 +84,16 @@ py_test( "//snowflake/ml/utils:connection_params", ], ) + +py_test( + name = "feature_store_compatibility_test", + srcs = [ + "feature_store_compatibility_test.py", + ], + shard_count = 1, + deps = [ + ":common_utils", + "//snowflake/ml/feature_store:feature_store_lib", + "//snowflake/ml/utils:connection_params", + ], +) diff --git a/tests/integ/snowflake/ml/feature_store/access_utils.py b/tests/integ/snowflake/ml/feature_store/access_utils.py index bae00f71..5fc51057 100644 --- a/tests/integ/snowflake/ml/feature_store/access_utils.py +++ b/tests/integ/snowflake/ml/feature_store/access_utils.py @@ -3,8 +3,7 @@ from snowflake.ml.feature_store.feature_store import ( _FEATURE_STORE_OBJECT_TAG, - _FEATURE_VIEW_ENTITY_TAG, - _FEATURE_VIEW_TS_COL_TAG, + _FEATURE_VIEW_METADATA_TAG, FeatureStore, ) from snowflake.snowpark import Session, exceptions @@ -21,26 +20,24 @@ class FeatureStoreRole(Enum): _PRIVILEGE_LEVELS: Dict[FeatureStoreRole, Dict[str, List[str]]] = { FeatureStoreRole.ADMIN: { "database {database}": ["CREATE SCHEMA"], - f"tag {{database}}.{{schema}}.{_FEATURE_VIEW_ENTITY_TAG}": ["OWNERSHIP"], - f"tag {{database}}.{{schema}}.{_FEATURE_VIEW_TS_COL_TAG}": ["OWNERSHIP"], + f"tag {{database}}.{{schema}}.{_FEATURE_VIEW_METADATA_TAG}": ["OWNERSHIP"], f"tag {{database}}.{{schema}}.{_FEATURE_STORE_OBJECT_TAG}": ["OWNERSHIP"], "schema {database}.{schema}": ["OWNERSHIP"], }, FeatureStoreRole.PRODUCER: { "schema {database}.{schema}": [ + "CREATE TABLE", # Necessary for testing only (create mock table) "CREATE DYNAMIC TABLE", - "CREATE TABLE", "CREATE TAG", "CREATE VIEW", + # FIXME(dhung): Dataset RBAC won't be ready until 8.17 release + # "CREATE DATASET", ], - f"tag {{database}}.{{schema}}.{_FEATURE_VIEW_ENTITY_TAG}": ["APPLY"], - f"tag {{database}}.{{schema}}.{_FEATURE_VIEW_TS_COL_TAG}": ["APPLY"], + f"tag {{database}}.{{schema}}.{_FEATURE_VIEW_METADATA_TAG}": ["APPLY"], f"tag {{database}}.{{schema}}.{_FEATURE_STORE_OBJECT_TAG}": ["APPLY"], # TODO: The below privileges should be granted on a per-resource level # between producers (e.g. PRODUCER_A grants PRODUCER_B operate access # to FEATURE_VIEW_0, but not FEATURE_VIEW_1) - "future tables in schema {database}.{schema}": ["INSERT"], - "all tables in schema {database}.{schema}": ["INSERT"], "future dynamic tables in schema {database}.{schema}": ["OPERATE"], "all dynamic tables in schema {database}.{schema}": ["OPERATE"], "future tasks in schema {database}.{schema}": ["OPERATE"], @@ -66,6 +63,9 @@ class FeatureStoreRole(Enum): "SELECT", "REFERENCES", ], + # FIXME(dhung): Dataset RBAC won't be ready until 8.17 release + # "future datasets in schema {database}.{schema}": ["USAGE"], + # "all datasets in schema {database}.{schema}": ["USAGE"], }, FeatureStoreRole.NONE: {}, } diff --git a/tests/integ/snowflake/ml/feature_store/common_utils.py b/tests/integ/snowflake/ml/feature_store/common_utils.py index cd6fb491..75b32440 100644 --- a/tests/integ/snowflake/ml/feature_store/common_utils.py +++ b/tests/integ/snowflake/ml/feature_store/common_utils.py @@ -42,7 +42,14 @@ def create_random_schema( return schema -def compare_dataframe(actual_df: pd.DataFrame, target_data: Dict[str, Any], sort_cols: List[str]) -> None: +def compare_dataframe( + actual_df: pd.DataFrame, target_data: Dict[str, Any], sort_cols: List[str], exclude_cols: Optional[List[str]] = None +) -> None: + if exclude_cols is not None: + for c in exclude_cols: + assert c.upper() in actual_df, f"{c.upper()} is missing in actual_df" + actual_df = actual_df.drop([c.upper() for c in exclude_cols], axis=1) + target_df = pd.DataFrame(target_data).sort_values(by=sort_cols) assert_frame_equal( actual_df.sort_values(by=sort_cols).reset_index(drop=True), target_df.reset_index(drop=True), check_dtype=False diff --git a/tests/integ/snowflake/ml/feature_store/feature_store_access_test.py b/tests/integ/snowflake/ml/feature_store/feature_store_access_test.py index 0060c974..5bf889a2 100644 --- a/tests/integ/snowflake/ml/feature_store/feature_store_access_test.py +++ b/tests/integ/snowflake/ml/feature_store/feature_store_access_test.py @@ -81,7 +81,7 @@ def setUp(self) -> None: def _init_test_data(cls) -> str: prev_role = cls._session.get_current_role() try: - cls._session.use_role(cls._test_roles[Role.ADMIN]) + cls._session.use_role(cls._test_roles[Role.PRODUCER]) test_table: str = create_mock_table(cls._session, cls._test_database, cls._test_schema) # Create Entities @@ -325,6 +325,8 @@ def test_resume_feature_view(self, required_access: Role, test_access: Role) -> finally: self._feature_store.delete_feature_view(fv) + # FIXME(dhung): SNOW-1346923 + @absltest.skip("Dataset RBAC won't be ready until 8.17 release") # type: ignore[misc] @parameterized.product(required_access=[Role.PRODUCER], test_access=list(Role)) # type: ignore[misc] def test_generate_dataset(self, required_access: Role, test_access: Role) -> None: spine_df = self._session.sql(f"SELECT id FROM {self._mock_table}") @@ -333,12 +335,30 @@ def test_generate_dataset(self, required_access: Role, test_access: Role) -> Non dataset_name = f"FS_TEST_DATASET_{uuid4().hex.upper()}" self._test_access( - lambda: self._feature_store.generate_dataset(spine_df, [fv1, fv2], materialized_table=dataset_name), + lambda: self._feature_store.generate_dataset(dataset_name, spine_df, [fv1, fv2]), required_access, test_access, access_exception_dict={Role.NONE: snowpark_exceptions.SnowparkSQLException}, ) + # FIXME(dhung): SNOW-1346923 + @absltest.skip("Dataset RBAC won't be ready until 8.17 release") # type: ignore[misc] + @parameterized.product(required_access=[Role.CONSUMER], test_access=list(Role)) # type: ignore[misc] + def test_access_dataset(self, required_access: Role, test_access: Role) -> None: + spine_df = self._session.sql(f"SELECT id FROM {self._mock_table}") + fv1 = self._feature_store.get_feature_view("fv1", "v1") + fv2 = self._feature_store.get_feature_view("fv2", "v1") + dataset_name = f"FS_TEST_DATASET_{uuid4().hex.upper()}" + dataset = self._feature_store.generate_dataset(dataset_name, spine_df, [fv1, fv2]) + + self._test_access( + lambda: dataset.to_pandas(), + required_access, + test_access, + expected_result=lambda _pd: self.assertNotEmpty(_pd), + access_exception_dict={Role.NONE: snowpark_exceptions.SnowparkSQLException}, + ) + @parameterized.product(required_access=[Role.PRODUCER], test_access=list(Role)) # type: ignore[misc] def test_delete_feature_view(self, required_access: Role, test_access: Role) -> None: e = self._feature_store.get_entity("foo") @@ -399,11 +419,10 @@ def test_get_entity(self, required_access: Role, test_access: Role) -> None: @parameterized.product(required_access=[Role.CONSUMER], test_access=list(Role)) # type: ignore[misc] def test_list_feature_views(self, required_access: Role, test_access: Role) -> None: self._test_access( - lambda: self._feature_store.list_feature_views(as_dataframe=False), + lambda: self._feature_store.list_feature_views().collect(), required_access, test_access, expected_result=lambda rst: self.assertGreater(len(rst), 0), - access_exception_dict={Role.NONE: snowpark_exceptions.SnowparkSQLException}, ) @parameterized.product(required_access=[Role.CONSUMER], test_access=list(Role)) # type: ignore[misc] diff --git a/tests/integ/snowflake/ml/feature_store/feature_store_case_sensitivity_test.py b/tests/integ/snowflake/ml/feature_store/feature_store_case_sensitivity_test.py index 90cf4b06..b7d43dea 100644 --- a/tests/integ/snowflake/ml/feature_store/feature_store_case_sensitivity_test.py +++ b/tests/integ/snowflake/ml/feature_store/feature_store_case_sensitivity_test.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List from uuid import uuid4 from absl.testing import absltest, parameterized @@ -311,19 +311,19 @@ def test_join_keys_and_ts_col(self, equi_names: List[str], diff_names: List[str] @parameterized.parameters( [ ( - [("foo", "bar"), ("foo", "BAR"), ("FOO", "BAR"), ('"FOO"', "BAR")], - [('"foo"', "bar")], + ["foo", "Foo", "FOO"], + ['"foo"'], ), ( - [('"abc"', "def"), ('"abc"', "DEF")], - [("abc", "def")], + ['"abc"'], + ["abc", '"Abc"', '"aBC"'], ), ] ) # type: ignore[misc] - def test_feature_view_names_and_versions_combination( + def test_feature_view_names( self, - equi_full_names: List[Tuple[str, str]], - diff_full_names: List[Tuple[str, str]], + equi_full_names: List[str], + diff_full_names: List[str], ) -> None: current_schema = create_random_schema(self._session, "TEST_FEATURE_VIEW_NAMES") fs = FeatureStore( @@ -337,30 +337,26 @@ def test_feature_view_names_and_versions_combination( df = self._session.create_dataframe([1, 2, 3], schema=["a"]) e = Entity(name="my_cool_entity", join_keys=["a"]) - original_fv_name, original_version = equi_full_names[0] + original_fv_name = equi_full_names[0] fv_0 = FeatureView(name=original_fv_name, entities=[e], feature_df=df) fs.register_entity(e) - fs.register_feature_view(fv_0, original_version) + fs.register_feature_view(fv_0, "LATEST") # register with identical full name will fail - for equi_full_name in equi_full_names: - fv_name = equi_full_name[0] - version = equi_full_name[1] + for name in equi_full_names: with self.assertWarnsRegex(UserWarning, "FeatureView .* already exists..*"): - fv = FeatureView(name=fv_name, entities=[e], feature_df=df) - fs.register_feature_view(fv, version) + fv = FeatureView(name=name, entities=[e], feature_df=df) + fs.register_feature_view(fv, "LATEST") # register with different full name is fine - for diff_full_name in diff_full_names: - fv_name = diff_full_name[0] - version = diff_full_name[1] - fv = FeatureView(name=fv_name, entities=[e], feature_df=df) - fv = fs.register_feature_view(fv, version) + for name in diff_full_names: + fv = FeatureView(name=name, entities=[e], feature_df=df) + fv = fs.register_feature_view(fv, "LATEST") fs.read_feature_view(fv) - self.assertEqual(len(fs.list_feature_views(as_dataframe=False)), len(diff_full_names) + 1) + self.assertEqual(len(fs.list_feature_views().collect()), len(diff_full_names) + 1) self.assertEqual( - len(fs.list_feature_views(entity_name="my_cool_entity", as_dataframe=False)), + len(fs.list_feature_views(entity_name="my_cool_entity").collect()), len(diff_full_names) + 1, ) self.assertGreaterEqual( @@ -368,33 +364,26 @@ def test_feature_view_names_and_versions_combination( fs.list_feature_views( entity_name="my_cool_entity", feature_view_name=original_fv_name, - as_dataframe=False, - ) + ).collect() ), 1, ) - for diff_full_name in diff_full_names: - fv_name = diff_full_name[0] + for name in diff_full_names: self.assertGreaterEqual( len( fs.list_feature_views( entity_name="my_cool_entity", - feature_view_name=fv_name, - as_dataframe=False, - ) + feature_view_name=name, + ).collect() ), 1, ) - for equi_name in equi_full_names: - fv_name = equi_name[0] - version = equi_name[1] - fs.get_feature_view(fv_name, version) + for name in equi_full_names: + fs.get_feature_view(name, "LATEST") - for diff_name in diff_full_names: - fv_name = diff_name[0] - version = diff_name[1] - fs.get_feature_view(fv_name, version) + for name in diff_full_names: + fs.get_feature_view(name, "LATEST") @parameterized.parameters(TEST_NAMES) # type: ignore[misc] def test_find_objects(self, equi_names: List[str], diff_names: List[str]) -> None: @@ -431,27 +420,42 @@ def test_feature_view_version(self) -> None: fs.register_entity(e) fv = FeatureView(name="MY_FV", entities=[e], feature_df=df) - # 1: register with lowercase, get it back with lowercase/uppercase - fs.register_feature_view(fv, "a1") - fs.get_feature_view("MY_FV", "A1") - fs.get_feature_view("MY_FV", "a1") - - # 2: register with uppercase, get it back with lowercase/uppercase - fs.register_feature_view(fv, "B2") - fs.get_feature_view("MY_FV", "b2") - fs.get_feature_view("MY_FV", "B2") - - # 3. register with valid characters - fs.register_feature_view(fv, "V2_1") - fs.get_feature_view("MY_FV", "v2_1") - - # 4: register with invalid characters - with self.assertRaisesRegex(ValueError, "3 is not a valid SQL identifier: .*"): - fs.register_feature_view(fv, "3") - with self.assertRaisesRegex(ValueError, ".* is not allowed in version: .*"): - fs.register_feature_view(fv, "abc$") - with self.assertRaisesRegex(ValueError, "abc# is not a valid SQL identifier: .*"): - fs.register_feature_view(fv, "abc#") + valid_versions = [ + "v2", # start with letter + "3x", # start with digit + "1", # single digit + "2.1", # digit with period + "3_1", # digit with underscore + "4-1", # digit with hyphen + "4-1_2.3", # digit with period, underscore and hyphen + "x", # single letter + "4x_1", # digit, letter and underscore + "latest", # pure lowercase letters + "OLD", # pure uppercase letters + "OLap", # pure uppercase letters + "a" * 128, # within maximum allowed length + ] + + invalid_dataset_versions = [ + "", # empty + "_v1", # start with underscore + ".2", # start with period + "3/1", # digit with slash + "-4", # start with hyphen + "v1$", # start with letter, contains invalid character + "9^", # start with digit, contains invalid character + "a" * 129, # exceed maximum allowed length + ] + + for version in valid_versions: + fv_1 = fs.register_feature_view(fv, version) + self.assertTrue(("$" + version) in fv_1.fully_qualified_name()) + fv_2 = fs.get_feature_view("MY_FV", version) + self.assertTrue(("$" + version) in fv_2.fully_qualified_name()) + + for version in invalid_dataset_versions: + with self.assertRaisesRegex(ValueError, ".* is not a valid feature view version.*"): + fs.register_feature_view(fv, version) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/feature_store/feature_store_compatibility_test.py b/tests/integ/snowflake/ml/feature_store/feature_store_compatibility_test.py new file mode 100644 index 00000000..a86e7fa2 --- /dev/null +++ b/tests/integ/snowflake/ml/feature_store/feature_store_compatibility_test.py @@ -0,0 +1,171 @@ +from absl.testing import absltest +from common_utils import compare_dataframe +from snowflake.ml.version import VERSION + +from snowflake.ml.feature_store import ( # type: ignore[attr-defined] + CreationMode, + Entity, + FeatureStore, + FeatureView, + _FeatureStoreObjTypes, +) +from snowflake.ml.utils.connection_params import SnowflakeLoginOptions +from snowflake.snowpark import Session + +FS_COMPATIBIILTY_TEST_DB = "FEATURE_STORE_COMPATIBILITY_TEST_DB" +FS_COMPATIBIILTY_TEST_SCHEMA = "FEATURE_STORE_COMPATIBILITY_TEST_SCHEMA" +TEST_DATA = "test_data" +# check backward compatibility with two pkg versions in the past +BC_VERSION_LIMITS = 2 + + +def _create_test_data(session: Session) -> str: + test_table = f"{FS_COMPATIBIILTY_TEST_DB}.{FS_COMPATIBIILTY_TEST_SCHEMA}.{TEST_DATA}" + session.sql( + f"""CREATE TABLE IF NOT EXISTS {test_table} + (name VARCHAR(64), id INT, title VARCHAR(128), age INT, dept VARCHAR(64), ts INT) + """ + ).collect() + session.sql( + f"""INSERT OVERWRITE INTO {test_table} (name, id, title, age, dept, ts) + VALUES + ('john', 1, 'boss', 20, 'sales', 100), + ('porter', 2, 'manager', 30, 'engineer', 200) + """ + ).collect() + return test_table + + +class FeatureStoreCompatibilityTest(absltest.TestCase): + @classmethod + def setUpClass(self) -> None: + self._session = Session.builder.configs(SnowflakeLoginOptions()).create() + + try: + self._session.sql(f"CREATE DATABASE IF NOT EXISTS {FS_COMPATIBIILTY_TEST_DB}").collect() + self._session.use_database(FS_COMPATIBIILTY_TEST_DB) + self._session.sql(f"CREATE SCHEMA IF NOT EXISTS {FS_COMPATIBIILTY_TEST_SCHEMA}").collect() + self._session.use_schema(FS_COMPATIBIILTY_TEST_SCHEMA) + self._mock_table = _create_test_data(self._session) + except Exception as e: + raise Exception(f"Test setup failed: {e}") + + def test_cross_version_compatibilities(self) -> None: + fs = FeatureStore( + self._session, + FS_COMPATIBIILTY_TEST_DB, + FS_COMPATIBIILTY_TEST_SCHEMA, + self._session.get_current_warehouse(), + creation_mode=CreationMode.CREATE_IF_NOT_EXIST, + ) + self._maybe_create_feature_store_objects(fs) + + versions = fs._collapse_object_versions() + self.assertGreater(len(versions), 0) + for version in versions[: BC_VERSION_LIMITS + 1]: + self._check_per_version_access(str(version), fs) + + # TODO: update to check more than 1 with next version release + # obj from at least two versions should be listed + entity_df = fs.list_entities() + self.assertGreater(len(entity_df.collect()), 1) + fv_df = fs.list_feature_views() + self.assertGreater(len(fv_df.collect()), 1) + + def test_forward_compatibility_breakage(self) -> None: + with absltest.mock.patch( + "snowflake.ml.feature_store.feature_store._FeatureStoreObjTypes.parse", autospec=True + ) as MockFeatureStoreObjTypesParseFn: + MockFeatureStoreObjTypesParseFn.return_value = _FeatureStoreObjTypes.UNKNOWN + with self.assertRaisesRegex(RuntimeError, "The current snowflake-ml-python version .*"): + FeatureStore( + self._session, + FS_COMPATIBIILTY_TEST_DB, + FS_COMPATIBIILTY_TEST_SCHEMA, + self._session.get_current_warehouse(), + creation_mode=CreationMode.CREATE_IF_NOT_EXIST, + ) + + def test_pkg_version_falling_behind(self) -> None: + with absltest.mock.patch( + "snowflake.ml.feature_store.feature_store.snowml_version", autospec=True + ) as MockSnowMLVersion: + MockSnowMLVersion.VERSION = "1.0.0" + with self.assertWarnsRegex(UserWarning, "The current snowflake-ml-python version out of date.*"): + FeatureStore( + self._session, + FS_COMPATIBIILTY_TEST_DB, + FS_COMPATIBIILTY_TEST_SCHEMA, + self._session.get_current_warehouse(), + creation_mode=CreationMode.CREATE_IF_NOT_EXIST, + ) + + def _check_per_version_access(self, version: str, fs: FeatureStore) -> None: + entity_names = ["foo", "Bar"] + for e in entity_names: + fs.get_entity(self._get_versioned_object_name(e, version)) + + feature_view_names = ["unmanaged_fv", "MANAGED_fv"] + fvs = [] + for fv_name in feature_view_names: + fv_name = self._get_versioned_object_name(fv_name, version) + fv = fs.get_feature_view(fv_name, "V1") + data = fs.read_feature_view(fv) + self.assertEqual(len(data.collect()), 2) + fvs.append(fv) + + spine_df = self._session.create_dataframe( + [(1, "john", 101), (2, "porter", 202), (1, "john", 90)], schema=["id", "name", "ts"] + ) + data = fs.retrieve_feature_values( + spine_df=spine_df, + features=fvs, + spine_timestamp_col="ts", + ) + compare_dataframe( + actual_df=data.to_pandas(), + target_data={ + "ID": [1, 1, 2], + "NAME": ["john", "john", "porter"], + "TS": [90, 101, 202], + "TITLE": ["boss", "boss", "manager"], + "AGE": [None, 20, 30], + "DEPT": [None, "sales", "engineer"], + }, + sort_cols=["ID", "TS"], + ) + + def _maybe_create_feature_store_objects(self, fs: FeatureStore) -> None: + e1 = Entity(self._get_versioned_object_name("foo", VERSION), ["id"], f"VERSION={VERSION}") + fs.register_entity(e1) + e2 = Entity(self._get_versioned_object_name("Bar", VERSION), ["id", "name"], f"VERSION={VERSION}") + fs.register_entity(e2) + + sql1 = f"select id, title from {TEST_DATA}" + fv1 = FeatureView( + name=self._get_versioned_object_name("unmanaged_fv", VERSION), + entities=[e1], + feature_df=self._session.sql(sql1), + ) + fs.register_feature_view(feature_view=fv1, version="V1") + + sql2 = f"select id, name, age, dept, ts from {TEST_DATA}" + fv2 = FeatureView( + name=self._get_versioned_object_name("MANAGED_fv", VERSION), + entities=[e1, e2], + feature_df=self._session.sql(sql2), + timestamp_col="ts", + refresh_freq="60m", + ) + fs.register_feature_view(feature_view=fv2, version="V1") + + def _get_versioned_object_name(self, prefix: str, version: str) -> str: + name = f"{prefix}_{version.replace('.', '_')}" + if prefix.islower(): + return name + else: + return f'"{name}"' + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/feature_store/feature_store_large_scale_test.py b/tests/integ/snowflake/ml/feature_store/feature_store_large_scale_test.py index 79b14988..a4cde92a 100644 --- a/tests/integ/snowflake/ml/feature_store/feature_store_large_scale_test.py +++ b/tests/integ/snowflake/ml/feature_store/feature_store_large_scale_test.py @@ -14,7 +14,6 @@ ) from pandas.testing import assert_frame_equal -from snowflake.ml._internal.utils.sql_identifier import SqlIdentifier from snowflake.ml.feature_store import ( # type: ignore[attr-defined] CreationMode, Entity, @@ -123,7 +122,8 @@ def test_external_table(self) -> None: location_features = fs.register_feature_view(feature_view=location_features, version="V1") def create_select_query(start: str, end: str) -> str: - return f"""SELECT DISTINCT DATE_TRUNC('second', TO_TIMESTAMP(TPEP_DROPOFF_DATETIME)) AS DROPOFF_TIME, + return f"""SELECT + DISTINCT DATE_TRUNC('second', TO_TIMESTAMP(TO_VARCHAR(TPEP_DROPOFF_DATETIME))) AS DROPOFF_TIME, PULOCATIONID, TIP_AMOUNT, TOTAL_AMOUNT FROM {raw_dataset} WHERE DROPOFF_TIME >= '{start}' AND DROPOFF_TIME < '{end}' @@ -131,41 +131,41 @@ def create_select_query(start: str, end: str) -> str: spine_df_1 = self._session.sql(create_select_query("2016-01-01 00:00:00", "2016-01-03 00:00:00")) - result_table_name = f"FS_INTEG_TEST_{uuid4().hex.upper()}" + dataset_name = f"FS_INTEG_TEST_{uuid4().hex.upper()}" + dataset_version = "test_version" fv_slice = location_features.slice(["F_AVG_TIP", "F_AVG_TOTAL_AMOUNT"]) ds0 = fs.generate_dataset( spine_df=spine_df_1, features=[fv_slice], - materialized_table=result_table_name, + name=dataset_name, + version=dataset_version, spine_timestamp_col="DROPOFF_TIME", spine_label_cols=None, - save_mode="merge", ) # verify dataset metadata is correct - self.assertEqual(ds0.materialized_table, f"{FS_INTEG_TEST_DB}.{current_schema}.{result_table_name}") - self.assertIsNotNone(ds0.feature_store_metadata) - self.assertEqual(len(ds0.feature_store_metadata.features), 1) - deserialized_fv_slice = FeatureViewSlice.from_json(ds0.feature_store_metadata.features[0], self._session) - # verify snapshot 0 rows count equal to spine df rows count + dsv0 = ds0.selected_version + dsv0_meta = dsv0._get_metadata() + self.assertEqual( + dsv0.url(), f"snow://dataset/{FS_INTEG_TEST_DB}.{current_schema}.{dataset_name}/versions/{dataset_version}/" + ) + self.assertIsNotNone(dsv0_meta.properties) + self.assertEqual(len(dsv0_meta.properties.serialized_feature_views), 1) + deserialized_fv_slice = FeatureViewSlice.from_json( + dsv0_meta.properties.serialized_feature_views[0], self._session + ) + # verify dataset rows count equal to spine df rows count df1_row_count = len(spine_df_1.collect()) - self.assertEqual(self._session.sql(f"SELECT COUNT(*) FROM {ds0.snapshot_table}").collect()[0][0], df1_row_count) - self.assertEqual(len(ds0.df.collect()), df1_row_count) + self.assertEqual(len(ds0.read.to_snowpark_dataframe().collect()), df1_row_count) self.assertEqual(deserialized_fv_slice, fv_slice) - self.assertIsNone(ds0.label_cols) - self.assertDictEqual( - ds0.feature_store_metadata.connection_params, - { - "database": SqlIdentifier(FS_INTEG_TEST_DB).identifier(), - "schema": SqlIdentifier(current_schema).identifier(), - }, - ) + self.assertIsNone(dsv0_meta.label_cols) # verify materialized table value is correct actual_pdf = ( - self._session.sql(f"SELECT PULOCATIONID, F_AVG_TIP, F_AVG_TOTAL_AMOUNT FROM {ds0.materialized_table}") + ds0.read.to_snowpark_dataframe() + .select(["PULOCATIONID", "F_AVG_TIP", "F_AVG_TOTAL_AMOUNT"]) .to_pandas() .sort_values(by="PULOCATIONID") .reset_index(drop=True) @@ -179,28 +179,6 @@ def create_select_query(start: str, end: str) -> str: ) assert_frame_equal(expected_pdf, actual_pdf, check_dtype=True) - # generate another dataset and merge with original materialized table - spine_df_2 = self._session.sql(create_select_query("2016-01-04 00:00:00", "2016-01-05 00:00:00")) - ds1 = fs.generate_dataset( - spine_df=spine_df_2, - features=[fv_slice], - materialized_table=result_table_name, - spine_timestamp_col="DROPOFF_TIME", - spine_label_cols=None, - save_mode="merge", - ) - - df2_row_count = len(spine_df_2.collect()) - # verify snapshot 1 rows count equal to 2x of spine df rows count (as it appends same amount of rows) - self.assertEqual( - self._session.sql(f"SELECT COUNT(*) FROM {ds1.snapshot_table}").collect()[0][0], - df1_row_count + df2_row_count, - ) - self.assertEqual(len(ds1.df.collect()), df1_row_count + df2_row_count) - # verify snapshort 0 is not impacted after new data is merged with materialized table. - self.assertEqual(self._session.sql(f"SELECT COUNT(*) FROM {ds0.snapshot_table}").collect()[0][0], df1_row_count) - self.assertEqual(len(ds0.df.collect()), df1_row_count) - if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/feature_store/feature_store_object_test.py b/tests/integ/snowflake/ml/feature_store/feature_store_object_test.py index b4615d85..07fa4732 100644 --- a/tests/integ/snowflake/ml/feature_store/feature_store_object_test.py +++ b/tests/integ/snowflake/ml/feature_store/feature_store_object_test.py @@ -2,11 +2,13 @@ from absl.testing import absltest +from snowflake.ml._internal.exceptions import exceptions as snowml_exceptions from snowflake.ml.feature_store import ( # type: ignore[attr-defined] Entity, FeatureView, FeatureViewSlice, FeatureViewStatus, + FeatureViewVersion, ) from snowflake.ml.feature_store.feature_view import ( _FEATURE_OBJ_TYPE, @@ -113,6 +115,43 @@ def test_feature_view_slice_serde(self) -> None: with self.assertRaisesRegex(ValueError, "Invalid json str for FeatureViewSlice.*"): FeatureViewSlice.from_json(malformed, self._session) + def test_feature_view_versions(self) -> None: + valid_versions = [ + "v2", # start with letter + "3x", # start with digit + "1", # single digit + "2.1", # digit with period + "3_1", # digit with underscore + "4-1", # digit with hyphen + "4-1_2.3", # digit with period, underscore and hyphen + "x", # single letter + "4x_1", # digit, letter and underscore + "latest", # pure lowercase letters + "OLD", # pure uppercase letters + "OLap", # pure uppercase letters + "a" * 128, # within maximum allowed length + ] + + invalid_dataset_versions = [ + "", # empty + "_v1", # start with underscore + ".2", # start with period + "3/1", # digit with slash + "-4", # start with hyphen + "v1$", # start with letter, contains invalid character + "9^", # start with digit, contains invalid character + "a" * 129, # exceed maximum allowed length + ] + + for version in valid_versions: + FeatureViewVersion(version) + + for version in invalid_dataset_versions: + with self.assertRaisesRegex( + snowml_exceptions.SnowflakeMLException, ".* is not a valid feature view version.*" + ): + FeatureViewVersion(version) + class EntityTest(absltest.TestCase): def test_invalid_entity_name(self) -> None: diff --git a/tests/integ/snowflake/ml/feature_store/feature_store_test.py b/tests/integ/snowflake/ml/feature_store/feature_store_test.py index c3bc0e09..5f5f1819 100644 --- a/tests/integ/snowflake/ml/feature_store/feature_store_test.py +++ b/tests/integ/snowflake/ml/feature_store/feature_store_test.py @@ -2,27 +2,26 @@ from typing import List, Optional, Tuple, Union, cast from uuid import uuid4 -from absl.testing import absltest +from absl.testing import absltest, parameterized from common_utils import ( FS_INTEG_TEST_DATASET_SCHEMA, FS_INTEG_TEST_DB, FS_INTEG_TEST_DUMMY_DB, cleanup_temporary_objects, compare_dataframe, - compare_feature_views, create_mock_session, create_random_schema, get_test_warehouse_name, ) +from snowflake.ml.version import VERSION +from snowflake.ml import dataset from snowflake.ml._internal.utils.sql_identifier import SqlIdentifier -from snowflake.ml.dataset.dataset import Dataset from snowflake.ml.feature_store.entity import Entity from snowflake.ml.feature_store.feature_store import ( _ENTITY_TAG_PREFIX, _FEATURE_STORE_OBJECT_TAG, - _FEATURE_VIEW_ENTITY_TAG, - _FEATURE_VIEW_TS_COL_TAG, + _FEATURE_VIEW_METADATA_TAG, CreationMode, FeatureStore, ) @@ -32,11 +31,11 @@ FeatureViewStatus, ) from snowflake.ml.utils.connection_params import SnowflakeLoginOptions -from snowflake.snowpark import DataFrame, Session, exceptions as snowpark_exceptions +from snowflake.snowpark import Session, exceptions as snowpark_exceptions from snowflake.snowpark.functions import call_udf, col, udf -class FeatureStoreTest(absltest.TestCase): +class FeatureStoreTest(parameterized.TestCase): @classmethod def setUpClass(self) -> None: self._session = Session.builder.configs(SnowflakeLoginOptions()).create() @@ -85,7 +84,7 @@ def _create_mock_table(self, name: str) -> str: ).collect() return table_full_path - def _create_feature_store(self, name: Optional[str] = None) -> FeatureStore: + def _create_feature_store(self, name: Optional[str] = None, use_optimized_tag_ref: bool = False) -> FeatureStore: current_schema = create_random_schema(self._session, "FS_TEST") if name is None else name fs = FeatureStore( self._session, @@ -98,6 +97,7 @@ def _create_feature_store(self, name: Optional[str] = None) -> FeatureStore: # Intentionally point session to a different database to make sure feature store code is resilient to # session location. self._session.use_database(FS_INTEG_TEST_DUMMY_DB) + fs._use_optimized_tag_ref = use_optimized_tag_ref return fs def _check_tag_value( @@ -227,13 +227,14 @@ def test_clear_feature_store_system_error(self) -> None: fs.clear() fs._session = original_session - def test_create_and_delete_entities(self) -> None: - fs = self._create_feature_store() + @parameterized.parameters(True, False) # type: ignore[misc] + def test_create_and_delete_entities(self, use_optimized_tag_ref: bool) -> None: + fs = self._create_feature_store(use_optimized_tag_ref=use_optimized_tag_ref) entities = { "User": Entity("USER", ['"uid"']), "Ad": Entity('"aD"', ["aid"]), - "Product": Entity("Product", ["pid", "cid"]), + "Product": Entity("Product", ['"pid"', "cid"]), } # create new entities @@ -247,7 +248,7 @@ def test_create_and_delete_entities(self) -> None: actual_df=actual_result.drop(columns="OWNER"), target_data={ "NAME": ["aD", "PRODUCT", "USER"], - "JOIN_KEYS": ['["AID"]', '["CID","PID"]', '["uid"]'], + "JOIN_KEYS": ['["AID"]', '["pid,CID"]', '["uid"]'], "DESC": ["", "", ""], }, sort_cols=["NAME"], @@ -271,7 +272,7 @@ def test_create_and_delete_entities(self) -> None: actual_df=actual_result.drop(columns="OWNER"), target_data={ "NAME": ["PRODUCT", "USER"], - "JOIN_KEYS": ['["CID","PID"]', '["uid"]'], + "JOIN_KEYS": ['["pid,CID"]', '["uid"]'], "DESC": ["", ""], }, sort_cols=["NAME"], @@ -286,10 +287,10 @@ def test_create_and_delete_entities(self) -> None: # test delete entity failure with active feature views # create a new feature view - sql = f'SELECT name, id AS "uid" FROM {self._mock_table}' + sql = f'SELECT name, id AS "uid", id AS CID, id AS "pid" FROM {self._mock_table}' fv = FeatureView( name="fv", - entities=[entities["User"]], + entities=[entities["User"], entities["Product"]], feature_df=self._session.sql(sql), refresh_freq="1m", ) @@ -297,8 +298,9 @@ def test_create_and_delete_entities(self) -> None: with self.assertRaisesRegex(ValueError, "Cannot delete Entity .* due to active FeatureViews.*"): fs.delete_entity("User") - def test_retrieve_entity(self) -> None: - fs = self._create_feature_store() + @parameterized.parameters(True, False) # type: ignore[misc] + def test_retrieve_entity(self, use_optimized_tag_ref: bool) -> None: + fs = self._create_feature_store(use_optimized_tag_ref=use_optimized_tag_ref) e1 = Entity(name="foo", join_keys=["a", "b"], desc="my foo") e2 = Entity(name="bar", join_keys=["c"]) @@ -321,7 +323,7 @@ def test_retrieve_entity(self) -> None: actual_df=actual_result.drop(columns="OWNER"), target_data={ "NAME": ["FOO", "BAR"], - "JOIN_KEYS": ['["A","B"]', '["C"]'], + "JOIN_KEYS": ['["A,B"]', '["C"]'], "DESC": ["my foo", ""], }, sort_cols=["NAME"], @@ -348,8 +350,9 @@ def test_register_entity_system_error(self) -> None: with self.assertRaisesRegex(RuntimeError, "Failed to find object .*"): fs.register_entity(e) - def test_register_feature_view_with_unregistered_entity(self) -> None: - fs = self._create_feature_store() + @parameterized.parameters(True, False) # type: ignore[misc] + def test_register_feature_view_with_unregistered_entity(self, use_optimized_tag_ref: bool) -> None: + fs = self._create_feature_store(use_optimized_tag_ref=use_optimized_tag_ref) e = Entity("foo", ["id"]) @@ -364,7 +367,8 @@ def test_register_feature_view_with_unregistered_entity(self) -> None: with self.assertRaisesRegex(ValueError, "Entity .* has not been registered."): fs.register_feature_view(feature_view=fv, version="v1") - def test_register_feature_view_as_view(self) -> None: + @parameterized.parameters(True, False) # type: ignore[misc] + def test_register_feature_view_as_view(self, use_optimized_tag_ref: bool) -> None: """ APIs covered by test: 1. register_feature_view @@ -375,7 +379,7 @@ def test_register_feature_view_as_view(self) -> None: 6. generate_dataset (covers retrieve_feature_values) """ - fs = self._create_feature_store() + fs = self._create_feature_store(use_optimized_tag_ref=use_optimized_tag_ref) e = Entity("foo", ["id"]) fs.register_entity(e) @@ -388,11 +392,17 @@ def test_register_feature_view_as_view(self) -> None: timestamp_col="ts", desc="foobar", ).attach_feature_desc({"AGE": "my age", "TITLE": '"my title"'}) - fv = fs.register_feature_view(feature_view=fv, version="v1") + fv = fs.register_feature_view(feature_view=fv, version="2.0") - self._check_tag_value(fs, fv.fully_qualified_name(), "table", _FEATURE_STORE_OBJECT_TAG, "FEATURE_VIEW") + self._check_tag_value( + fs, + fv.fully_qualified_name(), + "table", + _FEATURE_STORE_OBJECT_TAG, + f"""{{"type": "EXTERNAL_FEATURE_VIEW", "pkg_version": "{VERSION}"}}""", + ) - self.assertEqual(fv, fs.get_feature_view("fv", "v1")) + self.assertEqual(fv, fs.get_feature_view("fv", "2.0")) compare_dataframe( actual_df=fs.read_feature_view(fv).to_pandas(), @@ -405,8 +415,6 @@ def test_register_feature_view_as_view(self) -> None: }, sort_cols=["ID", "TS"], ) - compare_feature_views(fs.list_feature_views(as_dataframe=False), [fv]) - compare_feature_views(fs.list_feature_views(entity_name="FOO", as_dataframe=False), [fv]) # create another feature view new_fv = FeatureView( @@ -417,20 +425,37 @@ def test_register_feature_view_as_view(self) -> None: desc="foobar", ) new_fv = fs.register_feature_view(feature_view=new_fv, version="V1") - compare_feature_views(fs.list_feature_views(as_dataframe=False), [fv, new_fv]) + + compare_dataframe( + actual_df=fs.list_feature_views(entity_name="FOO").to_pandas(), + target_data={ + "NAME": ["FV", "NEW_FV"], + "VERSION": ["2.0", "V1"], + "DATABASE_NAME": [fs._config.database] * 2, + "SCHEMA_NAME": [fs._config.schema] * 2, + "DESC": ["foobar", "foobar"], + "ENTITIES": ['[\n "FOO"\n]'] * 2, + }, + sort_cols=["NAME"], + exclude_cols=["CREATED_ON", "OWNER"], + ) # generate data on multiple feature views spine_df = self._session.create_dataframe([(1, 101)], schema=["id", "ts"]) ds = fs.generate_dataset( - spine_df=spine_df, features=[fv, new_fv], spine_timestamp_col="ts", include_feature_view_timestamp_col=True + spine_df=spine_df, + features=[fv, new_fv], + spine_timestamp_col="ts", + include_feature_view_timestamp_col=True, + name="test_ds", ) compare_dataframe( - actual_df=ds.df.to_pandas(), + actual_df=ds.read.to_pandas(), target_data={ "ID": [1], "TS": [101], - "FV_V1_TS": [100], + "FV_2.0_TS": [100], "NAME": ["jonh"], "TITLE": ["boss"], "AGE": [20], @@ -471,8 +496,9 @@ def test_register_feature_view_system_error(self) -> None: with self.assertRaisesRegex(RuntimeError, "(?s)Create dynamic table .* failed.*"): fs.register_feature_view(feature_view=fv2, version="v2") - def test_create_and_delete_feature_views(self) -> None: - fs = self._create_feature_store() + @parameterized.parameters(True, False) # type: ignore[misc] + def test_create_and_delete_feature_views(self, use_optimized_tag_ref: bool) -> None: + fs = self._create_feature_store(use_optimized_tag_ref=use_optimized_tag_ref) e1 = Entity("foo", ["aid"]) e2 = Entity("bar", ["uid"]) @@ -501,7 +527,13 @@ def test_create_and_delete_feature_views(self) -> None: ) # fv0.status == FeatureViewStatus.RUNNING can be removed after BCR 2024_02 gets fully deployed self.assertEqual(fv0.refresh_freq, "1 minute") self.assertEqual(fv0, fs.get_feature_view("fv0", "FIRST")) - self._check_tag_value(fs, fv0.fully_qualified_name(), "table", _FEATURE_STORE_OBJECT_TAG, "FEATURE_VIEW") + self._check_tag_value( + fs, + fv0.fully_qualified_name(), + "table", + _FEATURE_STORE_OBJECT_TAG, + f"""{{"type": "MANAGED_FEATURE_VIEW", "pkg_version": "{VERSION}"}}""", + ) # suspend feature view fv0 = fs.suspend_feature_view(fv0) @@ -533,7 +565,19 @@ def test_create_and_delete_feature_views(self) -> None: fs.update_default_warehouse(alternate_warehouse) fv1 = fs.register_feature_view(feature_view=fv1, version="FIRST") - compare_feature_views(fs.list_feature_views(as_dataframe=False), [fv0, new_fv0, fv1]) + compare_dataframe( + actual_df=fs.list_feature_views().to_pandas(), + target_data={ + "NAME": ["FV0", "FV0", "FV1"], + "VERSION": ["FIRST", "SECOND", "FIRST"], + "DATABASE_NAME": [fs._config.database] * 3, + "SCHEMA_NAME": [fs._config.schema] * 3, + "DESC": ["my_fv0", "my_new_fv0", "my_fv1"], + "ENTITIES": ['[\n "FOO",\n "BAR"\n]'] * 3, + }, + sort_cols=["NAME"], + exclude_cols=["CREATED_ON", "OWNER"], + ) # delete feature view with self.assertRaisesRegex(ValueError, "FeatureView .* has not been registered."): @@ -542,7 +586,19 @@ def test_create_and_delete_feature_views(self) -> None: fs.delete_feature_view(fs.get_feature_view("FV0", "FIRST")) - compare_feature_views(fs.list_feature_views(as_dataframe=False), [new_fv0, fv1]) + compare_dataframe( + actual_df=fs.list_feature_views().to_pandas(), + target_data={ + "NAME": ["FV0", "FV1"], + "VERSION": ["SECOND", "FIRST"], + "DATABASE_NAME": [fs._config.database] * 2, + "SCHEMA_NAME": [fs._config.schema] * 2, + "DESC": ["my_new_fv0", "my_fv1"], + "ENTITIES": ['[\n "FOO",\n "BAR"\n]'] * 2, + }, + sort_cols=["NAME"], + exclude_cols=["CREATED_ON", "OWNER"], + ) # test get feature view obj fv = fs.get_feature_view(name="fv1", version="FIRST") @@ -560,8 +616,9 @@ def test_create_and_delete_feature_views(self) -> None: fv = fs.get_feature_view(name="fv0", version="SECOND") self.assertEqual(str(fv.timestamp_col).upper(), "TS") - def test_create_duplicated_feature_view(self) -> None: - fs = self._create_feature_store() + @parameterized.parameters(True, False) # type: ignore[misc] + def test_create_duplicated_feature_view(self, use_optimized_tag_ref: bool) -> None: + fs = self._create_feature_store(use_optimized_tag_ref=use_optimized_tag_ref) e = Entity("foo", ["id"]) fs.register_entity(e) @@ -573,10 +630,10 @@ def test_create_duplicated_feature_view(self) -> None: feature_df=self._session.sql(sql), refresh_freq="1m", ) - fv = fs.register_feature_view(feature_view=fv, version="v1") + fv = fs.register_feature_view(feature_view=fv, version="r1") - with self.assertWarnsRegex(UserWarning, "FeatureView FV/V1 already exists. Skip registration. .*"): - fv = fs.register_feature_view(feature_view=fv, version="v1") + with self.assertWarnsRegex(UserWarning, "FeatureView FV/r1 already exists. Skip registration. .*"): + fv = fs.register_feature_view(feature_view=fv, version="r1") self.assertIsNotNone(fv) fv = FeatureView( @@ -586,7 +643,7 @@ def test_create_duplicated_feature_view(self) -> None: refresh_freq="1m", ) with self.assertWarnsRegex(UserWarning, "FeatureView .* already exists..*"): - fv = fs.register_feature_view(feature_view=fv, version="v1") + fv = fs.register_feature_view(feature_view=fv, version="r1") def test_resume_and_suspend_feature_view(self) -> None: fs = self._create_feature_store() @@ -641,8 +698,9 @@ def test_resume_and_suspend_feature_view_system_error(self) -> None: fs._session = original_session - def test_read_feature_view(self) -> None: - fs = self._create_feature_store() + @parameterized.parameters(True, False) # type: ignore[misc] + def test_read_feature_view(self, use_optimized_tag_ref: bool) -> None: + fs = self._create_feature_store(use_optimized_tag_ref=use_optimized_tag_ref) e = Entity("foo", ["id"]) fs.register_entity(e) @@ -669,8 +727,9 @@ def test_read_feature_view(self) -> None: sort_cols=["NAME"], ) - def test_register_with_cron_expr(self) -> None: - fs = self._create_feature_store() + @parameterized.parameters(True, False) # type: ignore[misc] + def test_register_with_cron_expr(self, use_optimized_tag_ref: bool) -> None: + fs = self._create_feature_store(use_optimized_tag_ref=use_optimized_tag_ref) e = Entity("foo", ["id"]) fs.register_entity(e) @@ -686,13 +745,17 @@ def test_register_with_cron_expr(self) -> None: fv = fs.get_feature_view("my_fv", "v1") self.assertEqual(my_fv, fv) - task_name = FeatureView._get_physical_name(fv.name, fv.version) # type: ignore[arg-type] + task_name = FeatureView._get_physical_name(fv.name, fv.version).resolved() # type: ignore[arg-type] res = self._session.sql(f"SHOW TASKS LIKE '{task_name}' IN SCHEMA {fs._config.full_schema_path}").collect() self.assertEqual(len(res), 1) self.assertEqual(res[0]["state"], "started") self.assertEqual(fv.refresh_freq, "DOWNSTREAM") self._check_tag_value( - fs, fv.fully_qualified_name(), "task", _FEATURE_STORE_OBJECT_TAG, "FEATURE_VIEW_REFRESH_TASK" + fs, + fv.fully_qualified_name(), + "task", + _FEATURE_STORE_OBJECT_TAG, + f"""{{"type": "FEATURE_VIEW_REFRESH_TASK", "pkg_version": "{VERSION}"}}""", ) fv = fs.suspend_feature_view(fv) @@ -707,8 +770,9 @@ def test_register_with_cron_expr(self) -> None: res = self._session.sql(f"SHOW TASKS LIKE '{task_name}' IN SCHEMA {fs._config.full_schema_path}").collect() self.assertEqual(len(res), 0) - def test_retrieve_time_series_feature_values(self) -> None: - fs = self._create_feature_store() + @parameterized.parameters(True, False) # type: ignore[misc] + def test_retrieve_time_series_feature_values(self, use_optimized_tag_ref: bool) -> None: + fs = self._create_feature_store(use_optimized_tag_ref=use_optimized_tag_ref) e = Entity("foo", ["id"]) fs.register_entity(e) @@ -721,7 +785,7 @@ def test_retrieve_time_series_feature_values(self) -> None: timestamp_col="ts", refresh_freq="DOWNSTREAM", ) - fv1 = fs.register_feature_view(feature_view=fv1, version="v1") + fv1 = fs.register_feature_view(feature_view=fv1, version="1.0") sql2 = f"SELECT id, age, ts FROM {self._mock_table}" fv2 = FeatureView( @@ -731,7 +795,7 @@ def test_retrieve_time_series_feature_values(self) -> None: timestamp_col="ts", refresh_freq="DOWNSTREAM", ) - fv2 = fs.register_feature_view(feature_view=fv2, version="v1") + fv2 = fs.register_feature_view(feature_view=fv2, version="1.0") sql3 = f"SELECT id, dept FROM {self._mock_table}" fv3 = FeatureView( @@ -740,7 +804,7 @@ def test_retrieve_time_series_feature_values(self) -> None: feature_df=self._session.sql(sql3), refresh_freq="DOWNSTREAM", ) - fv3 = fs.register_feature_view(feature_view=fv3, version="v1") + fv3 = fs.register_feature_view(feature_view=fv3, version="1.0") spine_df = self._session.create_dataframe([(1, 101), (2, 202), (1, 90)], schema=["id", "ts"]) df = fs.retrieve_feature_values( @@ -755,10 +819,10 @@ def test_retrieve_time_series_feature_values(self) -> None: target_data={ "ID": [1, 1, 2], "TS": [90, 101, 202], - "FV1_V1_TS": [None, 100, 200], + "FV1_1.0_TS": [None, 100, 200], "NAME": [None, "jonh", "porter"], "TITLE": [None, "boss", "manager"], - "FV2_V1_TS": [None, 100, 200], + "FV2_1.0_TS": [None, 100, 200], "AGE": [None, 20, 30], "DEPT": ["sales", "sales", "engineer"], }, @@ -836,8 +900,9 @@ def test_retrieve_feature_values(self) -> None: # test retrieve_feature_values with serialized feature objects fv1_slice = fv1.slice(["name"]) - dataset = fs.generate_dataset(spine_df, features=[fv1_slice, fv2]) - df = fs.retrieve_feature_values(spine_df=spine_df, features=cast(List[str], dataset.load_features())) + dataset = fs.generate_dataset(spine_df=spine_df, features=[fv1_slice, fv2], name="test_ds") + features = fs.load_feature_views_from_dataset(dataset) + df = fs.retrieve_feature_values(spine_df=spine_df, features=features) compare_dataframe( actual_df=df.to_pandas(), target_data={ @@ -851,77 +916,155 @@ def test_retrieve_feature_values(self) -> None: def test_invalid_load_feature_views_from_dataset(self) -> None: fs = self._create_feature_store() - dataset = Dataset(self._session, self._session.create_dataframe([1, 2, 3], schema=["foo"])) + ds = dataset.create_from_dataframe( + self._session, + "test_ds", + uuid4().hex, + input_dataframe=self._session.create_dataframe([1, 2, 3], schema=["foo"]), + ) with self.assertRaisesRegex(ValueError, "Dataset.*does not contain valid feature view information."): - fs.load_feature_views_from_dataset(dataset) + fs.load_feature_views_from_dataset(ds) - def test_list_feature_views(self) -> None: - fs = self._create_feature_store() + @parameterized.parameters(True, False) # type: ignore[misc] + def test_list_feature_views(self, use_optimized_tag_ref: bool) -> None: + fs = self._create_feature_store(use_optimized_tag_ref=use_optimized_tag_ref) - e = Entity("foo", ["id"]) - fs.register_entity(e) + e1 = Entity("foo", ["id"]) + fs.register_entity(e1) + e2 = Entity("bar", ["name"]) + fs.register_entity(e2) - self.assertEqual(fs.list_feature_views(entity_name="foo", as_dataframe=False), []) + compare_dataframe( + actual_df=fs.list_feature_views(entity_name="FOO").to_pandas(), + target_data={ + "NAME": [], + "VERSION": [], + "DATABASE_NAME": [], + "SCHEMA_NAME": [], + "CREATED_ON": [], + "OWNER": [], + "DESC": [], + "ENTITIES": [], + }, + sort_cols=["NAME"], + ) # 1. Right side is FeatureViewSlice sql1 = f"SELECT id, name, ts FROM {self._mock_table}" fv1 = FeatureView( name="fv1", - entities=[e], + entities=[e1], feature_df=self._session.sql(sql1), timestamp_col="ts", refresh_freq="DOWNSTREAM", ) fv1.attach_feature_desc({"name": "this is my name col"}) - fv1 = fs.register_feature_view(feature_view=fv1, version="v1") + fs.register_feature_view(feature_view=fv1, version="v1") - sql2 = f"SELECT id, title, age FROM {self._mock_table}" + sql2 = f"SELECT id, name, title, age FROM {self._mock_table}" fv2 = FeatureView( name="fv2", - entities=[e], + entities=[e2], feature_df=self._session.sql(sql2), refresh_freq="DOWNSTREAM", + desc="foobar", ) - fv2 = fs.register_feature_view(feature_view=fv2, version="v1") - self.assertEqual( - sorted( - cast( - List[FeatureView], - fs.list_feature_views(entity_name="Foo", as_dataframe=False), - ), - key=lambda fv: fv.name, - ), - [fv1, fv2], - ) - self.assertEqual( - fs.list_feature_views(entity_name="foo", feature_view_name="fv1", as_dataframe=False), - [fv1], - ) - - df = cast(DataFrame, fs.list_feature_views()) - self.assertListEqual( - df.columns, - [ - "NAME", - "ENTITIES", - "TIMESTAMP_COL", - "DESC", - "QUERY", - "VERSION", - "STATUS", - "FEATURE_DESC", - "REFRESH_FREQ", - "DATABASE", - "SCHEMA", - "WAREHOUSE", - "REFRESH_MODE", - "REFRESH_MODE_REASON", - "OWNER", - "PHYSICAL_NAME", - ], + fs.register_feature_view(feature_view=fv2, version="v1") + + fv3 = FeatureView( + name="fv3", + entities=[e1, e2], + feature_df=self._session.sql(sql2), + refresh_freq="DOWNSTREAM", + desc="foobar", + ) + fs.register_feature_view(feature_view=fv3, version="v1") + + compare_dataframe( + actual_df=fs.list_feature_views().to_pandas(), + target_data={ + "NAME": ["FV1", "FV2", "FV3"], + "VERSION": ["v1", "v1", "v1"], + "DATABASE_NAME": [fs._config.database] * 3, + "SCHEMA_NAME": [fs._config.schema] * 3, + "DESC": ["", "foobar", "foobar"], + "ENTITIES": ['[\n "FOO"\n]', '[\n "BAR"\n]', '[\n "FOO",\n "BAR"\n]'], + }, + sort_cols=["NAME"], + exclude_cols=["CREATED_ON", "OWNER"], + ) + + compare_dataframe( + actual_df=fs.list_feature_views(entity_name="FOO").to_pandas(), + target_data={ + "NAME": ["FV1", "FV3"], + "VERSION": ["v1", "v1"], + "DATABASE_NAME": [fs._config.database, fs._config.database], + "SCHEMA_NAME": [fs._config.schema, fs._config.schema], + "DESC": ["", "foobar"], + "ENTITIES": ['[\n "FOO"\n]', '[\n "FOO",\n "BAR"\n]'], + }, + sort_cols=["NAME"], + exclude_cols=["CREATED_ON", "OWNER"], + ) + + compare_dataframe( + actual_df=fs.list_feature_views(feature_view_name="FV2").to_pandas(), + target_data={ + "NAME": ["FV2"], + "VERSION": ["v1"], + "DATABASE_NAME": [fs._config.database], + "SCHEMA_NAME": [fs._config.schema], + "DESC": ["foobar"], + "ENTITIES": ['[\n "BAR"\n]'], + }, + sort_cols=["NAME"], + exclude_cols=["CREATED_ON", "OWNER"], + ) + + compare_dataframe( + actual_df=fs.list_feature_views(entity_name="BAR", feature_view_name="FV2").to_pandas(), + target_data={ + "NAME": ["FV2"], + "VERSION": ["v1"], + "DATABASE_NAME": [fs._config.database], + "SCHEMA_NAME": [fs._config.schema], + "DESC": ["foobar"], + "ENTITIES": ['[\n "BAR"\n]'], + }, + sort_cols=["NAME"], + exclude_cols=["CREATED_ON", "OWNER"], + ) + + compare_dataframe( + actual_df=fs.list_feature_views(entity_name="FOO", feature_view_name="FV3").to_pandas(), + target_data={ + "NAME": ["FV3"], + "VERSION": ["v1"], + "DATABASE_NAME": [fs._config.database], + "SCHEMA_NAME": [fs._config.schema], + "DESC": ["foobar"], + "ENTITIES": ['[\n "FOO",\n "BAR"\n]'], + }, + sort_cols=["NAME"], + exclude_cols=["CREATED_ON", "OWNER"], + ) + + compare_dataframe( + actual_df=fs.list_feature_views(entity_name="BAR", feature_view_name="BAZ").to_pandas(), + target_data={ + "NAME": [], + "VERSION": [], + "DATABASE_NAME": [], + "SCHEMA_NAME": [], + "CREATED_ON": [], + "OWNER": [], + "DESC": [], + "ENTITIES": [], + }, + sort_cols=["NAME"], ) - result = df.collect() - self.assertEqual(len(result), 2) + fs._check_feature_store_object_versions() def test_list_feature_views_system_error(self) -> None: fs = self._create_feature_store() @@ -952,30 +1095,6 @@ def test_list_feature_views_system_error(self) -> None: with self.assertRaisesRegex(RuntimeError, "Failed to find object"): fs.list_feature_views(entity_name="foo") - def test_create_and_cleanup_tags(self) -> None: - current_schema = create_random_schema(self._session, "TEST_CREATE_AND_CLEANUP_TAGS") - fs = FeatureStore( - self._session, - FS_INTEG_TEST_DB, - current_schema, - default_warehouse=self._test_warehouse_name, - creation_mode=CreationMode.CREATE_IF_NOT_EXIST, - ) - self.assertIsNotNone(fs) - - res = self._session.sql( - f"SHOW TAGS LIKE '{_FEATURE_VIEW_ENTITY_TAG}' IN SCHEMA {fs._config.full_schema_path}" - ).collect() - self.assertEqual(len(res), 1) - - self._session.sql(f"DROP SCHEMA IF EXISTS {FS_INTEG_TEST_DB}.{current_schema}").collect() - - row_list = self._session.sql( - f"SHOW TAGS LIKE '{_FEATURE_VIEW_ENTITY_TAG}' IN DATABASE {fs._config.database}" - ).collect() - for row in row_list: - self.assertNotEqual(row["schema_name"], current_schema) - def test_generate_dataset(self) -> None: fs = self._create_feature_store() @@ -1002,16 +1121,17 @@ def test_generate_dataset(self) -> None: fv2 = fs.register_feature_view(feature_view=fv2, version="v1") spine_df = self._session.create_dataframe([(1, 100), (1, 101)], schema=["id", "ts"]) - # Generate dataset the first time + # Generate dataset ds1 = fs.generate_dataset( spine_df=spine_df, features=[fv1, fv2], - materialized_table="foobar", + name="foobar", + version="test", spine_timestamp_col="ts", ) compare_dataframe( - actual_df=ds1.df.to_pandas(), + actual_df=ds1.read.to_pandas(), target_data={ "ID": [1, 1], "TS": [100, 101], @@ -1023,85 +1143,18 @@ def test_generate_dataset(self) -> None: ) self.assertEqual([fv1, fv2], fs.load_feature_views_from_dataset(ds1)) - # Re-generate dataset with same source should not cause any duplication - ds2 = fs.generate_dataset( - spine_df=spine_df, - features=[fv1, fv2], - materialized_table="foobar", - spine_timestamp_col="ts", - save_mode="merge", - ) - - compare_dataframe( - actual_df=ds2.df.to_pandas(), - target_data={ - "ID": [1, 1], - "TS": [100, 101], - "NAME": ["jonh", "jonh"], - "TITLE": ["boss", "boss"], - "AGE": [20, 20], - }, - sort_cols=["ID", "TS"], - ) - - # New data should properly appear - spine_df = self._session.create_dataframe([(2, 202)], schema=["id", "ts"]) - ds3 = fs.generate_dataset( - spine_df=spine_df, - features=[fv1, fv2], - materialized_table="foobar", - spine_timestamp_col="ts", - save_mode="merge", - ) - - compare_dataframe( - actual_df=ds3.df.to_pandas(), - target_data={ - "ID": [1, 1, 2], - "TS": [100, 101, 202], - "NAME": ["jonh", "jonh", "porter"], - "TITLE": ["boss", "boss", "manager"], - "AGE": [20, 20, 30], - }, - sort_cols=["ID", "TS"], - ) - - # Snapshot should remain the same - compare_dataframe( - actual_df=self._session.sql(f"SELECT * FROM {ds1.snapshot_table}").to_pandas(), - target_data={ - "ID": [1, 1], - "TS": [100, 101], - "NAME": ["jonh", "jonh"], - "TITLE": ["boss", "boss"], - "AGE": [20, 20], - }, - sort_cols=["ID", "TS"], - ) - compare_dataframe( - actual_df=self._session.sql(f"SELECT * FROM {ds3.snapshot_table}").to_pandas(), - target_data={ - "ID": [1, 1, 2], - "TS": [100, 101, 202], - "NAME": ["jonh", "jonh", "porter"], - "TITLE": ["boss", "boss", "manager"], - "AGE": [20, 20, 30], - }, - sort_cols=["ID", "TS"], - ) - # Generate dataset with exclude_columns and check both materialization and non-materialization path spine_df = self._session.create_dataframe([(1, 101), (2, 202)], schema=["id", "ts"]) ds4 = fs.generate_dataset( spine_df=spine_df, features=[fv1, fv2], - materialized_table="foobar2", + name="foobar2", spine_timestamp_col="ts", exclude_columns=["id", "ts"], ) compare_dataframe( - actual_df=ds4.df.to_pandas(), + actual_df=ds4.read.to_pandas(), target_data={ "NAME": ["jonh", "porter"], "TITLE": ["boss", "manager"], @@ -1115,9 +1168,10 @@ def test_generate_dataset(self) -> None: features=[fv1, fv2], spine_timestamp_col="ts", exclude_columns=["id", "ts"], + name="test_ds", ) compare_dataframe( - actual_df=ds5.df.to_pandas(), + actual_df=ds5.read.to_pandas(), target_data={ "NAME": ["jonh", "porter"], "TITLE": ["boss", "manager"], @@ -1127,23 +1181,22 @@ def test_generate_dataset(self) -> None: ) # Generate data should fail with errorifexists if table already exist - with self.assertRaisesRegex(ValueError, "Dataset table .* already exists."): + with self.assertRaisesRegex(ValueError, "already exists"): fs.generate_dataset( spine_df=spine_df, features=[fv1, fv2], - materialized_table="foobar", + name="foobar", + version="test", spine_timestamp_col="ts", - save_mode="errorifexists", ) - # registered table should fail with invalid char `.` - with self.assertRaisesRegex(ValueError, "materialized_table .* contains invalid char `.`"): + # Invalid dataset names should be rejected + with self.assertRaisesRegex(ValueError, "Invalid identifier"): fs.generate_dataset( spine_df=spine_df, features=[fv1, fv2], - materialized_table="foo.bar", + name=".bar", spine_timestamp_col="ts", - save_mode="errorifexists", ) # invalid columns in exclude_columns should fail @@ -1151,11 +1204,64 @@ def test_generate_dataset(self) -> None: fs.generate_dataset( spine_df=spine_df, features=[fv1, fv2], - materialized_table="foobar3", + name="foobar3", spine_timestamp_col="ts", exclude_columns=["foo"], ) + def test_generate_dataset_external_schema(self) -> None: + database_name = self._session.get_current_database() + schema_name = create_random_schema(self._session, "FS_TEST_EXTERNAL_SCHEMA", database=database_name) + fs = self._create_feature_store() + self.assertNotEqual(fs._config.schema, schema_name) + + e = Entity('"fOO"', ["id"]) + fs.register_entity(e) + + sql1 = f"SELECT id, name, title FROM {self._mock_table}" + fv1 = FeatureView( + name="fv1", + entities=[e], + feature_df=self._session.sql(sql1), + refresh_freq="DOWNSTREAM", + ) + fv1 = fs.register_feature_view(feature_view=fv1, version="v1", block=True) + + # Generate dataset on external schema + spine_df = self._session.create_dataframe([(1, 100), (1, 101)], schema=["id", "ts"]) + ds_name = "dataset_external_schema" + ds = fs.generate_dataset( + spine_df=spine_df, + features=[fv1], + name=f"{database_name}.{schema_name}.{ds_name}", + spine_timestamp_col="ts", + ) + + # Generated dataset should be in external schema + self.assertGreater(len(ds.read.files()), 0) + for file in ds.read.files(): + self.assertContainsExactSubsequence(file, f"{database_name}.{schema_name}.{ds_name}") + + ds_df = ds.read.to_snowpark_dataframe() + compare_dataframe( + actual_df=ds_df.to_pandas(), + target_data={ + "ID": [1, 1], + "TS": [100, 101], + "NAME": ["jonh", "jonh"], + "TITLE": ["boss", "boss"], + }, + sort_cols=["ID", "TS"], + ) + + # Fail on non-existent schema + with self.assertRaisesRegex(RuntimeError, "does not exist"): + fs.generate_dataset( + spine_df=spine_df, + features=[fv1], + name="NONEXISTENT_SCHEMA.foobar", + ) + def test_clear_feature_store_in_existing_schema(self) -> None: current_schema = create_random_schema(self._session, "TEST_CLEAR_FEATURE_STORE_IN_EXISTING_SCHEMA") @@ -1191,21 +1297,19 @@ def test_clear_feature_store_in_existing_schema(self) -> None: fs.generate_dataset( spine_df=spine_df, features=[fv], - materialized_table="foo_mt", + name="foo_mt", spine_timestamp_col="ts", - save_mode="errorifexists", ) def check_fs_objects(expected_count: int) -> None: result = self._session.sql(f"SHOW DYNAMIC TABLES LIKE 'FV$V1' IN SCHEMA {full_schema_path}").collect() self.assertEqual(len(result), expected_count) - result = self._session.sql(f"SHOW TABLES LIKE 'foo_mt' IN SCHEMA {full_schema_path}").collect() + result = self._session.sql(f"SHOW DATASETS LIKE 'foo_mt' IN SCHEMA {full_schema_path}").collect() self.assertEqual(len(result), expected_count) result = self._session.sql(f"SHOW TASKS LIKE 'FV$V1' IN SCHEMA {full_schema_path}").collect() self.assertEqual(len(result), expected_count) expected_tags = [ - _FEATURE_VIEW_ENTITY_TAG, - _FEATURE_VIEW_TS_COL_TAG, + _FEATURE_VIEW_METADATA_TAG, _FEATURE_STORE_OBJECT_TAG, f"{_ENTITY_TAG_PREFIX}foo", ] @@ -1380,7 +1484,6 @@ def create_fvs(fs: FeatureStore, sql: str, overwrite: bool) -> Tuple[FeatureView }, sort_cols=["ID"], ) - compare_feature_views(fs.list_feature_views(as_dataframe=False), [fv1, fv2, fv3]) # Replace existing feature views sql = f"SELECT id, name, title FROM {self._mock_table}" @@ -1395,7 +1498,6 @@ def create_fvs(fs: FeatureStore, sql: str, overwrite: bool) -> Tuple[FeatureView }, sort_cols=["ID"], ) - compare_feature_views(fs.list_feature_views(as_dataframe=False), [fv1, fv2, fv3]) # Replace non-existing feature view non_existing_fv = FeatureView( @@ -1458,7 +1560,7 @@ def test_generate_dataset_point_in_time_join(self) -> None: refresh_freq=None, ) - customers_fv = fs.register_feature_view(feature_view=customers_fv, version="v1") + customers_fv = fs.register_feature_view(feature_view=customers_fv, version="V1") spine_df = self._session.create_dataframe( [ @@ -1474,12 +1576,12 @@ def test_generate_dataset_point_in_time_join(self) -> None: dataset = fs.generate_dataset( spine_df=spine_df, features=[customers_fv], - materialized_table="customer_frad_training_data", + name="customer_frad_training_data", spine_timestamp_col="EVENT_TS", spine_label_cols=[], include_feature_view_timestamp_col=True, ) - actual_df = dataset.df.to_pandas() + actual_df = dataset.read.to_pandas() actual_df["CUSTOMER_FV_V1_FEATURE_TS"] = actual_df["CUSTOMER_FV_V1_FEATURE_TS"].dt.date # CUST_AVG_AMOUNT_7 and CUST_AVG_AMOUNT_30 are expected to be same as the values @@ -1502,9 +1604,10 @@ def test_generate_dataset_point_in_time_join(self) -> None: sort_cols=["CUSTOMER_ID"], ) - def test_cross_feature_store_interop(self) -> None: + @parameterized.parameters(True, False) # type: ignore[misc] + def test_cross_feature_store_interop(self, use_optimized_tag_ref: bool) -> None: # create first feature store and register feature views - first_fs = self._create_feature_store() + first_fs = self._create_feature_store(use_optimized_tag_ref=use_optimized_tag_ref) first_entity = Entity("foo", ["id"]) first_fs.register_entity(first_entity) @@ -1541,9 +1644,10 @@ def test_cross_feature_store_interop(self) -> None: spine_df=spine_df, features=[first_fv, second_fv], spine_timestamp_col="ts", + name="test_ds", ) compare_dataframe( - actual_df=ds.df.to_pandas(), + actual_df=ds.read.to_pandas(), target_data={ "ID": [1], "TS": [101], @@ -1587,10 +1691,11 @@ def test_generate_dataset_left_join(self) -> None: ds = fs.generate_dataset( spine_df=spine_df, features=[fv1, fv2], + name="test_ds", ) compare_dataframe( - actual_df=ds.df.to_pandas(), + actual_df=ds.read.to_pandas(), target_data={ "ID": [1, 2, 3], "NAME": ["jonh", "porter", "johnny"], diff --git a/tests/integ/snowflake/ml/fileset/snowfs_integ_test.py b/tests/integ/snowflake/ml/fileset/snowfs_integ_test.py index c120c5bb..6d2a03fa 100644 --- a/tests/integ/snowflake/ml/fileset/snowfs_integ_test.py +++ b/tests/integ/snowflake/ml/fileset/snowfs_integ_test.py @@ -1,3 +1,5 @@ +from uuid import uuid4 + import fsspec from absl.testing import absltest @@ -17,7 +19,7 @@ class TestSnowFileSystem(absltest.TestCase): schema = snowpark_session.get_current_schema() domain = "dataset" - entity = f"{db}.{schema}.snowfs_fbe_integ" + entity = f"{db}.{schema}.snowfs_fbe_integ_{uuid4().hex}" version1 = "version1" version2 = "version2" row_counts = { @@ -27,7 +29,6 @@ class TestSnowFileSystem(absltest.TestCase): @classmethod def setUpClass(cls) -> None: - cls.snowpark_session.sql("ALTER SESSION SET ENABLE_DATASET=true").collect() cls.snowpark_session.sql(f"CREATE OR REPLACE {cls.domain} {cls.entity}").collect() _create_file_based_entity( cls.snowpark_session, "dataset", cls.entity, cls.version1, row_count=cls.row_counts[cls.version1], seed=42 @@ -38,7 +39,8 @@ def setUpClass(cls) -> None: cls.files = [ f"{cls.domain}/{cls.entity}/{r['name']}" - for r in cls.snowpark_session.sql(f"LIST 'snow://{cls.domain}/{cls.entity}/versions'").collect() + for version in [cls.version1, cls.version2] + for r in cls.snowpark_session.sql(f"LIST 'snow://{cls.domain}/{cls.entity}/versions/{version}'").collect() ] assert len(cls.files) > 0, "LIST returned no files" diff --git a/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py b/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py index af082029..a2d11db4 100644 --- a/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py +++ b/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py @@ -131,6 +131,32 @@ def test_tag(self) -> None: self._model.unset_tag(self._tag_name1) self.assertDictEqual({}, self._model.show_tags()) + def test_rename(self) -> None: + model, test_features, _ = model_factory.ModelFactory.prepare_sklearn_model() + self.registry.log_model( + model=model, + model_name="MODEL", + version_name="V1", + sample_input_data=test_features, + ) + model = self.registry.get_model(model_name="MODEL") + model.rename("MODEL2") + self.assertEqual(model.name, "MODEL2") + self.registry.delete_model("MODEL2") + + def test_rename_fully_qualified_name(self) -> None: + model, test_features, _ = model_factory.ModelFactory.prepare_sklearn_model() + self.registry.log_model( + model=model, + model_name="MODEL", + version_name="V1", + sample_input_data=test_features, + ) + model = self.registry.get_model(model_name="MODEL") + model.rename(f"{self._test_db}.{self._test_schema}.MODEL2") + self.assertEqual(model.name, "MODEL2") + self.registry.delete_model("MODEL2") + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py b/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py index acc8b876..bacd5e15 100644 --- a/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py +++ b/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py @@ -1,7 +1,13 @@ +import glob +import os +import tempfile import uuid +import numpy as np from absl.testing import absltest, parameterized +from sklearn import svm +from snowflake.ml.model import ExportMode from snowflake.ml.registry import registry from snowflake.ml.utils import connection_params from snowflake.snowpark import Session @@ -36,12 +42,12 @@ def setUpClass(self) -> None: self._db_manager.cleanup_databases(expire_hours=6) self.registry = registry.Registry(self._session) - model, test_features, _ = model_factory.ModelFactory.prepare_sklearn_model() + self.model, self.test_features, _ = model_factory.ModelFactory.prepare_sklearn_model() self._mv = self.registry.log_model( - model=model, + model=self.model, model_name=MODEL_NAME, version_name=VERSION_NAME, - sample_input_data=test_features, + sample_input_data=self.test_features, ) @classmethod @@ -69,6 +75,20 @@ def test_metrics(self) -> None: with self.assertRaises(KeyError): self._mv.get_metric("b") + def test_export(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + self._mv.export(tmpdir) + self.assertLen(list(glob.iglob(os.path.join(tmpdir, "**", "*"), recursive=True)), 14) + + with tempfile.TemporaryDirectory() as tmpdir: + self._mv.export(tmpdir, export_mode=ExportMode.FULL) + self.assertLen(list(glob.iglob(os.path.join(tmpdir, "**", "*"), recursive=True)), 27) + + def test_load(self) -> None: + loaded_model = self._mv.load() + assert isinstance(loaded_model, svm.SVC) + np.testing.assert_allclose(loaded_model.predict(self.test_features), self.model.predict(self.test_features)) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/modeling/impute/simple_imputer_test.py b/tests/integ/snowflake/ml/modeling/impute/simple_imputer_test.py index dbf9e24b..76892f74 100644 --- a/tests/integ/snowflake/ml/modeling/impute/simple_imputer_test.py +++ b/tests/integ/snowflake/ml/modeling/impute/simple_imputer_test.py @@ -455,6 +455,37 @@ def test_transform_sklearn(self) -> None: np.testing.assert_allclose(transformed_df[output_cols].to_numpy(), sklearn_transformed_dataset) + def test_transform_sklearn_categorical(self) -> None: + """ + Verify transform of pandas dataframe with categorical columns. + + Raises + ------ + AssertionErrord + If the transformed result does not match the one generated by Sklearn. + """ + input_cols = CATEGORICAL_COLS + output_cols = OUTPUT_COLS + df_pandas, df = framework_utils.get_df(self._session, DATA, SCHEMA) + + simple_imputer = SimpleImputer( + input_cols=input_cols, + output_cols=output_cols, + strategy="constant", + fill_value="missing_value", + missing_values=None, + ) + simple_imputer.fit(df) + + df_none_nan_pandas, df_none_nan = framework_utils.get_df(self._session, DATA_NONE_NAN, SCHEMA) + transformed_df = simple_imputer.transform(df_none_nan_pandas) + + sklearn_simple_imputer = SklearnSimpleImputer(strategy="constant", missing_values=None) + sklearn_simple_imputer.fit(df_pandas[input_cols]) + sklearn_transformed_dataset = sklearn_simple_imputer.transform(df_none_nan_pandas[input_cols]) + + np.testing.assert_equal(transformed_df[output_cols].to_numpy(), sklearn_transformed_dataset) + def test_transform_sklearn_constant_string(self) -> None: """ Verify imputed data using a string constant. diff --git a/tests/integ/snowflake/ml/modeling/preprocessing/one_hot_encoder_test.py b/tests/integ/snowflake/ml/modeling/preprocessing/one_hot_encoder_test.py index 715c353c..92878151 100644 --- a/tests/integ/snowflake/ml/modeling/preprocessing/one_hot_encoder_test.py +++ b/tests/integ/snowflake/ml/modeling/preprocessing/one_hot_encoder_test.py @@ -112,7 +112,9 @@ def map_output(x: Optional[Dict[str, Any]]) -> Optional[List[int]]: base_encoding = 0 for idx in range(1, len(mapped_pandas.columns)): base_encoding += mapped_pandas.iloc[0, idx - 1][1] - mapped_pandas.iloc[:, idx] = mapped_pandas.iloc[:, idx].apply(lambda x: [x[0] + base_encoding, x[1]]) + mapped_pandas.iloc[:, idx] = mapped_pandas.iloc[:, idx].apply( + lambda x: [x[0] + base_encoding, x[1]] # noqa: B023 + ) values = [] for row_idx, row in mapped_pandas.iterrows(): @@ -1770,6 +1772,23 @@ def test_column_insensitivity(self) -> None: ohe = OneHotEncoder(input_cols=lower_cols, output_cols=cols, sparse=False).fit(snow_df) ohe.transform(snow_df) + def test_fit_pd_transform_sp(self) -> None: + pd_data = pd.read_csv(TEST_DATA_PATH, index_col=0) + snow_df = self._session.create_dataframe(pd_data) + cols = [ + "AGE", + "CAMPAIGN", + "CONTACT", + "DAY_OF_WEEK", + "EDUCATION", + "JOB", + "MONTH", + "DURATION", + ] + + ohe = OneHotEncoder(input_cols=cols, output_cols=cols, sparse=False).fit(pd_data) + ohe.transform(snow_df) + if __name__ == "__main__": main() diff --git a/tests/integ/snowflake/ml/registry/BUILD.bazel b/tests/integ/snowflake/ml/registry/BUILD.bazel index 71f14eea..5486c3b8 100644 --- a/tests/integ/snowflake/ml/registry/BUILD.bazel +++ b/tests/integ/snowflake/ml/registry/BUILD.bazel @@ -4,7 +4,6 @@ py_test( name = "model_registry_basic_integ_test", srcs = ["model_registry_basic_integ_test.py"], deps = [ - "//snowflake/ml/registry:artifact_manager", "//snowflake/ml/registry:model_registry", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/test_utils:db_manager", diff --git a/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py index 626584af..d04ff0db 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_modeling_model_test.py @@ -1,11 +1,19 @@ +import os +import posixpath + import numpy as np +import yaml from absl.testing import absltest from sklearn import datasets +from snowflake.ml import dataset +from snowflake.ml.model._model_composer import model_composer from snowflake.ml.modeling.lightgbm import LGBMRegressor from snowflake.ml.modeling.linear_model import LogisticRegression from snowflake.ml.modeling.xgboost import XGBRegressor +from snowflake.snowpark import types as T from tests.integ.snowflake.ml.registry.model import registry_model_test_base +from tests.integ.snowflake.ml.test_utils import test_env_utils class TestRegistryModelingModelInteg(registry_model_test_base.RegistryModelTestBase): @@ -84,6 +92,80 @@ def test_snowml_model_deploy_lightgbm( }, ) + def test_dataset_to_model_lineage(self) -> None: + iris_X = datasets.load_iris(as_frame=True).frame + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] + + INPUT_COLUMNS = ["SEPALLENGTH", "SEPALWIDTH", "PETALLENGTH", "PETALWIDTH"] + LABEL_COLUMNS = "TARGET" + OUTPUT_COLUMNS = "PREDICTED_TARGET" + regr = LogisticRegression(input_cols=INPUT_COLUMNS, output_cols=OUTPUT_COLUMNS, label_cols=LABEL_COLUMNS) + schema = [ + T.StructField("SEPALLENGTH", T.DoubleType()), + T.StructField("SEPALWIDTH", T.DoubleType()), + T.StructField("PETALLENGTH", T.DoubleType()), + T.StructField("PETALWIDTH", T.DoubleType()), + T.StructField("TARGET", T.StringType()), + T.StructField("PREDICTED_TARGET", T.StringType()), + ] + test_features_df = self._session.create_dataframe(iris_X, schema=schema) + + test_features_dataset = dataset.create_from_dataframe( + session=self._session, + name="trainDataset", + version="v1", + input_dataframe=test_features_df, + ) + + test_df = test_features_dataset.read.to_snowpark_dataframe() + + regr.fit(test_df) + + # Case 1 : test generation of MANIFEST.yml file + + model_name = "some_name" + tmp_stage_path = posixpath.join(self._session.get_session_stage(), f"{model_name}_{1}") + conda_dependencies = [ + test_env_utils.get_latest_package_version_spec_in_server(self._session, "snowflake-snowpark-python!=1.12.0") + ] + mc = model_composer.ModelComposer(self._session, stage_path=tmp_stage_path) + + mc.save( + name=model_name, + model=regr, + signatures=None, + sample_input_data=None, + conda_dependencies=conda_dependencies, + metadata={"author": "rsureshbabu", "version": "1"}, + options={"relax_version": False}, + ) + + with open(os.path.join(tmp_stage_path, mc._workspace.name, "MANIFEST.yml"), encoding="utf-8") as f: + yaml_content = yaml.safe_load(f) + assert "lineage_sources" in yaml_content + assert isinstance(yaml_content["lineage_sources"], list) + assert len(yaml_content["lineage_sources"]) == 1 + + source = yaml_content["lineage_sources"][0] + assert isinstance(source, dict) + assert source.get("type") == "DATASET" + assert source.get("entity") == f"{test_features_dataset.fully_qualified_name}" + assert source.get("version") == f"{test_features_dataset._version.name}" + + # Case 2 : test remaining life cycle. + self._test_registry_model( + model=regr, + prediction_assert_fns={ + "predict": ( + iris_X, + lambda res: lambda res: np.testing.assert_allclose( + res[OUTPUT_COLUMNS].values, regr.predict(iris_X)[OUTPUT_COLUMNS].values + ), + ), + }, + additional_dependencies=["fsspec", "aiohttp", "cryptography"], + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model_registry_basic_integ_test.py b/tests/integ/snowflake/ml/registry/model_registry_basic_integ_test.py index 7f9ce834..3728da6b 100644 --- a/tests/integ/snowflake/ml/registry/model_registry_basic_integ_test.py +++ b/tests/integ/snowflake/ml/registry/model_registry_basic_integ_test.py @@ -4,7 +4,6 @@ from absl.testing import absltest, parameterized from snowflake.ml.registry import model_registry -from snowflake.ml.registry.artifact import Artifact, ArtifactType from snowflake.ml.utils import connection_params from snowflake.snowpark import Session from tests.integ.snowflake.ml.test_utils import db_manager @@ -167,68 +166,6 @@ def test_create_and_drop_model_registry(self, database_name: str, schema_name: O self.assertTrue(self._db_manager.assert_database_existence(database_name, exists=False)) self._validate_restore_db_and_schema() - def test_add_and_delete_ml_artifacts(self) -> None: - """Test add() and delete() in `_artifact_manager.py` works as expected.""" - - artifact_registry = db_manager.TestObjectNameGenerator.get_snowml_test_object_name( - _RUN_ID, "artifact_registry" - ).upper() - artifact_registry_schema = "PUBLIC" - - try: - model_registry.create_model_registry( - session=self._session, database_name=artifact_registry, schema_name=artifact_registry_schema - ) - registry = model_registry.ModelRegistry( - session=self._session, database_name=artifact_registry, schema_name=artifact_registry_schema - ) - except Exception as e: - self._db_manager.drop_database(artifact_registry) - raise Exception(f"Test failed with exception:{e}") - - artifact_id = "test_art_123" - artifact_version = "test_artifact_version" - artifact_name = "test_artifact" - artifact = Artifact(type=ArtifactType.DATASET, spec='{"test_property": "test_value"}') - - try: - art_ref = registry._artifact_manager.add( - artifact=artifact, - artifact_id=artifact_id, - artifact_name=artifact_name, - artifact_version=artifact_version, - ) - - self.assertTrue( - registry._artifact_manager.exists( - art_ref.name, - art_ref.version, - ) - ) - - # Validate the artifact_spec can be parsed as expected - retrieved_art_df = registry._artifact_manager.get( - art_ref.name, - art_ref.version, - ) - - actual_artifact_spec = retrieved_art_df.collect()[0]["ARTIFACT_SPEC"] - self.assertEqual(artifact._spec, actual_artifact_spec) - - # Validate that `delete_artifact` can remove entries from the artifact table. - registry._artifact_manager.delete( - art_ref.name, - art_ref.version, - ) - self.assertFalse( - registry._artifact_manager.exists( - art_ref.name, - art_ref.version, - ) - ) - finally: - self._db_manager.drop_database(artifact_registry, if_exists=True) - if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/registry/model_registry_integ_test.py b/tests/integ/snowflake/ml/registry/model_registry_integ_test.py index 2ae50924..2e7c6dad 100644 --- a/tests/integ/snowflake/ml/registry/model_registry_integ_test.py +++ b/tests/integ/snowflake/ml/registry/model_registry_integ_test.py @@ -7,9 +7,7 @@ from sklearn import metrics from snowflake import connector -from snowflake.ml.dataset import dataset from snowflake.ml.registry import model_registry -from snowflake.ml.registry.artifact import ArtifactType from snowflake.ml.utils import connection_params from snowflake.snowpark import Session from tests.integ.snowflake.ml.test_utils import ( @@ -363,97 +361,6 @@ def test_snowml_pipeline(self) -> None: local_prediction.to_pandas().astype(dtype={"OUTPUT_TARGET": np.float64}), ) - def test_log_model_with_dataset(self) -> None: - registry = model_registry.ModelRegistry(session=self._session, database_name=self.registry_name) - - model_name = "snowml_test_dataset" - model_version = self.run_id - model, test_features, dataset_df = model_factory.ModelFactory.prepare_snowml_model_xgb() - - dummy_materialized_table_full_path = f"{registry._fully_qualified_schema_name()}.DUMMY_MATERIALIZED_TABLE" - dummy_snapshot_table_full_path = f"{dummy_materialized_table_full_path}_SNAPSHOT" - self._session.create_dataframe(dataset_df).write.mode("overwrite").save_as_table( - f"{dummy_materialized_table_full_path}" - ) - self._session.create_dataframe(dataset_df).write.mode("overwrite").save_as_table( - f"{dummy_snapshot_table_full_path}" - ) - - spine_query = f"SELECT * FROM {dummy_materialized_table_full_path}" - - fs_metadata = dataset.FeatureStoreMetadata( - spine_query=spine_query, - connection_params={ - "database": "test_db", - "schema": "test_schema", - "default_warehouse": "test_warehouse", - }, - features=[], - ) - dummy_dataset = dataset.Dataset( - self._session, - df=self._session.sql(spine_query), - materialized_table=dummy_materialized_table_full_path, - snapshot_table=dummy_snapshot_table_full_path, - timestamp_col="ts", - label_cols=["TARGET"], - feature_store_metadata=fs_metadata, - desc="a dummy dataset metadata", - ) - cur_user = self._session.sql("SELECT CURRENT_USER()").collect()[0]["CURRENT_USER()"] - self.assertEqual(dummy_dataset.owner, cur_user) - self.assertIsNone(dummy_dataset.name) - self.assertIsNotNone(dummy_dataset.generation_timestamp) - - minimal_dataset = dataset.Dataset( - self._session, - df=self._session.sql(spine_query), - ) - self.assertEqual(minimal_dataset.owner, cur_user) - self.assertIsNone(minimal_dataset.name) - self.assertIsNone(minimal_dataset.version) - self.assertIsNotNone(minimal_dataset.generation_timestamp) - - test_combinations = [ - (model_version, dummy_dataset), - (f"{model_version}.2", dummy_dataset), - (f"{model_version}.3", minimal_dataset), - ] - for version, ds in test_combinations: - atf_ref = registry.log_artifact( - artifact=ds, - name=f"ds_{version}", - version=f"{version}.ds", - ) - self.assertEqual(atf_ref.name, f"ds_{version}") - self.assertEqual(atf_ref.version, f"{version}.ds") - - registry.log_model( - model_name=model_name, - model_version=version, - model=model, - conda_dependencies=[ - test_env_utils.get_latest_package_version_spec_in_server( - self._session, "snowflake-snowpark-python!=1.12.0" - ) - ], - options={"embed_local_ml_library": True}, - artifacts=[atf_ref], - ) - - # test deserialized dataset from get_artifact - des_ds_0 = registry.get_artifact(atf_ref.name, atf_ref.version) - self.assertIsNotNone(des_ds_0) - self.assertEqual(des_ds_0, ds) - - # test deserialized dataset from list_artifacts - rows_list = registry.list_artifacts(model_name, version).collect() - self.assertEqual(len(rows_list), 1) - self.assertEqual(rows_list[0]["ID"], des_ds_0._id) - self.assertEqual(ArtifactType[rows_list[0]["TYPE"]], ArtifactType.DATASET) - des_ds_1 = dataset.Dataset.from_json(rows_list[0]["ARTIFACT_SPEC"], self._session) - self.assertEqual(des_ds_1, ds) - if __name__ == "__main__": absltest.main()