diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e487a89..ff975b24 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ --- -exclude: ^(.*egg.info.*|.*/parameters.py$|.*\.py_template|.*/experimental/.*|docs/source/_themes/.*) +exclude: ^(.*egg.info.*|.*/parameters.py$|.*\.py_template|.*/experimental/.*|.*/fixtures/.*|docs/source/_themes/.*) minimum_pre_commit_version: 3.4.0 repos: - repo: https://github.com/asottile/pyupgrade diff --git a/CHANGELOG.md b/CHANGELOG.md index d70fb5ba..da6550e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,32 @@ # Release History +## 1.2.0 + +### Bug Fixes + +- Model Registry: Fix "XGBoost version not compiled with GPU support" error when running CPU inference against open-source + XGBoost models deployed to SPCS. +- Model Registry: Fix model deployment to SPCS on Windows machines. + +### Behavior Changes + +### New Features + +- Model Development: Introduced XGBoost external memory training feature. This feature enables training XGBoost models + on large datasets that don't fit into memory. +- Registry: New Registry class named `snowflake.ml.registry.Registry` providing similar APIs as the old one but works + with new MODEL object in Snowflake SQL. Also, we are providing`snowflake.ml.model.Model` and + `snowflake.ml.model.ModelVersion` to represent a model and a specific version of a model. +- Model Development: Add support for `fit_predict` method in `AgglomerativeClustering`, `DBSCAN`, and `OPTICS` classes; +- Model Development: Add support for `fit_transform` method in `MDS`, `SpectralEmbedding` and `TSNE` class. + +### Additional Notes + +- Model Registry: The `snowflake.ml.registry.model_registry.ModelRegistry` has been deprecated starting from version +1.2.0. It will stay in the Private Preview phase. For future implementations, kindly utilize +`snowflake.ml.registry.Registry`, except when specifically required. The old model registry will be removed once all +its primary functionalities are fully integrated into the new registry. + ## 1.1.2 ### Bug Fixes diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2a4074d0..7df6ff5a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -117,6 +117,9 @@ You can build an entire sub-tree as: ) ``` ++ If the visibility of the target is not `//visibility:public`, you need to make sure your target is visible to + `//bazel:snowml_public_common` to make sure CI type checking work. + ### Type-check #### mypy diff --git a/README.md b/README.md index d4acf9c2..1610d8b8 100644 --- a/README.md +++ b/README.md @@ -23,20 +23,17 @@ model development classes based on sklearn, xgboost, and lightgbm. 1. Framework Connectors: Optimized, secure and performant data provisioning for Pytorch and Tensorflow frameworks in their native data loader formats. -### Snowpark ML Ops [Private Preview] - -[Snowpark MLOps](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowpark-ml-ops) complements the -Snowpark ML Development API, and provides model management capabilities along with integrated deployment into Snowflake. -Currently, the API consists of: - 1. FileSet API: FileSet provides a Python fsspec-compliant API for materializing data into a Snowflake internal stage from a query or Snowpark Dataframe along with a number of convenience APIs. -1. Model Registry: A python API for managing models within Snowflake which also supports deployment of ML models into -Snowflake Warehouses as vectorized UDFs. +### Snowpark Model Management [Public Preview] + +[Snowpark Model Management](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowpark-ml-ops) complements +the Snowpark ML Development API, and provides model management capabilities along with integrated deployment into Snowflake. +Currently, the API consists of: -During PrPr, we are iterating on API without backward compatibility guarantees. It is better to recreate your registry -everytime you update the package. This means, at this time, you cannot use the registry for production use. +1. Registry: A python API for managing models within Snowflake which also supports deployment of ML models into Snowflake +as native MODEL object running with Snowflake Warehouse. ## Getting started diff --git a/bazel/BUILD.bazel b/bazel/BUILD.bazel index 7f299669..5b29739a 100644 --- a/bazel/BUILD.bazel +++ b/bazel/BUILD.bazel @@ -14,3 +14,13 @@ sh_binary( name = "test_wrapper", srcs = ["test_wrapper.sh"], ) + +# Package group for common targets in the repo. +package_group( + name = "snowml_public_common", + packages = [ + "//bazel/...", + "//ci/...", + "//docs/...", + ], +) diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index 00b04f5c..3c244f9d 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -17,7 +17,7 @@ build: noarch: python package: name: snowflake-ml-python - version: 1.1.2 + version: 1.2.0 requirements: build: - python diff --git a/ci/type_check.sh b/ci/type_check.sh index 5617aee6..927db9bb 100755 --- a/ci/type_check.sh +++ b/ci/type_check.sh @@ -1,92 +1,10 @@ #!/bin/bash -# Usage -# type_check.sh [-a] [-b ] -# -# Flags -# -a: check all targets (excluding the exempted ones). -# -b: specify path to bazel. -# -# Inputs -# - ci/skip_type_checking_targets : a list of target patterns against which -# typechecking should be enforced. -# -# Action -# - Performs typechecking against the intersection of -# type checked targets and affected targets. -# Exit code: -# 0 if succeeds. No target to check means success. -# 1 if there is an error in parsing commandline flag. -# Otherwise exits with bazel's exit code. -# -# NOTE: -# 1. Ignores all targets that depends on targets in `skip_type_checking_targets`. -# 2. Affected targets also include raw python files on top of bazel build targets whereas ignored_targets don't. Hence -# we used `kind('py_.* rule')` filter. - -set -o pipefail -set -u -set -e - -bazel="bazel" -affected_targets="" +# Just an alias to avoid break Jenkins PROG=$0 -help() { - local exit_code=$1 - echo "Usage: ${PROG} [-a] [-b ]" - exit "${exit_code}" -} - -while getopts "ab:h" opt; do - case "${opt}" in - a) - affected_targets="//..." - ;; - b) - bazel="${OPTARG}" - ;; - h) - help 0 - ;; - :) - help 1 - ;; - ?) - help 1 - ;; - esac -done - -echo "Using bazel: " "${bazel}" -working_dir=$(mktemp -d "/tmp/tmp_XXXXX") -trap 'rm -rf "${working_dir}"' EXIT - -if [[ -z "${affected_targets}" ]]; then - affected_targets_file="${working_dir}/affected_targets" - ./bazel/get_affected_targets.sh -b "${bazel}" -f "${affected_targets_file}" - - affected_targets="$(<"${affected_targets_file}")" -fi - -printf \ - "let skip_type_checking_targets = set(%s) + set(%s) + set(%s) in \ - let affected_targets = kind('py_.* rule', set(%s)) in \ - let rdeps_targets = rdeps(//..., \$skip_type_checking_targets) in \ - \$affected_targets except \$rdeps_targets" \ - "$("${working_dir}/type_checked_targets_query" -"${bazel}" query --query_file="${working_dir}/type_checked_targets_query" >"${working_dir}/type_checked_targets" -echo "Type checking the following targets:" "$(<"${working_dir}/type_checked_targets")" - -set +e -"${bazel}" build \ - --keep_going \ - --config=typecheck \ - --color=yes \ - --target_pattern_file="${working_dir}/type_checked_targets" -bazel_exit_code=$? +SCRIPT=$(readlink -f "$PROG") +# Absolute path this script is in, thus /home/user/bin +SCRIPTPATH=$(dirname "$SCRIPT") -if [[ $bazel_exit_code -eq 0 || $bazel_exit_code -eq 4 ]]; then - exit 0 -fi -exit ${bazel_exit_code} +"${SCRIPTPATH}/type_check/type_check.sh" "$@" diff --git a/ci/type_check/BUILD.bazel b/ci/type_check/BUILD.bazel new file mode 100644 index 00000000..e69de29b diff --git a/ci/type_check/runner/.gitignore b/ci/type_check/runner/.gitignore new file mode 100644 index 00000000..819ae9cc --- /dev/null +++ b/ci/type_check/runner/.gitignore @@ -0,0 +1 @@ +BUILD.bazel diff --git a/ci/type_check/type_check.sh b/ci/type_check/type_check.sh new file mode 100755 index 00000000..1c8b2d1f --- /dev/null +++ b/ci/type_check/type_check.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +# Usage +# type_check.sh [-a] [-b ] +# +# Flags +# -a: check all targets (excluding the exempted ones). +# -b: specify path to bazel. +# +# Inputs +# - ci/skip_type_checking_targets : a list of target patterns against which +# typechecking should be enforced. +# +# Action +# - Create a mypy_test targets to type check all affected targets +# Exit code: +# 0 if succeeds. No target to check means success. +# 1 if there is an error in parsing commandline flag. +# Otherwise exits with bazel's exit code. +# +# NOTE: +# 1. Ignores all targets that depends on targets in `skip_type_checking_targets`. +# 2. Affected targets also include raw python files on top of bazel build targets whereas ignored_targets don't. Hence +# we used `kind('py_.* rule')` filter. + +set -o pipefail +set -u +set -e + +bazel="bazel" +affected_targets="" +PROG=$0 + +SCRIPT=$(readlink -f "$PROG") +# Absolute path this script is in, thus /home/user/bin +SCRIPTPATH=$(dirname "$SCRIPT") + +help() { + local exit_code=$1 + echo "Usage: ${PROG} [-a] [-b ]" + exit "${exit_code}" +} + +while getopts "ab:h" opt; do + case "${opt}" in + a) + affected_targets="//..." + ;; + b) + bazel="${OPTARG}" + ;; + h) + help 0 + ;; + :) + help 1 + ;; + ?) + help 1 + ;; + esac +done + +echo "Using bazel: " "${bazel}" +working_dir=$(mktemp -d "/tmp/tmp_XXXXX") +trap 'rm -rf "${working_dir}"' EXIT +trap 'rm -rf "${SCRIPTPATH}/runner/BUILD.bazel"' EXIT + +if [[ -z "${affected_targets}" ]]; then + affected_targets_file="${working_dir}/affected_targets" + ./bazel/get_affected_targets.sh -b "${bazel}" -f "${affected_targets_file}" + + affected_targets="$(<"${affected_targets_file}")" +fi + +printf \ + "let skip_type_checking_targets = set(%s) + set(%s) + set(%s) in \ + let affected_targets = kind('py_.* rule', set(%s)) in \ + let rdeps_targets = rdeps(//..., \$skip_type_checking_targets) in \ + \$affected_targets except \$rdeps_targets" \ + "$(<"${SCRIPTPATH}/../skip_type_checking_targets")" "$(<"${SCRIPTPATH}/../skip_merge_gate_targets")" "$(<"${SCRIPTPATH}/../skip_continuous_run_targets")" "${affected_targets}" >"${working_dir}/type_checked_targets_query" +type_check_targets=$("${bazel}" query --query_file="${working_dir}/type_checked_targets_query" | awk 'NF { print "\""$0"\","}') + +echo "${type_check_targets}" + +if [[ -z "${type_check_targets}" ]]; then + echo "No target to do the type checking. Bye!" + exit 0 +fi + +cat >"${SCRIPTPATH}/runner/BUILD.bazel" < bool: """ return class_object[1].__module__ == "sklearn.preprocessing._data" + @staticmethod + def _is_manifold_module_obj(class_object: Tuple[str, type]) -> bool: + """Check if the given class belongs to the SKLearn manifold module. + + Args: + class_object: Meta class object which needs to be checked. + + Returns: + True if the class belongs to `sklearn.manifold` module, otherwise False. + """ + return class_object[1].__module__.startswith("sklearn.manifold") + @staticmethod def _is_multioutput_obj(class_object: Tuple[str, type]) -> bool: """Check if the given estimator can learn and predict multiple labels (multi-label not multi-class) @@ -548,6 +581,7 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None: self.original_predict_docstring = "" self.predict_docstring = "" self.fit_predict_docstring = "" + self.fit_transform_docstring = "" self.predict_proba_docstring = "" self.score_docstring = "" self.predict_log_proba_docstring = "" @@ -570,6 +604,7 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None: # Optional function support self.fit_predict_cluster_function_support = False + self.fit_transform_manifold_function_support = False # Dependencies self.predict_udf_deps = "" @@ -608,6 +643,7 @@ def _format_default_type(self, default_value: Any) -> str: def _populate_flags(self) -> None: self._from_data_py = WrapperGeneratorFactory._is_data_module_obj(self.class_object) + self._is_manifold = WrapperGeneratorFactory._is_manifold_module_obj(self.class_object) self._is_regressor = WrapperGeneratorFactory._is_regressor_obj(self.class_object) self._is_classifier = WrapperGeneratorFactory._is_classifier_obj(self.class_object) self._is_meta_estimator = WrapperGeneratorFactory._is_meta_estimator_obj(self.class_object) @@ -630,6 +666,7 @@ def _populate_flags(self) -> None: self._is_grid_search_cv = WrapperGeneratorFactory._is_grid_search_cv(self.class_object) self._is_randomized_search_cv = WrapperGeneratorFactory._is_randomized_search_cv(self.class_object) self._is_iterative_imputer = WrapperGeneratorFactory._is_iterative_imputer(self.class_object) + self._is_xgboost = WrapperGeneratorFactory._is_xgboost(self.module_name) def _populate_import_statements(self) -> None: self.estimator_imports_list.append("import numpy") @@ -786,6 +823,18 @@ def _populate_function_names_and_signatures(self) -> None: signature_lines.append("sample_weight_col: Optional[str] = None") init_member_args.append("self.set_sample_weight_col(sample_weight_col)") + if self._is_xgboost: + signature_lines.append("use_external_memory_version: bool = False") + signature_lines.append("batch_size: int = 10000") + + init_member_args.append("self._use_external_memory_version = use_external_memory_version") + init_member_args.append("self._batch_size = batch_size") + ADDITIONAL_PARAM_DESCRIPTIONS["use_external_memory_version"] = PARAM_DESC_USE_EXTERNAL_MEMORY_VERSION + ADDITIONAL_PARAM_DESCRIPTIONS["batch_size"] = PARAM_DESC_BATCH_SIZE + else: + init_member_args.append("self._use_external_memory_version = False") + init_member_args.append("self._batch_size = -1") + sklearn_init_lines.append("**cleaned_up_init_args") if has_kwargs: signature_lines.append("**kwargs") @@ -934,6 +983,11 @@ def generate(self) -> "SklearnWrapperGenerator": if self._is_cluster: self.fit_predict_cluster_function_support = True + if self._is_manifold: + self.fit_transform_manifold_function_support = True + + if self._is_manifold: + self.fit_transform_manifold_function_support = True if WrapperGeneratorFactory._is_class_of_type(self.class_object[1], "SelectKBest"): # Set the k of SelectKBest features transformer to half the number of columns in the dataset. diff --git a/codegen/sklearn_wrapper_template.py_template b/codegen/sklearn_wrapper_template.py_template index f1270070..35e09c17 100644 --- a/codegen/sklearn_wrapper_template.py_template +++ b/codegen/sklearn_wrapper_template.py_template @@ -48,6 +48,18 @@ _PROJECT = "ModelDevelopment" _SUBPROJECT = "".join([s.capitalize() for s in "{transform.root_module_name}".replace("sklearn.", "").split("_")]) +def _is_fit_predict_method_enabled() -> Callable[[Any], bool]: + def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]: + return {transform.fit_predict_cluster_function_support} and callable(getattr(self._sklearn_object, "fit_predict", None)) + return check + + +def _is_fit_transform_method_enabled() -> Callable[[Any], bool]: + def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]: + return {transform.fit_transform_manifold_function_support} and callable(getattr(self._sklearn_object, "fit_transform", None)) + return check + + class {transform.original_class_name}(BaseTransformer): r"""{transform.estimator_class_docstring} """ @@ -123,11 +135,6 @@ class {transform.original_class_name}(BaseTransformer): if isinstance(dataset, DataFrame): session = dataset._session assert session is not None # keep mypy happy - # Validate that key package version in user workspace are supported in snowflake conda channel - # If customer doesn't have package in conda channel, replace the ones have the closest versions - self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( - pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT) - # Specify input columns so column pruning will be enforced selected_cols = self._get_active_columns() if len(selected_cols) > 0: @@ -155,7 +162,9 @@ class {transform.original_class_name}(BaseTransformer): label_cols=self.label_cols, sample_weight_col=self.sample_weight_col, autogenerated=self._autogenerated, - subproject=_SUBPROJECT + subproject=_SUBPROJECT, + use_external_memory_version=self._use_external_memory_version, + batch_size=self._batch_size, ) self._sklearn_object = model_trainer.train() self._is_fitted = True @@ -424,20 +433,27 @@ class {transform.original_class_name}(BaseTransformer): return output_df - @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc] - def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]: + @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc] + def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]: """ {transform.fit_predict_docstring} Returns: Predicted dataset. """ - if {transform.fit_predict_cluster_function_support}: - self.fit(dataset) - assert self._sklearn_object is not None - labels : npt.NDArray[Any] = self._sklearn_object.labels_ - return labels - else: - # TODO(xinyi): support fit_predict for mixture classes - raise NotImplementedError + self.fit(dataset) + assert self._sklearn_object is not None + return self._sklearn_object.labels_ + + + @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc] + def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]: + """ {transform.fit_transform_docstring} + Returns: + Transformed dataset. + """ + self.fit(dataset) + assert self._sklearn_object is not None + return self._sklearn_object.embedding_ + def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]: """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions. diff --git a/docs/source/_templates/autosummary/class.rst b/docs/source/_templates/autosummary/class.rst index 714c1ce2..9332d590 100644 --- a/docs/source/_templates/autosummary/class.rst +++ b/docs/source/_templates/autosummary/class.rst @@ -1,16 +1,15 @@ {% extends "!autosummary/class.rst" %} +.. autoclass:: {{ objname }} {% set methods =(methods| reject("equalto", "__init__") |list) %} {% block methods %} {% if methods %} .. rubric:: Methods - - .. autosummary:: - {% for item in methods %} - ~{{ name }}.{{ item }} - {%- endfor %} +{% for item in methods %} + .. automethod:: {{ name }}.{{ item }} +{%- endfor %} {% endif %} {% endblock %} @@ -18,10 +17,8 @@ {% if attributes %} .. rubric:: Attributes - - .. autosummary:: - {% for item in attributes %} - ~{{ name }}.{{ item }} - {%- endfor %} +{% for item in attributes %} + .. autoattribute:: {{ name }}.{{ item }} +{%- endfor %} {% endif %} {% endblock %} diff --git a/docs/source/index.rst b/docs/source/index.rst index 18365333..5477edb9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -24,3 +24,5 @@ Table of Contents modeling fileset + model + registry diff --git a/docs/source/model.rst b/docs/source/model.rst new file mode 100644 index 00000000..8cecd338 --- /dev/null +++ b/docs/source/model.rst @@ -0,0 +1,64 @@ +=========================== +snowflake.ml.model +=========================== + +.. automodule:: snowflake.ml.model + :noindex: + +snowflake.ml.model +--------------------------------- + +.. currentmodule:: snowflake.ml.model + +.. rubric:: Classes + +.. autosummary:: + :toctree: api/model + + Model + ModelVersion + HuggingFacePipelineModel + LLM + LLMOptions + +snowflake.ml.model.custom_model +--------------------------------- + +.. currentmodule:: snowflake.ml.model.custom_model + +.. rubric:: Classes + +.. autosummary:: + :toctree: api/model + + MethodRef + ModelRef + ModelContext + CustomModel + +snowflake.ml.model.model_signature +--------------------------------- + +.. currentmodule:: snowflake.ml.model.model_signature + +.. rubric:: Classes + +.. autosummary:: + :toctree: api/model + + DataType + BaseFeatureSpec + FeatureSpec + ModelSignature + +.. rubric:: Methods + +.. autosummary:: + :toctree: api/model + + infer_signature + + +.. .. rubric:: Attributes + +.. None diff --git a/docs/source/registry.rst b/docs/source/registry.rst index ede781ca..832328ce 100644 --- a/docs/source/registry.rst +++ b/docs/source/registry.rst @@ -1,5 +1,3 @@ -:orphan: - =========================== snowflake.ml.registry =========================== @@ -14,8 +12,7 @@ snowflake.ml.registry .. autosummary:: :toctree: api/registry - model_registry.ModelRegistry - model_registry.ModelReference + Registry .. .. rubric:: Methods diff --git a/docs/sphinxconf/conf.py b/docs/sphinxconf/conf.py index b1423dbb..af73cda3 100644 --- a/docs/sphinxconf/conf.py +++ b/docs/sphinxconf/conf.py @@ -29,7 +29,7 @@ extensions = [ "sphinx.ext.autodoc", "sphinx.ext.autosummary", - # "sphinx.ext.napoleon", + "sphinx.ext.napoleon", # "sphinx.ext.coverage", # "sphinx.ext.linkcode" ] @@ -56,6 +56,8 @@ autosummary_generate = True autosummary_generate_overwrite = True +autoclass_content = "both" + # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for @@ -85,6 +87,18 @@ suppress_warnings = ["ref"] +# Napoleon settings +napoleon_google_docstring = True +napoleon_numpy_docstring = True +napoleon_use_admonition_for_examples = False +napoleon_use_admonition_for_notes = False +napoleon_use_admonition_for_references = False +napoleon_use_ivar = False +napoleon_use_param = True +napoleon_use_rtype = True +napoleon_preprocess_types = False +napoleon_attr_annotations = True + def setup(app: Any) -> None: app.connect( @@ -166,6 +180,8 @@ def __init__(self, csv_filename: str) -> None: # sphinx expects a function for this, so make instance callable def __call__(self, app: Any, what: str, name: str, obj: ModuleType, skip: bool, options: Dict[str, Any]) -> bool: + if name == "__init__": + return False if name.startswith("_"): return True if what == "method": diff --git a/snowflake/cortex/BUILD.bazel b/snowflake/cortex/BUILD.bazel index da7c8ef7..310c4bb6 100644 --- a/snowflake/cortex/BUILD.bazel +++ b/snowflake/cortex/BUILD.bazel @@ -3,13 +3,15 @@ load("//bazel:py_rules.bzl", "py_library", "py_package", "py_test") package_group( name = "cortex", packages = [ - "//docs/...", "//snowflake/cortex/...", "//snowflake/ml/...", ], ) -package(default_visibility = [":cortex"]) +package(default_visibility = [ + ":cortex", + "//bazel:snowml_public_common", +]) py_library( name = "util", diff --git a/snowflake/ml/_internal/container_services/image_registry/BUILD.bazel b/snowflake/ml/_internal/container_services/image_registry/BUILD.bazel new file mode 100644 index 00000000..2e875839 --- /dev/null +++ b/snowflake/ml/_internal/container_services/image_registry/BUILD.bazel @@ -0,0 +1,56 @@ +load("//bazel:py_rules.bzl", "py_library", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "credential", + srcs = ["credential.py"], + deps = ["//snowflake/ml/_internal/utils:query_result_checker"], +) + +py_library( + name = "http_client", + srcs = ["http_client.py"], + deps = [ + "//snowflake/ml/_internal/exceptions", + "//snowflake/ml/_internal/utils:retryable_http", + "//snowflake/ml/_internal/utils:session_token_manager", + ], +) + +py_test( + name = "http_client_test", + srcs = ["http_client_test.py"], + deps = [ + ":http_client", + "//snowflake/ml/test_utils:mock_session", + ], +) + +py_library( + name = "registry_client", + srcs = ["registry_client.py"], + deps = [ + ":http_client", + ":imagelib", + "//snowflake/ml/_internal/exceptions", + ], +) + +py_library( + name = "imagelib", + srcs = ["imagelib.py"], + deps = [ + ":http_client", + ], +) + +py_test( + name = "registry_client_test", + srcs = ["registry_client_test.py"], + deps = [ + ":registry_client", + "//snowflake/ml/test_utils:exception_utils", + "//snowflake/ml/test_utils:mock_session", + ], +) diff --git a/snowflake/ml/_internal/utils/spcs_image_registry.py b/snowflake/ml/_internal/container_services/image_registry/credential.py similarity index 100% rename from snowflake/ml/_internal/utils/spcs_image_registry.py rename to snowflake/ml/_internal/container_services/image_registry/credential.py diff --git a/snowflake/ml/_internal/utils/image_registry_http_client.py b/snowflake/ml/_internal/container_services/image_registry/http_client.py similarity index 100% rename from snowflake/ml/_internal/utils/image_registry_http_client.py rename to snowflake/ml/_internal/container_services/image_registry/http_client.py diff --git a/snowflake/ml/_internal/utils/image_registry_http_client_test.py b/snowflake/ml/_internal/container_services/image_registry/http_client_test.py similarity index 98% rename from snowflake/ml/_internal/utils/image_registry_http_client_test.py rename to snowflake/ml/_internal/container_services/image_registry/http_client_test.py index d6bf0f69..b5a55c8e 100644 --- a/snowflake/ml/_internal/utils/image_registry_http_client_test.py +++ b/snowflake/ml/_internal/container_services/image_registry/http_client_test.py @@ -5,8 +5,10 @@ from absl.testing import absltest, parameterized from absl.testing.absltest import mock +from snowflake.ml._internal.container_services.image_registry import ( + http_client as image_registry_http_client, +) from snowflake.ml._internal.exceptions import exceptions as snowml_exceptions -from snowflake.ml._internal.utils import image_registry_http_client from snowflake.ml.test_utils import mock_session from snowflake.snowpark import session diff --git a/snowflake/ml/model/_deploy_client/utils/imagelib.py b/snowflake/ml/_internal/container_services/image_registry/imagelib.py similarity index 99% rename from snowflake/ml/model/_deploy_client/utils/imagelib.py rename to snowflake/ml/_internal/container_services/image_registry/imagelib.py index 8d0d9115..be366779 100644 --- a/snowflake/ml/model/_deploy_client/utils/imagelib.py +++ b/snowflake/ml/_internal/container_services/image_registry/imagelib.py @@ -23,7 +23,9 @@ import requests -from snowflake.ml._internal.utils import image_registry_http_client +from snowflake.ml._internal.container_services.image_registry import ( + http_client as image_registry_http_client, +) # Common HTTP headers _CONTENT_LENGTH_HEADER = "content-length" diff --git a/snowflake/ml/model/_deploy_client/utils/image_registry_client.py b/snowflake/ml/_internal/container_services/image_registry/registry_client.py similarity index 98% rename from snowflake/ml/model/_deploy_client/utils/image_registry_client.py rename to snowflake/ml/_internal/container_services/image_registry/registry_client.py index df8775b6..d07e6cec 100644 --- a/snowflake/ml/model/_deploy_client/utils/image_registry_client.py +++ b/snowflake/ml/_internal/container_services/image_registry/registry_client.py @@ -3,12 +3,14 @@ from typing import Dict, Optional, cast from urllib.parse import urlunparse +from snowflake.ml._internal.container_services.image_registry import ( + http_client as image_registry_http_client, + imagelib, +) from snowflake.ml._internal.exceptions import ( error_codes, exceptions as snowml_exceptions, ) -from snowflake.ml._internal.utils import image_registry_http_client -from snowflake.ml.model._deploy_client.utils import imagelib from snowflake.snowpark import Session from snowflake.snowpark._internal import utils as snowpark_utils diff --git a/snowflake/ml/model/_deploy_client/utils/image_registry_client_test.py b/snowflake/ml/_internal/container_services/image_registry/registry_client_test.py similarity index 97% rename from snowflake/ml/model/_deploy_client/utils/image_registry_client_test.py rename to snowflake/ml/_internal/container_services/image_registry/registry_client_test.py index 4a1f3430..7117e1eb 100644 --- a/snowflake/ml/model/_deploy_client/utils/image_registry_client_test.py +++ b/snowflake/ml/_internal/container_services/image_registry/registry_client_test.py @@ -3,7 +3,9 @@ from absl.testing import absltest from absl.testing.absltest import mock -from snowflake.ml.model._deploy_client.utils import image_registry_client +from snowflake.ml._internal.container_services.image_registry import ( + registry_client as image_registry_client, +) from snowflake.ml.test_utils import mock_session from snowflake.snowpark import session diff --git a/snowflake/ml/_internal/env_utils.py b/snowflake/ml/_internal/env_utils.py index 2ba8a16d..1fd225c3 100644 --- a/snowflake/ml/_internal/env_utils.py +++ b/snowflake/ml/_internal/env_utils.py @@ -33,7 +33,6 @@ class CONDA_OS(Enum): _SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake" _NODEFAULTS = "nodefaults" -_INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION: Optional[bool] = None _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {} _SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {} @@ -267,18 +266,6 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req return new_req -def _check_runtime_version_column_existence(session: session.Session) -> bool: - sql = textwrap.dedent( - """ - SHOW COLUMNS - LIKE 'runtime_version' - IN TABLE information_schema.packages; - """ - ) - result = session.sql(sql).count() - return result == 1 - - def get_matched_package_versions_in_snowflake_conda_channel( req: requirements.Requirement, python_version: str = snowml_env.PYTHON_VERSION, @@ -325,9 +312,9 @@ def get_matched_package_versions_in_snowflake_conda_channel( return matched_versions -def validate_requirements_in_information_schema( +def get_matched_package_versions_in_information_schema( session: session.Session, reqs: List[requirements.Requirement], python_version: str -) -> Optional[List[str]]: +) -> Dict[str, List[version.Version]]: """Look up the information_schema table to check if a package with the specified specifier exists in the Snowflake Conda channel. Note that this is not the source of truth due to the potential delay caused by a package that might exist in the information_schema table but has not yet become available in the Snowflake Conda channel. @@ -338,42 +325,35 @@ def validate_requirements_in_information_schema( python_version: A string of python version where model is run. Returns: - A list of pinned latest version that available in Snowflake anaconda channel and meet the version specifier. + A Dict, whose key is the package name, and value is a list of versions match the requirements. """ - global _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION - - if _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION is None: - _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION = _check_runtime_version_column_existence(session) - ret_list = [] - reqs_to_request = [] + ret_dict: Dict[str, List[version.Version]] = {} + reqs_to_request: List[requirements.Requirement] = [] for req in reqs: - if req.name not in _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: + if req.name in _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: + available_versions = list( + sorted(req.specifier.filter(set(_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req.name, [])))) + ) + ret_dict[req.name] = available_versions + else: reqs_to_request.append(req) + if reqs_to_request: pkg_names_str = " OR ".join( f"package_name = '{req_name}'" for req_name in sorted(req.name for req in reqs_to_request) ) - if _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION: - parsed_python_version = version.Version(python_version) - sql = textwrap.dedent( - f""" - SELECT PACKAGE_NAME, VERSION - FROM information_schema.packages - WHERE ({pkg_names_str}) - AND language = 'python' - AND (runtime_version = '{parsed_python_version.major}.{parsed_python_version.minor}' - OR runtime_version is null); - """ - ) - else: - sql = textwrap.dedent( - f""" - SELECT PACKAGE_NAME, VERSION - FROM information_schema.packages - WHERE ({pkg_names_str}) - AND language = 'python'; - """ - ) + + parsed_python_version = version.Version(python_version) + sql = textwrap.dedent( + f""" + SELECT PACKAGE_NAME, VERSION + FROM information_schema.packages + WHERE ({pkg_names_str}) + AND language = 'python' + AND (runtime_version = '{parsed_python_version.major}.{parsed_python_version.minor}' + OR runtime_version is null); + """ + ) try: result = ( @@ -392,14 +372,13 @@ def validate_requirements_in_information_schema( cached_req_ver_list.append(req_ver) _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE[req_name] = cached_req_ver_list except snowflake.connector.DataError: - return None - for req in reqs: - available_versions = list(req.specifier.filter(set(_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req.name, [])))) - if not available_versions: - return None - else: - ret_list.append(str(req)) - return sorted(ret_list) + return ret_dict + for req in reqs_to_request: + available_versions = list( + sorted(req.specifier.filter(set(_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req.name, [])))) + ) + ret_dict[req.name] = available_versions + return ret_dict def save_conda_env_file( diff --git a/snowflake/ml/_internal/env_utils_test.py b/snowflake/ml/_internal/env_utils_test.py index 176ddb0d..07068605 100644 --- a/snowflake/ml/_internal/env_utils_test.py +++ b/snowflake/ml/_internal/env_utils_test.py @@ -9,7 +9,7 @@ import yaml from absl.testing import absltest -from packaging import requirements, specifiers +from packaging import requirements, specifiers, version from snowflake.ml._internal import env as snowml_env, env_utils from snowflake.ml.test_utils import mock_data_frame, mock_session @@ -294,25 +294,17 @@ def test_relax_requirement_version(self) -> None: self.assertEqual(env_utils.relax_requirement_version(r), requirements.Requirement("python-package")) self.assertIsNot(env_utils.relax_requirement_version(r), r) - def test_validate_requirements_in_information_schema(self) -> None: + def test_get_matched_package_versions_in_information_schema(self) -> None: m_session = mock_session.MockSession(conn=None, test_case=self) - m_session.add_mock_sql( - query=textwrap.dedent( - """ - SHOW COLUMNS - LIKE 'runtime_version' - IN TABLE information_schema.packages; - """ - ), - result=mock_data_frame.MockDataFrame(count_result=0), - ) query = textwrap.dedent( - """ + f""" SELECT PACKAGE_NAME, VERSION FROM information_schema.packages WHERE (package_name = 'pytorch' OR package_name = 'xgboost') - AND language = 'python'; + AND language = 'python' + AND (runtime_version = '{platform.python_version_tuple()[0]}.{platform.python_version_tuple()[1]}' + OR runtime_version is null); """ ) sql_result = [ @@ -325,34 +317,42 @@ def test_validate_requirements_in_information_schema(self) -> None: m_session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) c_session = cast(session.Session, m_session) - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], python_version=snowml_env.PYTHON_VERSION, ), - sorted(["xgboost", "pytorch"]), + { + "xgboost": list(map(version.parse, ["1.3.3", "1.5.1", "1.7.3"])), + "pytorch": list(map(version.parse, ["1.12.1"])), + }, ) # Test cache - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], python_version=snowml_env.PYTHON_VERSION, ), - sorted(["xgboost", "pytorch"]), + { + "xgboost": list(map(version.parse, ["1.3.3", "1.5.1", "1.7.3"])), + "pytorch": list(map(version.parse, ["1.12.1"])), + }, ) # clear cache env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} query = textwrap.dedent( - """ + f""" SELECT PACKAGE_NAME, VERSION FROM information_schema.packages WHERE (package_name = 'xgboost') - AND language = 'python'; + AND language = 'python' + AND (runtime_version = '{platform.python_version_tuple()[0]}.{platform.python_version_tuple()[1]}' + OR runtime_version is null); """ ) sql_result = [ @@ -365,31 +365,37 @@ def test_validate_requirements_in_information_schema(self) -> None: m_session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) c_session = cast(session.Session, m_session) - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost")], python_version=snowml_env.PYTHON_VERSION, ), - ["xgboost"], + { + "xgboost": list(map(version.parse, ["1.3.3", "1.5.1", "1.7.3"])), + }, ) # Test cache - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost")], python_version=snowml_env.PYTHON_VERSION, ), - ["xgboost"], + { + "xgboost": list(map(version.parse, ["1.3.3", "1.5.1", "1.7.3"])), + }, ) query = textwrap.dedent( - """ + f""" SELECT PACKAGE_NAME, VERSION FROM information_schema.packages WHERE (package_name = 'pytorch') - AND language = 'python'; + AND language = 'python' + AND (runtime_version = '{platform.python_version_tuple()[0]}.{platform.python_version_tuple()[1]}' + OR runtime_version is null); """ ) sql_result = [ @@ -400,34 +406,42 @@ def test_validate_requirements_in_information_schema(self) -> None: m_session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) c_session = cast(session.Session, m_session) - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], python_version=snowml_env.PYTHON_VERSION, ), - sorted(["xgboost", "pytorch"]), + { + "xgboost": list(map(version.parse, ["1.3.3", "1.5.1", "1.7.3"])), + "pytorch": list(map(version.parse, ["1.12.1"])), + }, ) # Test cache - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], python_version=snowml_env.PYTHON_VERSION, ), - sorted(["xgboost", "pytorch"]), + { + "xgboost": list(map(version.parse, ["1.3.3", "1.5.1", "1.7.3"])), + "pytorch": list(map(version.parse, ["1.12.1"])), + }, ) # clear cache env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} query = textwrap.dedent( - """ + f""" SELECT PACKAGE_NAME, VERSION FROM information_schema.packages WHERE (package_name = 'xgboost') - AND language = 'python'; + AND language = 'python' + AND (runtime_version = '{platform.python_version_tuple()[0]}.{platform.python_version_tuple()[1]}' + OR runtime_version is null); """ ) sql_result = [ @@ -439,40 +453,49 @@ def test_validate_requirements_in_information_schema(self) -> None: m_session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) c_session = cast(session.Session, m_session) - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.7.3")], python_version=snowml_env.PYTHON_VERSION, ), - ["xgboost==1.7.3"], + { + "xgboost": list(map(version.parse, ["1.7.3"])), + }, ) # Test cache - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.7.3")], python_version=snowml_env.PYTHON_VERSION, ), - ["xgboost==1.7.3"], + { + "xgboost": list(map(version.parse, ["1.7.3"])), + }, ) - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost>=1.7,<1.8")], python_version=snowml_env.PYTHON_VERSION, ), - ["xgboost<1.8,>=1.7"], + { + "xgboost": list(map(version.parse, ["1.7.0", "1.7.1", "1.7.3"])), + }, ) - self.assertIsNone( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.7.1, ==1.7.3")], python_version=snowml_env.PYTHON_VERSION, - ) + ), + { + "xgboost": list(map(version.parse, [])), + }, ) # clear cache @@ -481,23 +504,27 @@ def test_validate_requirements_in_information_schema(self) -> None: m_session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) c_session = cast(session.Session, m_session) - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.7.*")], python_version=snowml_env.PYTHON_VERSION, ), - ["xgboost==1.7.*"], + { + "xgboost": list(map(version.parse, ["1.7.0", "1.7.1", "1.7.3"])), + }, ) # Test cache - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.7.*")], python_version=snowml_env.PYTHON_VERSION, ), - ["xgboost==1.7.*"], + { + "xgboost": list(map(version.parse, ["1.7.0", "1.7.1", "1.7.3"])), + }, ) # clear cache @@ -506,98 +533,55 @@ def test_validate_requirements_in_information_schema(self) -> None: m_session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) c_session = cast(session.Session, m_session) - self.assertIsNone( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.3.*")], python_version=snowml_env.PYTHON_VERSION, - ) + ), + { + "xgboost": list(map(version.parse, [])), + }, ) # Test cache - self.assertIsNone( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, reqs=[requirements.Requirement("xgboost==1.3.*")], python_version=snowml_env.PYTHON_VERSION, - ) + ), + { + "xgboost": list(map(version.parse, [])), + }, ) # clear cache env_utils._SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE = {} - query = textwrap.dedent( - """ - SELECT PACKAGE_NAME, VERSION - FROM information_schema.packages - WHERE (package_name = 'python-package') - AND language = 'python'; - """ - ) - sql_result = [row.Row()] - - m_session = mock_session.MockSession(conn=None, test_case=self) - m_session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) - c_session = cast(session.Session, m_session) - - self.assertIsNone( - env_utils.validate_requirements_in_information_schema( - session=c_session, - reqs=[requirements.Requirement("python-package")], - python_version=snowml_env.PYTHON_VERSION, - ) - ) - - env_utils._INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION = None - m_session = mock_session.MockSession(conn=None, test_case=self) - m_session.add_mock_sql( - query=textwrap.dedent( - """ - SHOW COLUMNS - LIKE 'runtime_version' - IN TABLE information_schema.packages; - """ - ), - result=mock_data_frame.MockDataFrame(count_result=1), - ) - query = textwrap.dedent( f""" SELECT PACKAGE_NAME, VERSION FROM information_schema.packages - WHERE (package_name = 'pytorch' OR package_name = 'xgboost') + WHERE (package_name = 'python-package') AND language = 'python' AND (runtime_version = '{platform.python_version_tuple()[0]}.{platform.python_version_tuple()[1]}' OR runtime_version is null); """ ) - sql_result = [ - row.Row(PACKAGE_NAME="xgboost", VERSION="1.3.3"), - row.Row(PACKAGE_NAME="xgboost", VERSION="1.5.1"), - row.Row(PACKAGE_NAME="xgboost", VERSION="1.7.3"), - row.Row(PACKAGE_NAME="pytorch", VERSION="1.12.1"), - ] + sql_result = [row.Row()] + m_session = mock_session.MockSession(conn=None, test_case=self) m_session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) c_session = cast(session.Session, m_session) - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( - session=c_session, - reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], - python_version=snowml_env.PYTHON_VERSION, - ), - sorted(["xgboost", "pytorch"]), - ) - - # Test cache - self.assertListEqual( - env_utils.validate_requirements_in_information_schema( + self.assertDictEqual( + env_utils.get_matched_package_versions_in_information_schema( session=c_session, - reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], + reqs=[requirements.Requirement("python-package")], python_version=snowml_env.PYTHON_VERSION, ), - sorted(["xgboost", "pytorch"]), + {}, ) def test_parse_python_version_string(self) -> None: diff --git a/snowflake/ml/_internal/file_utils.py b/snowflake/ml/_internal/file_utils.py index fb7387e9..7c05065b 100644 --- a/snowflake/ml/_internal/file_utils.py +++ b/snowflake/ml/_internal/file_utils.py @@ -362,3 +362,20 @@ def download_directory_from_stage( wait_exponential_multiplier=100, wait_exponential_max=10000, )(file_operation.get)(str(stage_file_path), str(local_file_dir), statement_params=statement_params) + + +def open_file(path: str, *args: Any, **kwargs: Any) -> Any: + """This function is a wrapper on top of the Python built-in "open" function, with a few added default values + to ensure successful execution across different platforms. + + Args: + path: file path + *args: arguments. + **kwargs: key arguments. + + Returns: + Open file and return a stream. + """ + kwargs.setdefault("newline", "\n") + kwargs.setdefault("encoding", "utf-8") + return open(path, *args, **kwargs) diff --git a/snowflake/ml/_internal/utils/BUILD.bazel b/snowflake/ml/_internal/utils/BUILD.bazel index 6aca28de..423ee9ff 100644 --- a/snowflake/ml/_internal/utils/BUILD.bazel +++ b/snowflake/ml/_internal/utils/BUILD.bazel @@ -151,12 +151,6 @@ py_library( srcs = ["result.py"], ) -py_library( - name = "spcs_image_registry", - srcs = ["spcs_image_registry.py"], - deps = [":query_result_checker"], -) - py_library( name = "table_manager", srcs = [ @@ -218,38 +212,37 @@ py_library( ) py_library( - name = "image_registry_http_client", - srcs = ["image_registry_http_client.py"], + name = "spcs_attribution_utils", + srcs = ["spcs_attribution_utils.py"], deps = [ - ":session_token_manager", - "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/_internal/utils:retryable_http", + ":query_result_checker", + "//snowflake/ml/_internal:telemetry", ], ) py_test( - name = "image_registry_http_client_test", - srcs = ["image_registry_http_client_test.py"], + name = "spcs_attribution_utils_test", + srcs = ["spcs_attribution_utils_test.py"], deps = [ - ":image_registry_http_client", + ":spcs_attribution_utils", + "//snowflake/ml/test_utils:mock_data_frame", "//snowflake/ml/test_utils:mock_session", ], ) py_library( - name = "spcs_attribution_utils", - srcs = ["spcs_attribution_utils.py"], + name = "snowflake_env", + srcs = ["snowflake_env.py"], deps = [ ":query_result_checker", - "//snowflake/ml/_internal:telemetry", ], ) py_test( - name = "spcs_attribution_utils_test", - srcs = ["spcs_attribution_utils_test.py"], + name = "snowflake_env_test", + srcs = ["snowflake_env_test.py"], deps = [ - ":spcs_attribution_utils", + ":snowflake_env", "//snowflake/ml/test_utils:mock_data_frame", "//snowflake/ml/test_utils:mock_session", ], diff --git a/snowflake/ml/_internal/utils/query_result_checker.py b/snowflake/ml/_internal/utils/query_result_checker.py index f5244ecf..a51068e2 100644 --- a/snowflake/ml/_internal/utils/query_result_checker.py +++ b/snowflake/ml/_internal/utils/query_result_checker.py @@ -60,9 +60,13 @@ def result_dimension_matcher( return True -def column_name_matcher(expected_col_name: str, result: list[snowpark.Row], sql: str | None = None) -> bool: +def column_name_matcher( + expected_col_name: str, allow_empty: bool, result: list[snowpark.Row], sql: str | None = None +) -> bool: """Returns true if `expected_col_name` is found. Raise exception otherwise.""" if not result: + if allow_empty: + return True raise connector.DataError(f"Query Result is empty.{_query_log(sql)}") if expected_col_name not in result[0]: raise connector.DataError( @@ -159,16 +163,17 @@ def has_dimensions(self, expected_rows: int | None = None, expected_cols: int | self._success_matchers.append(partial(result_dimension_matcher, expected_rows, expected_cols)) return self - def has_column(self, expected_col_name: str) -> ResultValidator: + def has_column(self, expected_col_name: str, allow_empty: bool = False) -> ResultValidator: """Validate that the a column with the name `expected_column_name` exists in the result. Args: expected_col_name: Name of the column that is expected to be present in the result (case sensitive). + allow_empty: If the check will fail if the result is empty. Returns: ResultValidator object (self) """ - self._success_matchers.append(partial(column_name_matcher, expected_col_name)) + self._success_matchers.append(partial(column_name_matcher, expected_col_name, allow_empty)) return self def has_named_value_match(self, row_idx: int, col_name: str, expected_value: Any) -> ResultValidator: @@ -224,8 +229,6 @@ def validate(self) -> list[snowpark.Row]: Returns: Query result. """ - if len(self._success_matchers) == 0: - self._success_matchers = _DEFAULT_MATCHERS result = self._get_result() for matcher in self._success_matchers: assert matcher(result, self._query) diff --git a/snowflake/ml/_internal/utils/query_result_checker_test.py b/snowflake/ml/_internal/utils/query_result_checker_test.py index caca3989..d3319038 100644 --- a/snowflake/ml/_internal/utils/query_result_checker_test.py +++ b/snowflake/ml/_internal/utils/query_result_checker_test.py @@ -25,9 +25,10 @@ def test_column_name_matcher(self) -> None: """Test column_name_matcher().""" row1 = Row(name1=1, name2=2) row2 = Row(name1=3, name2=4) - self.assertTrue(query_result_checker.column_name_matcher("name1", [row1, row2])) - self.assertRaises(DataError, query_result_checker.column_name_matcher, "name1", []) - self.assertRaises(DataError, query_result_checker.column_name_matcher, "name3", [row1, row2]) + self.assertTrue(query_result_checker.column_name_matcher("name1", False, [row1, row2])) + self.assertTrue(query_result_checker.column_name_matcher("name1", True, [])) + self.assertRaises(DataError, query_result_checker.column_name_matcher, "name1", False, []) + self.assertRaises(DataError, query_result_checker.column_name_matcher, "name3", False, [row1, row2]) def test_result_validator_dimensions_partial_ok(self) -> None: """Use the base ResultValidator to verify the dimensions of an operation result.""" @@ -112,6 +113,19 @@ def test_sql_result_validator_column_ok(self) -> None: ) self.assertEqual(actual_result, sql_result) + def test_sql_result_validator_column_empty(self) -> None: + """Use SqlResultValidator to check that a specific column exists in the result.""" + session = mock_session.MockSession(conn=None, test_case=self) + query = "CREATE TABLE TEMP" + sql_result: List[Row] = [] + session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) + actual_result = ( + query_result_checker.SqlResultValidator(session=cast(snowpark.Session, session), query=query) + .has_column(expected_col_name="status", allow_empty=True) + .validate() + ) + self.assertEqual(actual_result, sql_result) + def test_sql_result_validator_column_fail(self) -> None: """Use SqlResultValidator to check that a specific column exists in the result but we the column is missing.""" session = mock_session.MockSession(conn=None, test_case=self) diff --git a/snowflake/ml/_internal/utils/snowflake_env.py b/snowflake/ml/_internal/utils/snowflake_env.py new file mode 100644 index 00000000..1dc41abe --- /dev/null +++ b/snowflake/ml/_internal/utils/snowflake_env.py @@ -0,0 +1,95 @@ +import enum +from typing import Any, Dict, Optional, TypedDict, cast + +from packaging import version +from typing_extensions import Required + +from snowflake.ml._internal.utils import query_result_checker +from snowflake.snowpark import session + + +def get_current_snowflake_version( + sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None +) -> version.Version: + """Get Snowflake Version as a version.Version object follow PEP way of versioning, that is to say: + "7.44.2 b202312132139364eb71238" to + + Args: + sess: Snowpark Session. + statement_params: Statement params. Defaults to None. + + Returns: + The version of Snowflake Version. + """ + res = ( + query_result_checker.SqlResultValidator( + sess, "SELECT CURRENT_VERSION() AS CURRENT_VERSION", statement_params=statement_params + ) + .has_dimensions(expected_rows=1, expected_cols=1) + .validate()[0] + ) + + version_str = res.CURRENT_VERSION + assert isinstance(version_str, str) + + version_str = "+".join(version_str.split()) + return version.parse(version_str) + + +class SnowflakeCloudType(enum.Enum): + AWS = "aws" + AZURE = "azure" + GCP = "gcp" + + @classmethod + def from_value(cls, value: str) -> "SnowflakeCloudType": + assert value + for k in cls: + if k.value == value.lower(): + return k + else: + raise ValueError(f"'{cls.__name__}' enum not found for '{value}'") + + +class SnowflakeRegion(TypedDict): + region_group: Required[str] + snowflake_region: Required[str] + cloud: Required[SnowflakeCloudType] + region: Required[str] + display_name: Required[str] + + +def get_regions( + sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None +) -> Dict[str, SnowflakeRegion]: + res = ( + query_result_checker.SqlResultValidator(sess, "SHOW REGIONS", statement_params=statement_params) + .has_column("region_group") + .has_column("snowflake_region") + .has_column("cloud") + .has_column("region") + .has_column("display_name") + .validate() + ) + return { + f"{r.region_group}.{r.snowflake_region}": SnowflakeRegion( + region_group=r.region_group, + snowflake_region=r.snowflake_region, + cloud=SnowflakeCloudType.from_value(r.cloud), + region=r.region, + display_name=r.display_name, + ) + for r in res + } + + +def get_current_region_id(sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None) -> str: + res = ( + query_result_checker.SqlResultValidator( + sess, "SELECT CURRENT_REGION() AS CURRENT_REGION", statement_params=statement_params + ) + .has_dimensions(expected_rows=1, expected_cols=1) + .validate()[0] + ) + + return cast(str, res.CURRENT_REGION) diff --git a/snowflake/ml/_internal/utils/snowflake_env_test.py b/snowflake/ml/_internal/utils/snowflake_env_test.py new file mode 100644 index 00000000..850cc21f --- /dev/null +++ b/snowflake/ml/_internal/utils/snowflake_env_test.py @@ -0,0 +1,93 @@ +from typing import cast + +from absl.testing import absltest +from packaging import version + +from snowflake.ml._internal.utils import snowflake_env +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import Row, Session + + +class SnowflakeEnvTest(absltest.TestCase): + def test_current_snowflake_version_1(self) -> None: + session = mock_session.MockSession(conn=None, test_case=self) + query = "SELECT CURRENT_VERSION() AS CURRENT_VERSION" + sql_result = [Row(CURRENT_VERSION="8.0.0")] + session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) + actual_result = snowflake_env.get_current_snowflake_version(cast(Session, session)) + self.assertEqual(actual_result, version.parse("8.0.0")) + + def test_current_snowflake_version_2(self) -> None: + session = mock_session.MockSession(conn=None, test_case=self) + query = "SELECT CURRENT_VERSION() AS CURRENT_VERSION" + sql_result = [Row(CURRENT_VERSION="8.0.0 1234567890ab")] + session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) + actual_result = snowflake_env.get_current_snowflake_version(cast(Session, session)) + self.assertEqual(actual_result, version.parse("8.0.0+1234567890ab")) + + def test_get_regions(self) -> None: + session = mock_session.MockSession(conn=None, test_case=self) + query = "SHOW REGIONS" + sql_result = [ + Row( + region_group="PUBLIC", + snowflake_region="AWS_US_WEST_2", + cloud="aws", + region="us-west-2", + display_name="US West (Oregon)", + ), + Row( + region_group="PUBLIC", + snowflake_region="AZURE_EASTUS2", + cloud="azure", + region="eastus2", + display_name="East US 2 (Virginia)", + ), + Row( + region_group="PUBLIC", + snowflake_region="GCP_EUROPE_WEST2", + cloud="gcp", + region="europe-west2", + display_name="Europe West 2 (London)", + ), + ] + session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) + actual_result = snowflake_env.get_regions(cast(Session, session)) + self.assertDictEqual( + { + "PUBLIC.AWS_US_WEST_2": snowflake_env.SnowflakeRegion( + region_group="PUBLIC", + snowflake_region="AWS_US_WEST_2", + cloud=snowflake_env.SnowflakeCloudType.AWS, + region="us-west-2", + display_name="US West (Oregon)", + ), + "PUBLIC.AZURE_EASTUS2": snowflake_env.SnowflakeRegion( + region_group="PUBLIC", + snowflake_region="AZURE_EASTUS2", + cloud=snowflake_env.SnowflakeCloudType.AZURE, + region="eastus2", + display_name="East US 2 (Virginia)", + ), + "PUBLIC.GCP_EUROPE_WEST2": snowflake_env.SnowflakeRegion( + region_group="PUBLIC", + snowflake_region="GCP_EUROPE_WEST2", + cloud=snowflake_env.SnowflakeCloudType.GCP, + region="europe-west2", + display_name="Europe West 2 (London)", + ), + }, + actual_result, + ) + + def test_get_current_region_id(self) -> None: + session = mock_session.MockSession(conn=None, test_case=self) + query = "SELECT CURRENT_REGION() AS CURRENT_REGION" + sql_result = [Row(CURRENT_REGION="PUBLIC.AWS_US_WEST_2")] + session.add_mock_sql(query=query, result=mock_data_frame.MockDataFrame(sql_result)) + actual_result = snowflake_env.get_current_region_id(cast(Session, session)) + self.assertEqual(actual_result, "PUBLIC.AWS_US_WEST_2") + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/feature_store/BUILD.bazel b/snowflake/ml/feature_store/BUILD.bazel index 88cf42bd..bbd441d6 100644 --- a/snowflake/ml/feature_store/BUILD.bazel +++ b/snowflake/ml/feature_store/BUILD.bazel @@ -7,7 +7,10 @@ package_group( ], ) -package(default_visibility = [":feature_store"]) +package(default_visibility = [ + ":feature_store", + "//bazel:snowml_public_common", +]) py_library( name = "init", diff --git a/snowflake/ml/feature_store/_internal/BUILD.bazel b/snowflake/ml/feature_store/_internal/BUILD.bazel index 41437b73..cfd68961 100644 --- a/snowflake/ml/feature_store/_internal/BUILD.bazel +++ b/snowflake/ml/feature_store/_internal/BUILD.bazel @@ -1,6 +1,9 @@ load("//bazel:py_rules.bzl", "py_library") -package(default_visibility = ["//snowflake/ml/feature_store"]) +package(default_visibility = [ + "//bazel:snowml_public_common", + "//snowflake/ml/feature_store", +]) py_library( name = "synthetic_data_generator", diff --git a/snowflake/ml/feature_store/_internal/scripts/BUILD.bazel b/snowflake/ml/feature_store/_internal/scripts/BUILD.bazel index 26a39beb..2ccfc0d7 100644 --- a/snowflake/ml/feature_store/_internal/scripts/BUILD.bazel +++ b/snowflake/ml/feature_store/_internal/scripts/BUILD.bazel @@ -1,5 +1,9 @@ load("//bazel:py_rules.bzl", "py_binary") +package(default_visibility = [ + "//bazel:snowml_public_common", +]) + py_binary( name = "run_synthetic_data_generator", srcs = [ diff --git a/snowflake/ml/feature_store/_internal/scripts/upload_test_datasets.py b/snowflake/ml/feature_store/_internal/scripts/upload_test_datasets.py index 4f815d96..cf4d4ddb 100644 --- a/snowflake/ml/feature_store/_internal/scripts/upload_test_datasets.py +++ b/snowflake/ml/feature_store/_internal/scripts/upload_test_datasets.py @@ -1,5 +1,11 @@ -# A helper script cleans open taxi data (https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page) -# and store into snowflake database. +""" +A helper script cleans open taxi data (https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page) +and store into snowflake database. + +Download yellow trip data(2016 Jan): https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page. +Download wine data: +https://www.google.com/url?q=https://github.com/snowflakedb/snowml/blob/main/snowflake/ml/feature_store/notebooks/customer_demo/winequality-red.csv&sa=D&source=docs&ust=1702084016573738&usg=AOvVaw3r_muH0_LKBDr45C1Gj3cb +""" from absl.logging import logging @@ -56,7 +62,7 @@ def create_winedata(sess: Session, overwrite_mode: str) -> None: full_table_name = f"{FS_INTEG_TEST_DB}.{FS_INTEG_TEST_DATASET_SCHEMA}.{FS_INTEG_TEST_WINE_QUALITY_DATA}" df = ( - sess.read.options({"field_delimiter": ",", "skip_header": 1}) + sess.read.options({"field_delimiter": ";", "skip_header": 1}) .schema(input_schema) .csv(f"{sess.get_session_stage()}/{WINEDATA_NAME}") ) diff --git a/snowflake/ml/feature_store/entity.py b/snowflake/ml/feature_store/entity.py index 44b0ece0..d88546a6 100644 --- a/snowflake/ml/feature_store/entity.py +++ b/snowflake/ml/feature_store/entity.py @@ -10,6 +10,9 @@ _ENTITY_JOIN_KEY_DELIMITER = "," # join key length limit is the length limit of TAG value _ENTITY_JOIN_KEY_LENGTH_LIMIT = 256 +# The maximum number of join keys: +# https://docs.snowflake.com/en/user-guide/object-tagging#specify-tag-values +_ENTITY_MAX_NUM_JOIN_KEYS = 300 class Entity: @@ -39,14 +42,18 @@ def _validate(self, name: str, join_keys: List[str]) -> None: raise ValueError(f"Entity name `{name}` exceeds maximum length: {_ENTITY_NAME_LENGTH_LIMIT}") if _FEATURE_VIEW_ENTITY_TAG_DELIMITER in name: raise ValueError(f"Entity name contains invalid char: `{_FEATURE_VIEW_ENTITY_TAG_DELIMITER}`") + if len(join_keys) > _ENTITY_MAX_NUM_JOIN_KEYS: + raise ValueError( + f"Maximum number of join keys are {_ENTITY_MAX_NUM_JOIN_KEYS}, " "but {len(join_keys)} is provided." + ) if len(set(join_keys)) != len(join_keys): raise ValueError(f"Duplicate join keys detected in: {join_keys}") - if len(_FEATURE_VIEW_ENTITY_TAG_DELIMITER.join(join_keys)) > _ENTITY_JOIN_KEY_LENGTH_LIMIT: - raise ValueError(f"Total length of join keys exceeded maximum length: {_ENTITY_JOIN_KEY_LENGTH_LIMIT}") - for k in join_keys: + # TODO(wezhou) move this logic into SqlIdentifier. if _ENTITY_JOIN_KEY_DELIMITER in k: raise ValueError(f"Invalid char `{_ENTITY_JOIN_KEY_DELIMITER}` detected in join key {k}") + if len(k) > _ENTITY_JOIN_KEY_LENGTH_LIMIT: + raise ValueError(f"Join key: {k} exceeds length limit {_ENTITY_JOIN_KEY_LENGTH_LIMIT}.") def _to_dict(self) -> Dict[str, str]: entity_dict = self.__dict__.copy() diff --git a/snowflake/ml/feature_store/feature_store.py b/snowflake/ml/feature_store/feature_store.py index 24cb1b17..34376081 100644 --- a/snowflake/ml/feature_store/feature_store.py +++ b/snowflake/ml/feature_store/feature_store.py @@ -12,20 +12,18 @@ from pytimeparse.timeparse import timeparse -from snowflake import connector from snowflake.ml._internal import telemetry from snowflake.ml._internal.exceptions import ( error_codes, exceptions as snowml_exceptions, ) -from snowflake.ml._internal.utils import identifier, query_result_checker as qrc +from snowflake.ml._internal.utils import identifier from snowflake.ml._internal.utils.sql_identifier import ( SqlIdentifier, to_sql_identifiers, ) from snowflake.ml.dataset.dataset import Dataset, FeatureStoreMetadata from snowflake.ml.feature_store.entity import ( - _ENTITY_JOIN_KEY_DELIMITER, _ENTITY_NAME_LENGTH_LIMIT, _FEATURE_VIEW_ENTITY_TAG_DELIMITER, Entity, @@ -240,13 +238,15 @@ def register_entity(self, entity: Entity) -> None: suppress_source_trace=True, ) - join_keys_str = _ENTITY_JOIN_KEY_DELIMITER.join(entity.join_keys) + # allowed_values will add double-quotes around each value, thus use resolved str here. + join_keys = [f"'{key.resolved()}'" for key in entity.join_keys] + join_keys_str = ",".join(join_keys) full_tag_name = self._get_fully_qualified_name(tag_name) - self._session.sql(f"CREATE TAG IF NOT EXISTS {full_tag_name} COMMENT = '{entity.desc}'").collect( - statement_params=self._telemetry_stmp - ) self._session.sql( - f"ALTER SCHEMA {self._config.full_schema_path} SET TAG {full_tag_name} = '{join_keys_str}'" + f"""CREATE TAG IF NOT EXISTS {full_tag_name} + ALLOWED_VALUES {join_keys_str} + COMMENT = '{entity.desc}' + """ ).collect(statement_params=self._telemetry_stmp) logger.info(f"Registered Entity {entity}.") @@ -681,30 +681,14 @@ def list_entities(self) -> DataFrame: Snowpark DataFrame containing the results. """ prefix_len = len(_ENTITY_TAG_PREFIX) + 1 - tag_values_df = self._session.sql( - f""" - SELECT SUBSTR(TAG_NAME,{prefix_len},{_ENTITY_NAME_LENGTH_LIMIT}) AS NAME, - TAG_VALUE AS JOIN_KEYS - FROM TABLE( - {self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES( - '{self._config.full_schema_path}', - 'SCHEMA' - ) - ) - WHERE TAG_NAME LIKE '{_ENTITY_TAG_PREFIX}%' - """ - ) - tag_metadata_df = self._session.sql( - f"SHOW TAGS LIKE '{_ENTITY_TAG_PREFIX}%' IN SCHEMA {self._config.full_schema_path}" - ) return cast( DataFrame, - tag_values_df.join( - right=tag_metadata_df.with_column("NAME", F.substr('"name"', prefix_len, _ENTITY_NAME_LENGTH_LIMIT)) - .with_column_renamed('"comment"', "DESC") - .select("NAME", "DESC"), - on=["NAME"], - how="left", + self._session.sql( + f"SHOW TAGS LIKE '{_ENTITY_TAG_PREFIX}%' IN SCHEMA {self._config.full_schema_path}" + ).select( + F.col('"name"').substr(prefix_len, _ENTITY_NAME_LENGTH_LIMIT).alias("NAME"), + F.col('"allowed_values"').alias("JOIN_KEYS"), + F.col('"comment"').alias("DESC"), ), ) @@ -725,54 +709,25 @@ def get_entity(self, name: str) -> Entity: SnowflakeMLException: [RuntimeError] Failed to find resources. """ name = SqlIdentifier(name) - - full_entity_tag_name = self._get_entity_name(name) - prefix_len = len(_ENTITY_TAG_PREFIX) + 1 - - found_tags = self._find_object("TAGS", full_entity_tag_name) - if len(found_tags) == 0: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.NOT_FOUND, - original_exception=ValueError(f"Cannot find Entity with name {name}."), - ) - try: - physical_name = self._get_entity_name(name) - tag_values = ( - qrc.SqlResultValidator( - self._session, - f""" - SELECT SUBSTR(TAG_NAME,{prefix_len},{_ENTITY_NAME_LENGTH_LIMIT}) AS NAME, - TAG_VALUE AS JOIN_KEYS - FROM TABLE( - {self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES( - '{self._config.full_schema_path}', - 'SCHEMA' - ) - ) - WHERE TAG_NAME LIKE '{physical_name.resolved()}' - AND TAG_DATABASE = '{self._config.database.resolved()}' - """, - self._telemetry_stmp, - ) - .has_dimensions(expected_rows=1) - .validate() - ) - except connector.DataError as e: # raised by SqlResultValidator - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.NOT_FOUND, - original_exception=ValueError(f"Cannot find Entity with name {name}."), - ) from e + result = self.list_entities().filter(F.col("NAME") == name.resolved()).collect() except Exception as e: raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INTERNAL_SNOWPARK_ERROR, - original_exception=RuntimeError(f"Failed to retrieve tag reference information: {e}"), + original_exception=RuntimeError(f"Failed to list entities: {e}"), ) from e + if len(result) == 0: + raise snowml_exceptions.SnowflakeMLException( + error_code=error_codes.NOT_FOUND, + original_exception=ValueError(f"Cannot find Entity with name: {name}."), + ) + raw_join_keys = result[0]["JOIN_KEYS"] + join_keys = raw_join_keys.strip("[]").split(",") return Entity( - name=tag_values[0]["NAME"], - join_keys=tag_values[0]["JOIN_KEYS"].split(_ENTITY_JOIN_KEY_DELIMITER), - desc=found_tags[0]["comment"], + name=result[0]["NAME"], + join_keys=join_keys, + desc=result[0]["DESC"], ) @dispatch_decorator(prpr_version="1.0.8") @@ -807,9 +762,6 @@ def delete_entity(self, name: str) -> None: tag_name = self._get_fully_qualified_name(self._get_entity_name(name)) try: - self._session.sql(f"ALTER SCHEMA {self._config.full_schema_path} UNSET TAG {tag_name}").collect( - statement_params=self._telemetry_stmp - ) self._session.sql(f"DROP TAG IF EXISTS {tag_name}").collect(statement_params=self._telemetry_stmp) except Exception as e: raise snowml_exceptions.SnowflakeMLException( diff --git a/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.ipynb b/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.ipynb index 6098577e..8c4aec7b 100644 --- a/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.ipynb +++ b/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.ipynb @@ -5,9 +5,9 @@ "id": "0bb54abc", "metadata": {}, "source": [ - "- snowflake-ml-python version: 1.1.0\n", - "- Feature Store PrPr Version: 0.3.1\n", - "- Updated date: 12/11/2023" + "- snowflake-ml-python version: 1.2.0\n", + "- Feature Store PrPr Version: 0.4.0\n", + "- Updated date: 1/3/2024" ] }, { diff --git a/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.pdf b/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.pdf index c101abb3..af5b0a63 100644 Binary files a/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.pdf and b/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.pdf differ diff --git a/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.ipynb b/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.ipynb index 19b163c8..009d2e10 100644 --- a/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.ipynb +++ b/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.ipynb @@ -5,9 +5,9 @@ "id": "4f029c96", "metadata": {}, "source": [ - "- snowflake-ml-python version: 1.1.0\n", - "- Feature Store PrPr version: 0.3.2\n", - "- Updated date: 12/11/2023" + "- snowflake-ml-python version: 1.2.0\n", + "- Feature Store PrPr version: 0.4.0\n", + "- Updated date: 1/3/2024" ] }, { diff --git a/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.pdf b/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.pdf index 8abb91cf..0a0cb99b 100644 Binary files a/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.pdf and b/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.pdf differ diff --git a/snowflake/ml/feature_store/tests/BUILD.bazel b/snowflake/ml/feature_store/tests/BUILD.bazel index aa0f3eba..c1c43f45 100644 --- a/snowflake/ml/feature_store/tests/BUILD.bazel +++ b/snowflake/ml/feature_store/tests/BUILD.bazel @@ -1,6 +1,9 @@ load("//bazel:py_rules.bzl", "py_library", "py_test") -package(default_visibility = ["//snowflake/ml/feature_store"]) +package(default_visibility = [ + "//bazel:snowml_public_common", + "//snowflake/ml/feature_store", +]) py_library( name = "common_utils", diff --git a/snowflake/ml/feature_store/tests/feature_store_object_test.py b/snowflake/ml/feature_store/tests/feature_store_object_test.py index 425f3c54..63b617be 100644 --- a/snowflake/ml/feature_store/tests/feature_store_object_test.py +++ b/snowflake/ml/feature_store/tests/feature_store_object_test.py @@ -126,10 +126,10 @@ def test_invalid_entity_name(self) -> None: Entity(name="my_entity", join_keys=["foo", "foo"]) def test_join_keys_exceed_limit(self) -> None: - with self.assertRaisesRegex(ValueError, "Total length of join keys exceeded maximum length.*"): + with self.assertRaisesRegex(ValueError, "Join key: .* exceeds length limit 256."): Entity(name="foo", join_keys=["f" * 257]) - with self.assertRaisesRegex(ValueError, "Total length of join keys exceeded maximum length.*"): - Entity(name="foo", join_keys=["foo" * 50] + ["bar" * 50]) + with self.assertRaisesRegex(ValueError, "Maximum number of join keys are 300, but .* is provided."): + Entity(name="foo", join_keys=["foo"] * 301) def test_equality_check(self) -> None: self.assertTrue(Entity(name="foo", join_keys=["a"]) == Entity(name="foo", join_keys=["a"])) diff --git a/snowflake/ml/feature_store/tests/feature_store_test.py b/snowflake/ml/feature_store/tests/feature_store_test.py index 63859b01..65246002 100644 --- a/snowflake/ml/feature_store/tests/feature_store_test.py +++ b/snowflake/ml/feature_store/tests/feature_store_test.py @@ -180,7 +180,7 @@ def test_create_and_delete_entities(self) -> None: fs = self._create_feature_store() entities = { - "User": Entity("USER", ["uid"]), + "User": Entity("USER", ['"uid"']), "Ad": Entity('"aD"', ["aid"]), "Product": Entity("Product", ["pid", "cid"]), } @@ -193,7 +193,7 @@ def test_create_and_delete_entities(self) -> None: actual_df=fs.list_entities().to_pandas(), target_data={ "NAME": ["aD", "PRODUCT", "USER"], - "JOIN_KEYS": ["AID", "PID,CID", "UID"], + "JOIN_KEYS": ['["AID"]', '["CID","PID"]', '["uid"]'], "DESC": ["", "", ""], }, sort_cols=["NAME"], @@ -214,7 +214,7 @@ def test_create_and_delete_entities(self) -> None: actual_df=fs.list_entities().to_pandas(), target_data={ "NAME": ["PRODUCT", "USER"], - "JOIN_KEYS": ["PID,CID", "UID"], + "JOIN_KEYS": ['["CID","PID"]', '["uid"]'], "DESC": ["", ""], }, sort_cols=["NAME"], @@ -229,7 +229,7 @@ def test_create_and_delete_entities(self) -> None: # test delete entity failure with active feature views # create a new feature view - sql = f"SELECT name, id AS uid FROM {self._mock_table}" + sql = f'SELECT name, id AS "uid" FROM {self._mock_table}' fv = FeatureView(name="fv", entities=[entities["User"]], feature_df=self._session.sql(sql), refresh_freq="1m") fs.register_feature_view(feature_view=fv, version="FIRST") with self.assertRaisesRegex(ValueError, "Cannot delete Entity .* due to active FeatureViews.*"): @@ -251,7 +251,7 @@ def test_retrieve_entity(self) -> None: actual_df=fs.list_entities().to_pandas(), target_data={ "NAME": ["FOO", "BAR"], - "JOIN_KEYS": ["A,B", "C"], + "JOIN_KEYS": ['["A","B"]', '["C"]'], "DESC": ["my foo", ""], }, sort_cols=["NAME"], @@ -264,7 +264,7 @@ def test_get_entity_system_error(self) -> None: snowpark_exceptions.SnowparkClientException("Intentional Integ Test Error"), ) - with self.assertRaisesRegex(RuntimeError, "Failed to find object .*"): + with self.assertRaisesRegex(RuntimeError, "Failed to list entities: .*"): fs.get_entity("foo") def test_register_entity_system_error(self) -> None: diff --git a/snowflake/ml/fileset/parquet_parser.py b/snowflake/ml/fileset/parquet_parser.py index b3f7d0ca..c851a159 100644 --- a/snowflake/ml/fileset/parquet_parser.py +++ b/snowflake/ml/fileset/parquet_parser.py @@ -1,4 +1,6 @@ import collections +import logging +import time from typing import Any, Deque, Dict, Iterator, List import fsspec @@ -83,7 +85,7 @@ def __iter__(self) -> Iterator[Dict[str, npt.NDArray[Any]]]: np.random.shuffle(files) pa_dataset: ds.Dataset = ds.dataset(files, format="parquet", filesystem=self._fs) - for rb in pa_dataset.to_batches(batch_size=self._dataset_batch_size): + for rb in _retryable_batches(pa_dataset, batch_size=self._dataset_batch_size): if self._shuffle: rb = rb.take(np.random.permutation(rb.num_rows)) self._rb_buffer.append(rb) @@ -138,3 +140,31 @@ def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]: array = column.to_numpy(zero_copy_only=False) batch_dict[column_schema.name] = array return batch_dict + + +def _retryable_batches( + dataset: ds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0 +) -> Iterator[pa.RecordBatch]: + """Make the Dataset to_batches retryable.""" + retries = 0 + current_batch_index = 0 + + while True: + try: + for batch_index, batch in enumerate(dataset.to_batches(batch_size=batch_size)): + if batch_index < current_batch_index: + # Skip batches that have already been processed + continue + + yield batch + current_batch_index = batch_index + 1 + # Exit the loop once all batches are processed + break + + except Exception as e: + if retries < max_retries: + retries += 1 + logging.info(f"Error encountered: {e}. Retrying {retries}/{max_retries}...") + time.sleep(delay) + else: + raise e diff --git a/snowflake/ml/model/BUILD.bazel b/snowflake/ml/model/BUILD.bazel index 1ce48b8f..5b0847cb 100644 --- a/snowflake/ml/model/BUILD.bazel +++ b/snowflake/ml/model/BUILD.bazel @@ -64,6 +64,30 @@ py_library( ], ) +py_library( + name = "model", + srcs = ["__init__.py"], + deps = [ + "//snowflake/ml/model/_client/model:model_impl", + "//snowflake/ml/model/_client/model:model_version_impl", + "//snowflake/ml/model/models:huggingface_pipeline", + "//snowflake/ml/model/models:llm_model", + ], +) + +py_test( + name = "package_visibility_test", + srcs = ["package_visibility_test.py"], + deps = [ + ":_api", + ":custom_model", + ":deploy_platforms", + ":model", + ":model_signature", + ":type_hints", + ], +) + py_test( name = "custom_model_test", srcs = ["custom_model_test.py"], diff --git a/snowflake/ml/model/__init__.py b/snowflake/ml/model/__init__.py new file mode 100644 index 00000000..bcebb67d --- /dev/null +++ b/snowflake/ml/model/__init__.py @@ -0,0 +1,6 @@ +from snowflake.ml.model._client.model.model_impl import Model +from snowflake.ml.model._client.model.model_version_impl import ModelVersion +from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel +from snowflake.ml.model.models.llm import LLM, LLMOptions + +__all__ = ["Model", "ModelVersion", "HuggingFacePipelineModel", "LLM", "LLMOptions"] diff --git a/snowflake/ml/model/_client/model/BUILD.bazel b/snowflake/ml/model/_client/model/BUILD.bazel index 1fc4f5a7..6fb55009 100644 --- a/snowflake/ml/model/_client/model/BUILD.bazel +++ b/snowflake/ml/model/_client/model/BUILD.bazel @@ -8,6 +8,7 @@ py_library( deps = [ ":model_version_impl", "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model/_client/ops:model_ops", ], @@ -28,11 +29,11 @@ py_library( name = "model_version_impl", srcs = ["model_version_impl.py"], deps = [ - ":model_method_info", "//snowflake/ml/_internal:telemetry", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model:model_signature", "//snowflake/ml/model/_client/ops:model_ops", + "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", ], ) @@ -45,15 +46,8 @@ py_test( "//snowflake/ml/model:model_signature", "//snowflake/ml/model/_client/ops:metadata_ops", "//snowflake/ml/model/_client/ops:model_ops", + "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", "//snowflake/ml/test_utils:mock_data_frame", "//snowflake/ml/test_utils:mock_session", ], ) - -py_library( - name = "model_method_info", - srcs = ["model_method_info.py"], - deps = [ - "//snowflake/ml/model:model_signature", - ], -) diff --git a/snowflake/ml/model/_client/model/model_impl.py b/snowflake/ml/model/_client/model/model_impl.py index cf781d26..f1591305 100644 --- a/snowflake/ml/model/_client/model/model_impl.py +++ b/snowflake/ml/model/_client/model/model_impl.py @@ -1,7 +1,9 @@ -from typing import List, Union +from typing import Dict, List, Optional, Tuple, Union + +import pandas as pd from snowflake.ml._internal import telemetry -from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml._internal.utils import identifier, sql_identifier from snowflake.ml.model._client.model import model_version_impl from snowflake.ml.model._client.ops import model_ops @@ -37,10 +39,12 @@ def __eq__(self, __value: object) -> bool: @property def name(self) -> str: + """Return the name of the model that can be used to refer to it in SQL.""" return self._model_name.identifier() @property def fully_qualified_name(self) -> str: + """Return the fully qualified name of the model that can be used to refer to it in SQL.""" return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name) @property @@ -49,6 +53,24 @@ def fully_qualified_name(self) -> str: subproject=_TELEMETRY_SUBPROJECT, ) def description(self) -> str: + """The description for the model. This is an alias of `comment`.""" + return self.comment + + @description.setter + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def description(self, description: str) -> None: + self.comment = description + + @property + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def comment(self) -> str: + """The comment to the model.""" statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, @@ -58,18 +80,18 @@ def description(self) -> str: statement_params=statement_params, ) - @description.setter + @comment.setter @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def description(self, description: str) -> None: + def comment(self, comment: str) -> None: statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) return self._model_ops.set_comment( - comment=description, + comment=comment, model_name=self._model_name, statement_params=statement_params, ) @@ -80,12 +102,13 @@ def description(self, description: str) -> None: subproject=_TELEMETRY_SUBPROJECT, ) def default(self) -> model_version_impl.ModelVersion: + """The default version of the model.""" statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, class_name=self.__class__.__name__, ) - default_version_name = self._model_ops._model_version_client.get_default_version( + default_version_name = self._model_ops.get_default_version( model_name=self._model_name, statement_params=statement_params ) return self.version(default_version_name) @@ -105,7 +128,7 @@ def default(self, version: Union[str, model_version_impl.ModelVersion]) -> None: version_name = sql_identifier.SqlIdentifier(version) else: version_name = version._version_name - self._model_ops._model_version_client.set_default_version( + self._model_ops.set_default_version( model_name=self._model_name, version_name=version_name, statement_params=statement_params ) @@ -114,13 +137,14 @@ def default(self, version: Union[str, model_version_impl.ModelVersion]) -> None: subproject=_TELEMETRY_SUBPROJECT, ) def version(self, version_name: str) -> model_version_impl.ModelVersion: - """Get a model version object given a version name in the model. + """ + Get a model version object given a version name in the model. Args: - version_name: The name of version + version_name: The name of the version. Raises: - ValueError: Raised when the version requested does not exist. + ValueError: When the requested version does not exist. Returns: The model version object. @@ -149,11 +173,11 @@ def version(self, version_name: str) -> model_version_impl.ModelVersion: project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def list_versions(self) -> List[model_version_impl.ModelVersion]: - """List all versions in the model. + def versions(self) -> List[model_version_impl.ModelVersion]: + """Get all versions in the model. Returns: - A List of ModelVersion object representing all versions in the model. + A list of ModelVersion objects representing all versions in the model. """ statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, @@ -172,5 +196,140 @@ def list_versions(self) -> List[model_version_impl.ModelVersion]: for version_name in version_names ] + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def show_versions(self) -> pd.DataFrame: + """Show information about all versions in the model. + + Returns: + A Pandas DataFrame showing information about all versions in the model. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + rows = self._model_ops.show_models_or_versions( + model_name=self._model_name, + statement_params=statement_params, + ) + return pd.DataFrame([row.as_dict() for row in rows]) + def delete_version(self, version_name: str) -> None: raise NotImplementedError("Deleting version has not been supported yet.") + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def show_tags(self) -> Dict[str, str]: + """Get a dictionary showing the tag and its value attached to the model. + + Returns: + The model version object. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + return self._model_ops.show_tags(model_name=self._model_name, statement_params=statement_params) + + def _parse_tag_name( + self, + tag_name: str, + ) -> Tuple[sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier]: + _tag_db, _tag_schema, _tag_name, _ = identifier.parse_schema_level_object_identifier(tag_name) + if _tag_db is None: + tag_db_id = self._model_ops._model_client._database_name + else: + tag_db_id = sql_identifier.SqlIdentifier(_tag_db) + + if _tag_schema is None: + tag_schema_id = self._model_ops._model_client._schema_name + else: + tag_schema_id = sql_identifier.SqlIdentifier(_tag_schema) + + if _tag_name is None: + raise ValueError(f"Unable parse the tag name `{tag_name}` you input.") + + tag_name_id = sql_identifier.SqlIdentifier(_tag_name) + + return tag_db_id, tag_schema_id, tag_name_id + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def get_tag(self, tag_name: str) -> Optional[str]: + """Get the value of a tag attached to the model. + + Args: + tag_name: The name of the tag, can be fully qualified. If not fully qualified, the database or schema of + the model will be used. + + Returns: + The tag value as a string if the tag is attached, otherwise None. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name) + return self._model_ops.get_tag_value( + model_name=self._model_name, + tag_database_name=tag_db_id, + tag_schema_name=tag_schema_id, + tag_name=tag_name_id, + statement_params=statement_params, + ) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def set_tag(self, tag_name: str, tag_value: str) -> None: + """Set the value of a tag, attaching it to the model if not. + + Args: + tag_name: The name of the tag, can be fully qualified. If not fully qualified, the database or schema of + the model will be used. + tag_value: The value of the tag + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name) + self._model_ops.set_tag( + model_name=self._model_name, + tag_database_name=tag_db_id, + tag_schema_name=tag_schema_id, + tag_name=tag_name_id, + tag_value=tag_value, + statement_params=statement_params, + ) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def unset_tag(self, tag_name: str) -> None: + """Unset a tag attached to a model. + + Args: + tag_name: The name of the tag, can be fully qualified. If not fully qualified, the database or schema of + the model will be used. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name) + self._model_ops.unset_tag( + model_name=self._model_name, + tag_database_name=tag_db_id, + tag_schema_name=tag_schema_id, + tag_name=tag_name_id, + statement_params=statement_params, + ) diff --git a/snowflake/ml/model/_client/model/model_impl_test.py b/snowflake/ml/model/_client/model/model_impl_test.py index 1fca7576..c097e09d 100644 --- a/snowflake/ml/model/_client/model/model_impl_test.py +++ b/snowflake/ml/model/_client/model/model_impl_test.py @@ -1,14 +1,14 @@ from typing import cast from unittest import mock +import pandas as pd from absl.testing import absltest from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.model._client.model import model_impl, model_version_impl from snowflake.ml.model._client.ops import model_ops -from snowflake.ml.model._client.sql import model_version from snowflake.ml.test_utils import mock_session -from snowflake.snowpark import Session +from snowflake.snowpark import Row, Session class ModelImplTest(absltest.TestCase): @@ -57,7 +57,7 @@ def test_version_2(self) -> None: statement_params=mock.ANY, ) - def test_list_versions(self) -> None: + def test_versions(self) -> None: m_mv_1 = model_version_impl.ModelVersion._ref( self.m_model._model_ops, model_name=sql_identifier.SqlIdentifier("MODEL"), @@ -73,13 +73,42 @@ def test_list_versions(self) -> None: "list_models_or_versions", return_value=[sql_identifier.SqlIdentifier("V1"), sql_identifier.SqlIdentifier("v1", case_sensitive=True)], ) as mock_list_models_or_versions: - mv_list = self.m_model.list_versions() + mv_list = self.m_model.versions() self.assertListEqual(mv_list, [m_mv_1, m_mv_2]) mock_list_models_or_versions.assert_called_once_with( model_name=sql_identifier.SqlIdentifier("MODEL"), statement_params=mock.ANY, ) + def test_show_versions(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="v1", + comment="This is a comment", + model_name="MODEL", + is_default_version=True, + ), + Row( + create_on="06/01", + name="V1", + comment="This is a comment", + model_name="MODEL", + is_default_version=False, + ), + ] + with mock.patch.object( + self.m_model._model_ops, + "show_models_or_versions", + return_value=m_list_res, + ) as mock_show_models_or_versions: + mv_info = self.m_model.show_versions() + pd.testing.assert_frame_equal(mv_info, pd.DataFrame([row.as_dict() for row in m_list_res])) + mock_show_models_or_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + def test_description_getter(self) -> None: with mock.patch.object( self.m_model._model_ops, "get_comment", return_value="this is a comment" @@ -99,34 +128,175 @@ def test_description_setter(self) -> None: statement_params=mock.ANY, ) - def test_default_getter(self) -> None: - mock_model_ops = absltest.mock.MagicMock(spec=model_ops.ModelOperator) - mock_model_version_client = absltest.mock.MagicMock(spec=model_version.ModelVersionSQLClient) - self.m_model._model_ops = mock_model_ops - mock_model_ops._session = self.m_session - mock_model_ops._model_version_client = mock_model_version_client - mock_model_version_client.get_default_version.return_value = "V1" + def test_comment_getter(self) -> None: + with mock.patch.object( + self.m_model._model_ops, "get_comment", return_value="this is a comment" + ) as mock_get_comment: + self.assertEqual("this is a comment", self.m_model.comment) + mock_get_comment.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) - default_model_version = self.m_model.default - self.assertEqual(default_model_version.version_name, "V1") - mock_model_version_client.get_default_version.assert_called() + def test_comment_setter(self) -> None: + with mock.patch.object(self.m_model._model_ops, "set_comment") as mock_set_comment: + self.m_model.comment = "this is a comment" + mock_set_comment.assert_called_once_with( + comment="this is a comment", + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + + def test_default_getter(self) -> None: + with mock.patch.object( + self.m_model._model_ops, + "get_default_version", + return_value=sql_identifier.SqlIdentifier("V1", case_sensitive=True), + ) as mock_get_default_version, mock.patch.object( + self.m_model._model_ops, "validate_existence", return_value=True + ): + self.assertEqual("V1", self.m_model.default.version_name) + mock_get_default_version.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) def test_default_setter(self) -> None: - mock_model_version_client = absltest.mock.MagicMock(spec=model_version.ModelVersionSQLClient) - self.m_model._model_ops._model_version_client = mock_model_version_client + with mock.patch.object(self.m_model._model_ops, "set_default_version") as mock_set_default_version: + self.m_model.default = "V1" # type: ignore[assignment] + mock_set_default_version.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=mock.ANY, + ) - # str - self.m_model.default = "V1" # type: ignore[assignment] - mock_model_version_client.set_default_version.assert_called() + with mock.patch.object(self.m_model._model_ops, "set_default_version") as mock_set_default_version: + mv = model_version_impl.ModelVersion._ref( + self.m_model._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V2"), + ) + self.m_model.default = mv + mock_set_default_version.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V2"), + statement_params=mock.ANY, + ) - # ModelVersion - mv = model_version_impl.ModelVersion._ref( - self.m_model._model_ops, - model_name=sql_identifier.SqlIdentifier("MODEL"), - version_name=sql_identifier.SqlIdentifier("V2"), - ) - self.m_model.default = mv - mock_model_version_client.set_default_version.assert_called() + def test_show_tags(self) -> None: + m_res = {'DB."schema".MYTAG': "tag content", 'MYDB.SCHEMA."my_another_tag"': "1"} + with mock.patch.object(self.m_model._model_ops, "show_tags", return_value=m_res) as mock_show_tags: + res = self.m_model.show_tags() + self.assertDictEqual(res, m_res) + mock_show_tags.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + + def test_get_tag_1(self) -> None: + with mock.patch.object(self.m_model._model_ops, "get_tag_value", return_value="tag content") as mock_get_tag: + res = self.m_model.get_tag(tag_name="MYTAG") + self.assertEqual(res, "tag content") + mock_get_tag.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("TEMP"), + tag_schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=mock.ANY, + ) + + def test_get_tag_2(self) -> None: + with mock.patch.object(self.m_model._model_ops, "get_tag_value", return_value="tag content") as mock_get_tag: + res = self.m_model.get_tag(tag_name='"schema".MYTAG') + self.assertEqual(res, "tag content") + mock_get_tag.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("TEMP"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=mock.ANY, + ) + + def test_get_tag_3(self) -> None: + with mock.patch.object(self.m_model._model_ops, "get_tag_value", return_value="tag content") as mock_get_tag: + res = self.m_model.get_tag(tag_name='DB."schema".MYTAG') + self.assertEqual(res, "tag content") + mock_get_tag.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=mock.ANY, + ) + + def test_set_tag_1(self) -> None: + with mock.patch.object(self.m_model._model_ops, "set_tag") as mock_set_tag: + self.m_model.set_tag(tag_name="MYTAG", tag_value="tag content") + mock_set_tag.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("TEMP"), + tag_schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + tag_value="tag content", + statement_params=mock.ANY, + ) + + def test_set_tag_2(self) -> None: + with mock.patch.object(self.m_model._model_ops, "set_tag") as mock_set_tag: + self.m_model.set_tag(tag_name='"schema".MYTAG', tag_value="tag content") + mock_set_tag.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("TEMP"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + tag_value="tag content", + statement_params=mock.ANY, + ) + + def test_set_tag_3(self) -> None: + with mock.patch.object(self.m_model._model_ops, "set_tag") as mock_set_tag: + self.m_model.set_tag(tag_name='DB."schema".MYTAG', tag_value="tag content") + mock_set_tag.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + tag_value="tag content", + statement_params=mock.ANY, + ) + + def test_unset_tag_1(self) -> None: + with mock.patch.object(self.m_model._model_ops, "unset_tag") as mock_unset_tag: + self.m_model.unset_tag(tag_name="MYTAG") + mock_unset_tag.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("TEMP"), + tag_schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=mock.ANY, + ) + + def test_unset_tag_2(self) -> None: + with mock.patch.object(self.m_model._model_ops, "unset_tag") as mock_unset_tag: + self.m_model.unset_tag(tag_name='"schema".MYTAG') + mock_unset_tag.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("TEMP"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=mock.ANY, + ) + + def test_unset_tag_3(self) -> None: + with mock.patch.object(self.m_model._model_ops, "unset_tag") as mock_unset_tag: + self.m_model.unset_tag(tag_name='DB."schema".MYTAG') + mock_unset_tag.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=mock.ANY, + ) if __name__ == "__main__": diff --git a/snowflake/ml/model/_client/model/model_method_info.py b/snowflake/ml/model/_client/model/model_method_info.py deleted file mode 100644 index 013eace5..00000000 --- a/snowflake/ml/model/_client/model/model_method_info.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import TypedDict - -from typing_extensions import Required - -from snowflake.ml.model import model_signature - - -class ModelMethodInfo(TypedDict): - """Method information. - - Attributes: - name: Name of the method to be called via SQL. - target_method: actual target method name to be called. - signature: The signature of the model method. - """ - - name: Required[str] - target_method: Required[str] - signature: Required[model_signature.ModelSignature] diff --git a/snowflake/ml/model/_client/model/model_version_impl.py b/snowflake/ml/model/_client/model/model_version_impl.py index df353f60..9e38a900 100644 --- a/snowflake/ml/model/_client/model/model_version_impl.py +++ b/snowflake/ml/model/_client/model/model_version_impl.py @@ -3,11 +3,12 @@ import pandas as pd +from snowflake import connector from snowflake.ml._internal import telemetry from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.model import model_signature -from snowflake.ml.model._client.model import model_method_info from snowflake.ml.model._client.ops import metadata_ops, model_ops +from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.snowpark import dataframe _TELEMETRY_PROJECT = "MLOps" @@ -49,14 +50,17 @@ def __eq__(self, __value: object) -> bool: @property def model_name(self) -> str: + """Return the name of the model to which the model version belongs, usable as a reference in SQL.""" return self._model_name.identifier() @property def version_name(self) -> str: + """Return the name of the version to which the model version belongs, usable as a reference in SQL.""" return self._version_name.identifier() @property def fully_qualified_model_name(self) -> str: + """Return the fully qualified name of the model to which the model version belongs.""" return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name) @property @@ -65,6 +69,24 @@ def fully_qualified_model_name(self) -> str: subproject=_TELEMETRY_SUBPROJECT, ) def description(self) -> str: + """The description for the model version. This is an alias of `comment`.""" + return self.comment + + @description.setter + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def description(self, description: str) -> None: + self.comment = description + + @property + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def comment(self) -> str: + """The comment to the model version.""" statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, @@ -75,18 +97,18 @@ def description(self) -> str: statement_params=statement_params, ) - @description.setter + @comment.setter @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def description(self, description: str) -> None: + def comment(self, comment: str) -> None: statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) return self._model_ops.set_comment( - comment=description, + comment=comment, model_name=self._model_name, version_name=self._version_name, statement_params=statement_params, @@ -96,11 +118,11 @@ def description(self, description: str) -> None: project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def list_metrics(self) -> Dict[str, Any]: + def show_metrics(self) -> Dict[str, Any]: """Show all metrics logged with the model version. Returns: - A dictionary showing the metrics + A dictionary showing the metrics. """ statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, @@ -118,15 +140,15 @@ def get_metric(self, metric_name: str) -> Any: """Get the value of a specific metric. Args: - metric_name: The name of the metric + metric_name: The name of the metric. Raises: - KeyError: Raised when the requested metric name does not exist. + KeyError: When the requested metric name does not exist. Returns: The value of the metric. """ - metrics = self.list_metrics() + metrics = self.show_metrics() if metric_name not in metrics: raise KeyError(f"Cannot find metric with name {metric_name}.") return metrics[metric_name] @@ -136,17 +158,17 @@ def get_metric(self, metric_name: str) -> Any: subproject=_TELEMETRY_SUBPROJECT, ) def set_metric(self, metric_name: str, value: Any) -> None: - """Set the value of a specific metric name + """Set the value of a specific metric. Args: - metric_name: The name of the metric + metric_name: The name of the metric. value: The value of the metric. """ statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - metrics = self.list_metrics() + metrics = self.show_metrics() metrics[metric_name] = value self._model_ops._metadata_ops.save( metadata_ops.ModelVersionMetadataSchema(metrics=metrics), @@ -166,13 +188,13 @@ def delete_metric(self, metric_name: str) -> None: metric_name: The name of the metric to be deleted. Raises: - KeyError: Raised when the requested metric name does not exist. + KeyError: When the requested metric name does not exist. """ statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - metrics = self.list_metrics() + metrics = self.show_metrics() if metric_name not in metrics: raise KeyError(f"Cannot find metric with name {metric_name}.") del metrics[metric_name] @@ -183,24 +205,12 @@ def delete_metric(self, metric_name: str) -> None: statement_params=statement_params, ) - @telemetry.send_api_usage_telemetry( - project=_TELEMETRY_PROJECT, - subproject=_TELEMETRY_SUBPROJECT, - ) - def list_methods(self) -> List[model_method_info.ModelMethodInfo]: - """List all method information in a model version that is callable. - - Returns: - A list of ModelMethodInfo object containing the following information: - - name: The name of the method to be called (both in SQL and in Python SDK). - - target_method: The original method name in the logged Python object. - - Signature: Python signature of the original method. - """ + # Only used when the model does not contains user_data with client SDK information. + def _legacy_show_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]: statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - # TODO(SNOW-986673, SNOW-986675): Avoid parsing manifest and meta file and put Python signature into user_data. manifest = self._model_ops.get_model_version_manifest( model_name=self._model_name, version_name=self._version_name, @@ -211,7 +221,7 @@ def list_methods(self) -> List[model_method_info.ModelMethodInfo]: version_name=self._version_name, statement_params=statement_params, ) - return_methods_info: List[model_method_info.ModelMethodInfo] = [] + return_functions_info: List[model_manifest_schema.ModelFunctionInfo] = [] for method in manifest["methods"]: # Method's name is resolved so we need to use case_sensitive as True to get the user-facing identifier. method_name = sql_identifier.SqlIdentifier(method["name"], case_sensitive=True).identifier() @@ -221,14 +231,48 @@ def list_methods(self) -> List[model_method_info.ModelMethodInfo]: ), f"Get unexpected handler name {method['handler']}" target_method = method["handler"].split(".")[1] signature_dict = model_meta["signatures"][target_method] - method_info = model_method_info.ModelMethodInfo( + fi = model_manifest_schema.ModelFunctionInfo( name=method_name, target_method=target_method, signature=model_signature.ModelSignature.from_dict(signature_dict), ) - return_methods_info.append(method_info) + return_functions_info.append(fi) + return return_functions_info + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + def show_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]: + """Show all functions information in a model version that is callable. - return return_methods_info + Returns: + A list of ModelFunctionInfo objects containing the following information: + + - name: The name of the function to be called (both in SQL and in Python SDK). + - target_method: The original method name in the logged Python object. + - signature: Python signature of the original method. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + try: + client_data = self._model_ops.get_client_data_in_user_data( + model_name=self._model_name, + version_name=self._version_name, + statement_params=statement_params, + ) + return [ + model_manifest_schema.ModelFunctionInfo( + name=fi["name"], + target_method=fi["target_method"], + signature=model_signature.ModelSignature.from_dict(fi["signature"]), + ) + for fi in client_data["functions"] + ] + except (NotImplementedError, ValueError, connector.DataError): + return self._legacy_show_functions() @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, @@ -238,52 +282,52 @@ def run( self, X: Union[pd.DataFrame, dataframe.DataFrame], *, - method_name: Optional[str] = None, + function_name: Optional[str] = None, ) -> Union[pd.DataFrame, dataframe.DataFrame]: - """Invoke a method in a model version object + """Invoke a method in a model version object. Args: - X: The input data. Could be pandas DataFrame or Snowpark DataFrame - method_name: The method name to run. It is the name you will use to call a method in SQL. Defaults to None. - It can only be None if there is only 1 method. + X: The input data, which could be a pandas DataFrame or Snowpark DataFrame. + function_name: The function name to run. It is the name used to call a function in SQL. + Defaults to None. It can only be None if there is only 1 method. Raises: - ValueError: No method with the corresponding name is available. - ValueError: There are more than 1 target methods available in the model but no method name specified. + ValueError: When no method with the corresponding name is available. + ValueError: When there are more than 1 target methods available in the model but no function name specified. Returns: - The prediction data. + The prediction data. It would be the same type dataframe as your input. """ statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - methods: List[model_method_info.ModelMethodInfo] = self.list_methods() - if method_name: - req_method_name = sql_identifier.SqlIdentifier(method_name).identifier() - find_method: Callable[[model_method_info.ModelMethodInfo], bool] = ( + functions: List[model_manifest_schema.ModelFunctionInfo] = self.show_functions() + if function_name: + req_method_name = sql_identifier.SqlIdentifier(function_name).identifier() + find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = ( lambda method: method["name"] == req_method_name ) - target_method_info = next( - filter(find_method, methods), + target_function_info = next( + filter(find_method, functions), None, ) - if target_method_info is None: + if target_function_info is None: raise ValueError( - f"There is no method with name {method_name} available in the model" + f"There is no method with name {function_name} available in the model" f" {self.fully_qualified_model_name} version {self.version_name}" ) - elif len(methods) != 1: + elif len(functions) != 1: raise ValueError( f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}" f" version {self.version_name}. Please specify a `method_name` when calling the `run` method." ) else: - target_method_info = methods[0] + target_function_info = functions[0] return self._model_ops.invoke_method( - method_name=sql_identifier.SqlIdentifier(target_method_info["name"]), - signature=target_method_info["signature"], + method_name=sql_identifier.SqlIdentifier(target_function_info["name"]), + signature=target_function_info["signature"], X=X, model_name=self._model_name, version_name=self._version_name, diff --git a/snowflake/ml/model/_client/model/model_version_impl_test.py b/snowflake/ml/model/_client/model/model_version_impl_test.py index 84f30fcf..100d8921 100644 --- a/snowflake/ml/model/_client/model/model_version_impl_test.py +++ b/snowflake/ml/model/_client/model/model_version_impl_test.py @@ -9,6 +9,7 @@ from snowflake.ml.model import model_signature from snowflake.ml.model._client.model import model_version_impl from snowflake.ml.model._client.ops import metadata_ops, model_ops +from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.test_utils import mock_data_frame, mock_session from snowflake.snowpark import Session @@ -41,10 +42,10 @@ def test_property(self) -> None: self.assertEqual(self.m_mv.fully_qualified_model_name, 'TEMP."test".MODEL') self.assertEqual(self.m_mv.version_name, '"v1"') - def test_list_metrics(self) -> None: + def test_show_metrics(self) -> None: m_metadata = metadata_ops.ModelVersionMetadataSchema(metrics={}) with mock.patch.object(self.m_mv._model_ops._metadata_ops, "load", return_value=m_metadata) as mock_load: - self.assertDictEqual({}, self.m_mv.list_metrics()) + self.assertDictEqual({}, self.m_mv.show_metrics()) mock_load.assert_called_once_with( model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), @@ -140,7 +141,7 @@ def test_delete_metric_2(self) -> None: ) mock_save.assert_not_called() - def test_list_methods(self) -> None: + def test_show_functions_1(self) -> None: m_manifest = { "manifest_version": "1.0", "runtimes": { @@ -212,11 +213,13 @@ def test_list_methods(self) -> None: ) ) with mock.patch.object( + self.m_mv._model_ops, "get_client_data_in_user_data", side_effect=NotImplementedError() + ), mock.patch.object( self.m_mv._model_ops, "get_model_version_manifest", return_value=m_manifest ) as mock_get_model_version_manifest, mock.patch.object( self.m_mv._model_ops, "get_model_version_native_packing_meta", return_value=m_meta_yaml ) as mock_get_model_version_native_packing_meta: - methods = self.m_mv.list_methods() + methods = self.m_mv.show_functions() mock_get_model_version_manifest.assert_called_once_with( model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), @@ -243,6 +246,57 @@ def test_list_methods(self) -> None: ], ) + def test_show_functions_2(self) -> None: + m_function_info = [ + model_manifest_schema.ModelFunctionInfoDict( + { + "name": '"predict"', + "target_method": "predict", + "signature": _DUMMY_SIG["predict"].to_dict(), + } + ), + model_manifest_schema.ModelFunctionInfoDict( + { + "name": "__CALL__", + "target_method": "__call__", + "signature": _DUMMY_SIG["predict"].to_dict(), + } + ), + ] + m_user_data = model_manifest_schema.SnowparkMLDataDict( + schema_version=model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION, functions=m_function_info + ) + with mock.patch.object( + self.m_mv._model_ops, "get_client_data_in_user_data", return_value=m_user_data + ) as mock_get_client_data_in_user_data, mock.patch.object( + self.m_mv._model_ops, "get_model_version_manifest" + ) as mock_get_model_version_manifest, mock.patch.object( + self.m_mv._model_ops, "get_model_version_native_packing_meta" + ) as mock_get_model_version_native_packing_meta: + methods = self.m_mv.show_functions() + mock_get_client_data_in_user_data.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + self.assertEqual( + methods, + [ + { + "name": '"predict"', + "target_method": "predict", + "signature": _DUMMY_SIG["predict"], + }, + { + "name": "__CALL__", + "target_method": "__call__", + "signature": _DUMMY_SIG["predict"], + }, + ], + ) + mock_get_model_version_manifest.assert_not_called() + mock_get_model_version_native_packing_meta.assert_not_called() + def test_run(self) -> None: m_df = mock_data_frame.MockDataFrame() m_methods = [ @@ -257,22 +311,22 @@ def test_run(self) -> None: "signature": _DUMMY_SIG["predict"], }, ] - with mock.patch.object(self.m_mv, "list_methods", return_value=m_methods) as mock_list_methods: + with mock.patch.object(self.m_mv, "show_functions", return_value=m_methods) as mock_list_methods: with self.assertRaisesRegex(ValueError, "There is no method with name PREDICT available in the model"): - self.m_mv.run(m_df, method_name="PREDICT") + self.m_mv.run(m_df, function_name="PREDICT") mock_list_methods.assert_called_once_with() - with mock.patch.object(self.m_mv, "list_methods", return_value=m_methods) as mock_list_methods: + with mock.patch.object(self.m_mv, "show_functions", return_value=m_methods) as mock_list_methods: with self.assertRaisesRegex(ValueError, "There are more than 1 target methods available in the model"): self.m_mv.run(m_df) mock_list_methods.assert_called_once_with() with mock.patch.object( - self.m_mv, "list_methods", return_value=m_methods + self.m_mv, "show_functions", return_value=m_methods ) as mock_list_methods, mock.patch.object( self.m_mv._model_ops, "invoke_method", return_value=m_df ) as mock_invoke_method: - self.m_mv.run(m_df, method_name='"predict"') + self.m_mv.run(m_df, function_name='"predict"') mock_list_methods.assert_called_once_with() mock_invoke_method.assert_called_once_with( method_name='"predict"', @@ -284,11 +338,11 @@ def test_run(self) -> None: ) with mock.patch.object( - self.m_mv, "list_methods", return_value=m_methods + self.m_mv, "show_functions", return_value=m_methods ) as mock_list_methods, mock.patch.object( self.m_mv._model_ops, "invoke_method", return_value=m_df ) as mock_invoke_method: - self.m_mv.run(m_df, method_name="__call__") + self.m_mv.run(m_df, function_name="__call__") mock_list_methods.assert_called_once_with() mock_invoke_method.assert_called_once_with( method_name="__CALL__", @@ -310,7 +364,7 @@ def test_run_without_method_name(self) -> None: ] with mock.patch.object( - self.m_mv, "list_methods", return_value=m_methods + self.m_mv, "show_functions", return_value=m_methods ) as mock_list_methods, mock.patch.object( self.m_mv._model_ops, "invoke_method", return_value=m_df ) as mock_invoke_method: @@ -346,6 +400,27 @@ def test_description_setter(self) -> None: statement_params=mock.ANY, ) + def test_comment_getter(self) -> None: + with mock.patch.object( + self.m_mv._model_ops, "get_comment", return_value="this is a comment" + ) as mock_get_comment: + self.assertEqual("this is a comment", self.m_mv.comment) + mock_get_comment.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + + def test_comment_setter(self) -> None: + with mock.patch.object(self.m_mv._model_ops, "set_comment") as mock_set_comment: + self.m_mv.comment = "this is a comment" + mock_set_comment.assert_called_once_with( + comment="this is a comment", + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=mock.ANY, + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_client/ops/BUILD.bazel b/snowflake/ml/model/_client/ops/BUILD.bazel index 4775fa4d..1e914e93 100644 --- a/snowflake/ml/model/_client/ops/BUILD.bazel +++ b/snowflake/ml/model/_client/ops/BUILD.bazel @@ -1,8 +1,9 @@ load("//bazel:py_rules.bzl", "py_library", "py_test") package(default_visibility = [ + "//bazel:snowml_public_common", "//snowflake/ml/model/_client/model:__pkg__", - "//snowflake/ml/registry:__pkg__", + "//snowflake/ml/registry/_manager:__pkg__", ]) py_library( @@ -10,12 +11,15 @@ py_library( srcs = ["model_ops.py"], deps = [ ":metadata_ops", + "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model:model_signature", "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_client/sql:model", "//snowflake/ml/model/_client/sql:model_version", "//snowflake/ml/model/_client/sql:stage", + "//snowflake/ml/model/_client/sql:tag", "//snowflake/ml/model/_model_composer:model_composer", "//snowflake/ml/model/_model_composer/model_manifest", "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", @@ -30,6 +34,7 @@ py_test( srcs = ["model_ops_test.py"], deps = [ ":model_ops", + "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model:model_signature", "//snowflake/ml/model/_signatures:snowpark_handler", diff --git a/snowflake/ml/model/_client/ops/metadata_ops.py b/snowflake/ml/model/_client/ops/metadata_ops.py index 4ba7c11d..5f0a5818 100644 --- a/snowflake/ml/model/_client/ops/metadata_ops.py +++ b/snowflake/ml/model/_client/ops/metadata_ops.py @@ -68,9 +68,7 @@ def _get_current_metadata_dict( version_info_list = self._model_client.show_versions( model_name=model_name, version_name=version_name, statement_params=statement_params ) - assert len(version_info_list) == 1 - version_info = version_info_list[0] - metadata_str = version_info.metadata + metadata_str = version_info_list[0][self._model_client.MODEL_VERSION_METADATA_COL_NAME] if not metadata_str: return {} res = json.loads(metadata_str) diff --git a/snowflake/ml/model/_client/ops/model_ops.py b/snowflake/ml/model/_client/ops/model_ops.py index 40ce8914..9794807e 100644 --- a/snowflake/ml/model/_client/ops/model_ops.py +++ b/snowflake/ml/model/_client/ops/model_ops.py @@ -1,16 +1,19 @@ +import json import pathlib import tempfile from typing import Any, Dict, List, Optional, Union, cast import yaml +from packaging import version -from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml._internal.utils import identifier, snowflake_env, sql_identifier from snowflake.ml.model import model_signature, type_hints from snowflake.ml.model._client.ops import metadata_ops from snowflake.ml.model._client.sql import ( model as model_sql, model_version as model_version_sql, stage as stage_sql, + tag as tag_sql, ) from snowflake.ml.model._model_composer import model_composer from snowflake.ml.model._model_composer.model_manifest import ( @@ -19,9 +22,11 @@ ) from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema from snowflake.ml.model._signatures import snowpark_handler -from snowflake.snowpark import dataframe, session +from snowflake.snowpark import dataframe, row, session from snowflake.snowpark._internal import utils as snowpark_utils +_TAG_ON_MODEL_AVAILABLE_VERSION = version.parse("8.2.0") + class ModelOperator: def __init__( @@ -50,6 +55,11 @@ def __init__( database_name=database_name, schema_name=schema_name, ) + self._tag_client = tag_sql.ModuleTagSQLClient( + session, + database_name=database_name, + schema_name=schema_name, + ) self._metadata_ops = metadata_ops.MetadataOperator( session, database_name=database_name, @@ -109,22 +119,39 @@ def create_from_stage( statement_params=statement_params, ) - def list_models_or_versions( + def show_models_or_versions( self, *, model_name: Optional[sql_identifier.SqlIdentifier] = None, statement_params: Optional[Dict[str, Any]] = None, - ) -> List[sql_identifier.SqlIdentifier]: + ) -> List[row.Row]: if model_name: - res = self._model_client.show_versions( + return self._model_client.show_versions( model_name=model_name, + validate_result=False, statement_params=statement_params, ) else: - res = self._model_client.show_models( + return self._model_client.show_models( + validate_result=False, statement_params=statement_params, ) - return [sql_identifier.SqlIdentifier(row.name, case_sensitive=True) for row in res] + + def list_models_or_versions( + self, + *, + model_name: Optional[sql_identifier.SqlIdentifier] = None, + statement_params: Optional[Dict[str, Any]] = None, + ) -> List[sql_identifier.SqlIdentifier]: + res = self.show_models_or_versions( + model_name=model_name, + statement_params=statement_params, + ) + if model_name: + col_name = self._model_client.MODEL_VERSION_NAME_COL_NAME + else: + col_name = self._model_client.MODEL_NAME_COL_NAME + return [sql_identifier.SqlIdentifier(row[col_name], case_sensitive=True) for row in res] def validate_existence( self, @@ -137,11 +164,13 @@ def validate_existence( res = self._model_client.show_versions( model_name=model_name, version_name=version_name, + validate_result=False, statement_params=statement_params, ) else: res = self._model_client.show_models( model_name=model_name, + validate_result=False, statement_params=statement_params, ) return len(res) == 1 @@ -159,13 +188,14 @@ def get_comment( version_name=version_name, statement_params=statement_params, ) + col_name = self._model_client.MODEL_VERSION_COMMENT_COL_NAME else: res = self._model_client.show_models( model_name=model_name, statement_params=statement_params, ) - assert len(res) == 1 - return cast(str, res[0].comment) + col_name = self._model_client.MODEL_COMMENT_COL_NAME + return cast(str, res[0][col_name]) def set_comment( self, @@ -189,6 +219,123 @@ def set_comment( statement_params=statement_params, ) + def set_default_version( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + if not self.validate_existence( + model_name=model_name, version_name=version_name, statement_params=statement_params + ): + raise ValueError(f"You cannot set version {version_name} as default version as it does not exist.") + self._model_version_client.set_default_version( + model_name=model_name, version_name=version_name, statement_params=statement_params + ) + + def get_default_version( + self, + *, + model_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> sql_identifier.SqlIdentifier: + res = self._model_client.show_models(model_name=model_name, statement_params=statement_params)[0] + return sql_identifier.SqlIdentifier( + res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True + ) + + def get_tag_value( + self, + *, + model_name: sql_identifier.SqlIdentifier, + tag_database_name: sql_identifier.SqlIdentifier, + tag_schema_name: sql_identifier.SqlIdentifier, + tag_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> Optional[str]: + r = self._tag_client.get_tag_value( + module_name=model_name, + tag_database_name=tag_database_name, + tag_schema_name=tag_schema_name, + tag_name=tag_name, + statement_params=statement_params, + ) + value = r.TAG_VALUE + if value is None: + return value + return str(value) + + def show_tags( + self, + *, + model_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, str]: + tags_info = self._tag_client.get_tag_list( + module_name=model_name, + statement_params=statement_params, + ) + res: Dict[str, str] = { + identifier.get_schema_level_object_identifier( + sql_identifier.SqlIdentifier(r.TAG_DATABASE, case_sensitive=True), + sql_identifier.SqlIdentifier(r.TAG_SCHEMA, case_sensitive=True), + sql_identifier.SqlIdentifier(r.TAG_NAME, case_sensitive=True), + ): str(r.TAG_VALUE) + for r in tags_info + } + return res + + def set_tag( + self, + *, + model_name: sql_identifier.SqlIdentifier, + tag_database_name: sql_identifier.SqlIdentifier, + tag_schema_name: sql_identifier.SqlIdentifier, + tag_name: sql_identifier.SqlIdentifier, + tag_value: str, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + sf_version = snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params) + if sf_version >= _TAG_ON_MODEL_AVAILABLE_VERSION: + self._tag_client.set_tag_on_model( + model_name=model_name, + tag_database_name=tag_database_name, + tag_schema_name=tag_schema_name, + tag_name=tag_name, + tag_value=tag_value, + statement_params=statement_params, + ) + else: + raise NotImplementedError( + f"`set_tag` won't work before Snowflake version {_TAG_ON_MODEL_AVAILABLE_VERSION}," + f" currently is {sf_version}" + ) + + def unset_tag( + self, + *, + model_name: sql_identifier.SqlIdentifier, + tag_database_name: sql_identifier.SqlIdentifier, + tag_schema_name: sql_identifier.SqlIdentifier, + tag_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + sf_version = snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params) + if sf_version >= _TAG_ON_MODEL_AVAILABLE_VERSION: + self._tag_client.unset_tag_on_model( + model_name=model_name, + tag_database_name=tag_database_name, + tag_schema_name=tag_schema_name, + tag_name=tag_name, + statement_params=statement_params, + ) + else: + raise NotImplementedError( + f"`unset_tag` won't work before Snowflake version {_TAG_ON_MODEL_AVAILABLE_VERSION}," + f" currently is {sf_version}" + ) + def get_model_version_manifest( self, *, @@ -228,6 +375,27 @@ def get_model_version_native_packing_meta( raw_model_meta = yaml.safe_load(f) return model_meta.ModelMetadata._validate_model_metadata(raw_model_meta) + def get_client_data_in_user_data( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> model_manifest_schema.SnowparkMLDataDict: + if ( + snowflake_env.get_current_snowflake_version(self._session) + < model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION + ): + raise NotImplementedError("User_data has not been supported yet.") + raw_user_data_json_string = self._model_client.show_versions( + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + )[0][self._model_client.MODEL_VERSION_USER_DATA_COL_NAME] + raw_user_data = json.loads(raw_user_data_json_string) + assert isinstance(raw_user_data, dict), "user data should be a dictionary" + return model_manifest.ModelManifest.parse_client_data_from_user_data(raw_user_data) + def invoke_method( self, *, diff --git a/snowflake/ml/model/_client/ops/model_ops_test.py b/snowflake/ml/model/_client/ops/model_ops_test.py index 317f39fd..dd041e3c 100644 --- a/snowflake/ml/model/_client/ops/model_ops_test.py +++ b/snowflake/ml/model/_client/ops/model_ops_test.py @@ -1,3 +1,4 @@ +import json import os import pathlib import tempfile @@ -9,10 +10,12 @@ import pandas as pd import yaml from absl.testing import absltest +from packaging import version -from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml._internal.utils import snowflake_env, sql_identifier from snowflake.ml.model import model_signature from snowflake.ml.model._client.ops import model_ops +from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.model._signatures import snowpark_handler from snowflake.ml.test_utils import mock_data_frame, mock_session from snowflake.snowpark import DataFrame, Row, Session, types as spt @@ -38,6 +41,9 @@ def setUp(self) -> None: database_name=sql_identifier.SqlIdentifier("TEMP"), schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), ) + snowflake_env.get_current_snowflake_version = mock.MagicMock( + return_value=model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION + ) def test_prepare_model_stage_path(self) -> None: with mock.patch.object(self.m_ops._stage_client, "create_tmp_stage",) as mock_create_stage, mock.patch.object( @@ -53,6 +59,74 @@ def test_prepare_model_stage_path(self) -> None: statement_params=self.m_statement_params, ) + def test_show_models_or_versions_1(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="MODEL", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + default_version_name="V1", + ), + Row( + create_on="06/01", + name="Model", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + default_version_name="v1", + ), + ] + with mock.patch.object(self.m_ops._model_client, "show_models", return_value=m_list_res) as mock_show_models: + res = self.m_ops.show_models_or_versions( + statement_params=self.m_statement_params, + ) + self.assertListEqual( + res, + m_list_res, + ) + mock_show_models.assert_called_once_with( + validate_result=False, + statement_params=self.m_statement_params, + ) + + def test_show_models_or_versions_2(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="v1", + comment="This is a comment", + model_name="MODEL", + is_default_version=True, + ), + Row( + create_on="06/01", + name="V1", + comment="This is a comment", + model_name="MODEL", + is_default_version=False, + ), + ] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + res = self.m_ops.show_models_or_versions( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + self.assertListEqual( + res, + m_list_res, + ) + mock_show_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + validate_result=False, + statement_params=self.m_statement_params, + ) + def test_list_models_or_versions_1(self) -> None: m_list_res = [ Row( @@ -62,6 +136,7 @@ def test_list_models_or_versions_1(self) -> None: model_name="MODEL", database_name="TEMP", schema_name="test", + default_version_name="V1", ), Row( create_on="06/01", @@ -70,6 +145,7 @@ def test_list_models_or_versions_1(self) -> None: model_name="MODEL", database_name="TEMP", schema_name="test", + default_version_name="v1", ), ] with mock.patch.object(self.m_ops._model_client, "show_models", return_value=m_list_res) as mock_show_models: @@ -84,6 +160,7 @@ def test_list_models_or_versions_1(self) -> None: ], ) mock_show_models.assert_called_once_with( + validate_result=False, statement_params=self.m_statement_params, ) @@ -120,6 +197,7 @@ def test_list_models_or_versions_2(self) -> None: ) mock_show_versions.assert_called_once_with( model_name=sql_identifier.SqlIdentifier("MODEL"), + validate_result=False, statement_params=self.m_statement_params, ) @@ -132,6 +210,7 @@ def test_validate_existence_1(self) -> None: model_name="MODEL", database_name="TEMP", schema_name="test", + default_version_name="V1", ), ] with mock.patch.object(self.m_ops._model_client, "show_models", return_value=m_list_res) as mock_show_models: @@ -142,6 +221,7 @@ def test_validate_existence_1(self) -> None: self.assertTrue(res) mock_show_models.assert_called_once_with( model_name=sql_identifier.SqlIdentifier("Model", case_sensitive=True), + validate_result=False, statement_params=self.m_statement_params, ) @@ -155,6 +235,7 @@ def test_validate_existence_2(self) -> None: self.assertFalse(res) mock_show_models.assert_called_once_with( model_name=sql_identifier.SqlIdentifier("Model", case_sensitive=True), + validate_result=False, statement_params=self.m_statement_params, ) @@ -180,6 +261,7 @@ def test_validate_existence_3(self) -> None: mock_show_versions.assert_called_once_with( model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + validate_result=False, statement_params=self.m_statement_params, ) @@ -197,6 +279,157 @@ def test_validate_existence_4(self) -> None: mock_show_versions.assert_called_once_with( model_name=sql_identifier.SqlIdentifier("MODEL"), version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + validate_result=False, + statement_params=self.m_statement_params, + ) + + def test_get_tag_value_1(self) -> None: + m_list_res: Row = Row(TAG_VALUE="a") + with mock.patch.object(self.m_ops._tag_client, "get_tag_value", return_value=m_list_res) as mock_get_tag_value: + res = self.m_ops.get_tag_value( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=self.m_statement_params, + ) + self.assertEqual(res, "a") + mock_get_tag_value.assert_called_once_with( + module_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=self.m_statement_params, + ) + + def test_get_tag_value_2(self) -> None: + m_list_res: Row = Row(TAG_VALUE=1) + with mock.patch.object(self.m_ops._tag_client, "get_tag_value", return_value=m_list_res) as mock_get_tag_value: + res = self.m_ops.get_tag_value( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=self.m_statement_params, + ) + self.assertEqual(res, "1") + mock_get_tag_value.assert_called_once_with( + module_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=self.m_statement_params, + ) + + def test_get_tag_value_3(self) -> None: + m_list_res: Row = Row(TAG_VALUE=None) + with mock.patch.object(self.m_ops._tag_client, "get_tag_value", return_value=m_list_res) as mock_get_tag_value: + res = self.m_ops.get_tag_value( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=self.m_statement_params, + ) + self.assertIsNone(res) + mock_get_tag_value.assert_called_once_with( + module_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=self.m_statement_params, + ) + + def test_show_tags(self) -> None: + m_list_res: List[Row] = [ + Row(TAG_DATABASE="DB", TAG_SCHEMA="schema", TAG_NAME="MYTAG", TAG_VALUE="tag content"), + Row(TAG_DATABASE="MYDB", TAG_SCHEMA="SCHEMA", TAG_NAME="my_another_tag", TAG_VALUE=1), + ] + with mock.patch.object(self.m_ops._tag_client, "get_tag_list", return_value=m_list_res) as mock_get_tag_list: + res = self.m_ops.show_tags( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + self.assertDictEqual(res, {'DB."schema".MYTAG': "tag content", 'MYDB.SCHEMA."my_another_tag"': "1"}) + mock_get_tag_list.assert_called_once_with( + module_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + + def test_set_tag_fail(self) -> None: + with mock.patch.object( + snowflake_env, + "get_current_snowflake_version", + return_value=version.parse("8.1.0+23d9c914e5"), + ), mock.patch.object(self.m_ops._tag_client, "set_tag_on_model") as mock_set_tag: + with self.assertRaisesRegex(NotImplementedError, "`set_tag` won't work before Snowflake version"): + self.m_ops.set_tag( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + tag_value="tag content", + statement_params=self.m_statement_params, + ) + mock_set_tag.assert_not_called() + + def test_set_tag(self) -> None: + with mock.patch.object( + snowflake_env, + "get_current_snowflake_version", + return_value=version.parse("8.2.0+23d9c914e5"), + ), mock.patch.object(self.m_ops._tag_client, "set_tag_on_model") as mock_set_tag: + self.m_ops.set_tag( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + tag_value="tag content", + statement_params=self.m_statement_params, + ) + mock_set_tag.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + tag_value="tag content", + statement_params=self.m_statement_params, + ) + + def test_unset_tag_fail(self) -> None: + with mock.patch.object( + snowflake_env, + "get_current_snowflake_version", + return_value=version.parse("8.1.0+23d9c914e5"), + ), mock.patch.object(self.m_ops._tag_client, "unset_tag_on_model") as mock_unset_tag: + with self.assertRaisesRegex(NotImplementedError, "`unset_tag` won't work before Snowflake version"): + self.m_ops.unset_tag( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=self.m_statement_params, + ) + mock_unset_tag.assert_not_called() + + def test_unset_tag(self) -> None: + with mock.patch.object( + snowflake_env, + "get_current_snowflake_version", + return_value=version.parse("8.2.0+23d9c914e5"), + ), mock.patch.object(self.m_ops._tag_client, "unset_tag_on_model") as mock_unset_tag: + self.m_ops.unset_tag( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=self.m_statement_params, + ) + mock_unset_tag.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), statement_params=self.m_statement_params, ) @@ -341,6 +574,7 @@ def test_create_from_stage_2(self) -> None: model_name="MODEL", database_name="TEMP", schema_name="test", + default_version_name="V1", ), ] with mock.patch.object( @@ -377,6 +611,7 @@ def test_create_from_stage_3(self) -> None: model_name="MODEL", database_name="TEMP", schema_name="test", + default_version_name="V1", ), ) m_list_res_versions = [ @@ -407,6 +642,45 @@ def test_create_from_stage_3(self) -> None: mock_create_from_stage.assert_not_called() mock_add_version_from_stagel.assert_not_called() + def test_get_client_data_in_user_data_1(self) -> None: + m_client_data = { + "schema_version": model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION, + "functions": [ + model_manifest_schema.ModelFunctionInfoDict( + name="PREDICT", + target_method="predict", + signature=_DUMMY_SIG["predict"].to_dict(), + ) + ], + } + m_list_res = [ + Row( + create_on="06/01", + name="v1", + comment="This is a comment", + model_name="MODEL", + user_data=json.dumps({model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME: m_client_data}), + is_default_version=True, + ), + ] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions: + res = self.m_ops.get_client_data_in_user_data( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + statement_params=self.m_statement_params, + ) + self.assertDictEqual( + res, + m_client_data, + ) + mock_show_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + statement_params=self.m_statement_params, + ) + def test_invoke_method_1(self) -> None: pd_df = pd.DataFrame([["1.0"]], columns=["input"], dtype=np.float32) m_sig = _DUMMY_SIG["predict"] @@ -540,6 +814,7 @@ def test_get_comment_2(self) -> None: model_name="MODEL", database_name="TEMP", schema_name="test", + default_version_name="V1", ), ] with mock.patch.object( @@ -585,6 +860,83 @@ def test_set_comment_2(self) -> None: statement_params=self.m_statement_params, ) + def test_get_default_version(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="MODEL", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + default_version_name="v1", + ), + ] + with mock.patch.object(self.m_ops._model_client, "show_models", return_value=m_list_res) as mock_show_models: + res = self.m_ops.get_default_version( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + self.assertEqual(res, sql_identifier.SqlIdentifier("v1", case_sensitive=True)) + mock_show_models.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=self.m_statement_params, + ) + + def test_set_default_version_1(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="v1", + comment="This is a comment", + model_name="MODEL", + is_default_version=True, + ), + ] + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_list_res + ) as mock_show_versions, mock.patch.object( + self.m_ops._model_version_client, "set_default_version" + ) as mock_set_default_version: + self.m_ops.set_default_version( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + statement_params=self.m_statement_params, + ) + mock_show_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + validate_result=False, + statement_params=self.m_statement_params, + ) + mock_set_default_version.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier('"v1"'), + statement_params=self.m_statement_params, + ) + + def test_set_default_version_2(self) -> None: + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=[] + ) as mock_show_versions, mock.patch.object( + self.m_ops._model_version_client, "set_default_version" + ) as mock_set_default_version: + with self.assertRaisesRegex( + ValueError, "You cannot set version V1 as default version as it does not exist." + ): + self.m_ops.set_default_version( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=self.m_statement_params, + ) + mock_show_versions.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + validate_result=False, + statement_params=self.m_statement_params, + ) + mock_set_default_version.assert_not_called() + def test_delete_model_or_version(self) -> None: with mock.patch.object( self.m_ops._model_client, diff --git a/snowflake/ml/model/_client/sql/BUILD.bazel b/snowflake/ml/model/_client/sql/BUILD.bazel index 465e2f24..8857ea8b 100644 --- a/snowflake/ml/model/_client/sql/BUILD.bazel +++ b/snowflake/ml/model/_client/sql/BUILD.bazel @@ -1,13 +1,19 @@ load("//bazel:py_rules.bzl", "py_library", "py_test") -package(default_visibility = ["//snowflake/ml/model/_client/ops:__pkg__"]) +package(default_visibility = [ + "//bazel:snowml_public_common", + "//snowflake/ml/model/_client/ops:__pkg__", +]) py_library( name = "model", srcs = ["model.py"], deps = [ "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:query_result_checker", + "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", ], ) @@ -16,7 +22,9 @@ py_test( srcs = ["model_test.py"], deps = [ ":model", + "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", "//snowflake/ml/test_utils:mock_data_frame", "//snowflake/ml/test_utils:mock_session", ], @@ -27,6 +35,7 @@ py_library( srcs = ["model_version.py"], deps = [ "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:query_result_checker", "//snowflake/ml/_internal/utils:sql_identifier", ], ) @@ -47,6 +56,7 @@ py_library( srcs = ["stage.py"], deps = [ "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:query_result_checker", "//snowflake/ml/_internal/utils:sql_identifier", ], ) @@ -61,3 +71,24 @@ py_test( "//snowflake/ml/test_utils:mock_session", ], ) + +py_library( + name = "tag", + srcs = ["tag.py"], + deps = [ + "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:query_result_checker", + "//snowflake/ml/_internal/utils:sql_identifier", + ], +) + +py_test( + name = "tag_test", + srcs = ["tag_test.py"], + deps = [ + ":tag", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/test_utils:mock_data_frame", + "//snowflake/ml/test_utils:mock_session", + ], +) diff --git a/snowflake/ml/model/_client/sql/model.py b/snowflake/ml/model/_client/sql/model.py index 040b5dea..f07c02dd 100644 --- a/snowflake/ml/model/_client/sql/model.py +++ b/snowflake/ml/model/_client/sql/model.py @@ -1,10 +1,25 @@ from typing import Any, Dict, List, Optional -from snowflake.ml._internal.utils import identifier, sql_identifier +from snowflake.ml._internal.utils import ( + identifier, + query_result_checker, + snowflake_env, + sql_identifier, +) +from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.snowpark import row, session class ModelSQLClient: + MODEL_NAME_COL_NAME = "name" + MODEL_COMMENT_COL_NAME = "comment" + MODEL_DEFAULT_VERSION_NAME_COL_NAME = "default_version_name" + + MODEL_VERSION_NAME_COL_NAME = "name" + MODEL_VERSION_COMMENT_COL_NAME = "comment" + MODEL_VERSION_METADATA_COL_NAME = "metadata" + MODEL_VERSION_USER_DATA_COL_NAME = "user_data" + def __init__( self, session: session.Session, @@ -30,29 +45,60 @@ def show_models( self, *, model_name: Optional[sql_identifier.SqlIdentifier] = None, + validate_result: bool = True, statement_params: Optional[Dict[str, Any]] = None, ) -> List[row.Row]: fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()]) like_sql = "" if model_name: like_sql = f" LIKE '{model_name.resolved()}'" - res = self._session.sql(f"SHOW MODELS{like_sql} IN SCHEMA {fully_qualified_schema_name}") - return res.collect(statement_params=statement_params) + res = ( + query_result_checker.SqlResultValidator( + self._session, + f"SHOW MODELS{like_sql} IN SCHEMA {fully_qualified_schema_name}", + statement_params=statement_params, + ) + .has_column(ModelSQLClient.MODEL_NAME_COL_NAME, allow_empty=True) + .has_column(ModelSQLClient.MODEL_COMMENT_COL_NAME, allow_empty=True) + .has_column(ModelSQLClient.MODEL_DEFAULT_VERSION_NAME_COL_NAME, allow_empty=True) + ) + if validate_result and model_name: + res = res.has_dimensions(expected_rows=1) + + return res.validate() def show_versions( self, *, model_name: sql_identifier.SqlIdentifier, version_name: Optional[sql_identifier.SqlIdentifier] = None, + validate_result: bool = True, statement_params: Optional[Dict[str, Any]] = None, ) -> List[row.Row]: like_sql = "" if version_name: like_sql = f" LIKE '{version_name.resolved()}'" - res = self._session.sql(f"SHOW VERSIONS{like_sql} IN MODEL {self.fully_qualified_model_name(model_name)}") - return res.collect(statement_params=statement_params) + res = ( + query_result_checker.SqlResultValidator( + self._session, + f"SHOW VERSIONS{like_sql} IN MODEL {self.fully_qualified_model_name(model_name)}", + statement_params=statement_params, + ) + .has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True) + .has_column(ModelSQLClient.MODEL_VERSION_COMMENT_COL_NAME, allow_empty=True) + .has_column(ModelSQLClient.MODEL_VERSION_METADATA_COL_NAME, allow_empty=True) + ) + if ( + snowflake_env.get_current_snowflake_version(self._session) + >= model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION + ): + res = res.has_column(ModelSQLClient.MODEL_VERSION_USER_DATA_COL_NAME, allow_empty=True) + if validate_result and version_name: + res = res.has_dimensions(expected_rows=1) + + return res.validate() def set_comment( self, @@ -61,8 +107,11 @@ def set_comment( model_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - comment_sql = f"COMMENT ON MODEL {self.fully_qualified_model_name(model_name)} IS $${comment}$$" - self._session.sql(comment_sql).collect(statement_params=statement_params) + query_result_checker.SqlResultValidator( + self._session, + f"COMMENT ON MODEL {self.fully_qualified_model_name(model_name)} IS $${comment}$$", + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() def drop_model( self, @@ -70,6 +119,8 @@ def drop_model( model_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - self._session.sql(f"DROP MODEL {self.fully_qualified_model_name(model_name)}").collect( - statement_params=statement_params - ) + query_result_checker.SqlResultValidator( + self._session, + f"DROP MODEL {self.fully_qualified_model_name(model_name)}", + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() diff --git a/snowflake/ml/model/_client/sql/model_test.py b/snowflake/ml/model/_client/sql/model_test.py index 2d0c133a..ba6fab90 100644 --- a/snowflake/ml/model/_client/sql/model_test.py +++ b/snowflake/ml/model/_client/sql/model_test.py @@ -1,9 +1,11 @@ from typing import cast +from unittest import mock from absl.testing import absltest -from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml._internal.utils import snowflake_env, sql_identifier from snowflake.ml.model._client.sql import model as model_sql +from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.test_utils import mock_data_frame, mock_session from snowflake.snowpark import Row, Session @@ -11,6 +13,9 @@ class ModelSQLTest(absltest.TestCase): def setUp(self) -> None: self.m_session = mock_session.MockSession(conn=None, test_case=self) + snowflake_env.get_current_snowflake_version = mock.MagicMock( + return_value=model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION + ) def test_show_models_1(self) -> None: m_statement_params = {"test": "1"} @@ -23,6 +28,7 @@ def test_show_models_1(self) -> None: model_name="MODEL", database_name="TEMP", schema_name="test", + default_version_name="V1", ), Row( create_on="06/01", @@ -31,6 +37,7 @@ def test_show_models_1(self) -> None: model_name="MODEL", database_name="TEMP", schema_name="test", + default_version_name="v1", ), ], collect_statement_params=m_statement_params, @@ -56,6 +63,7 @@ def test_show_models_2(self) -> None: model_name="MODEL", database_name="TEMP", schema_name="test", + default_version_name="V1", ), ], collect_statement_params=m_statement_params, @@ -80,6 +88,8 @@ def test_show_versions_1(self) -> None: name="v1", comment="This is a comment", model_name="MODEL", + metadata="{}", + user_data="{}", is_default_version=True, ), Row( @@ -87,6 +97,8 @@ def test_show_versions_1(self) -> None: name="V1", comment="This is a comment", model_name="MODEL", + metadata="{}", + user_data="{}", is_default_version=False, ), ], @@ -112,6 +124,8 @@ def test_show_versions_2(self) -> None: name="v1", comment="This is a comment", model_name="MODEL", + metadata="{}", + user_data="{}", is_default_version=True, ), ], diff --git a/snowflake/ml/model/_client/sql/model_version.py b/snowflake/ml/model/_client/sql/model_version.py index 7ffc7d22..18ba0a55 100644 --- a/snowflake/ml/model/_client/sql/model_version.py +++ b/snowflake/ml/model/_client/sql/model_version.py @@ -4,7 +4,11 @@ from typing import Any, Dict, List, Optional, Tuple from urllib.parse import ParseResult -from snowflake.ml._internal.utils import identifier, sql_identifier +from snowflake.ml._internal.utils import ( + identifier, + query_result_checker, + sql_identifier, +) from snowflake.snowpark import dataframe, functions as F, session, types as spt from snowflake.snowpark._internal import utils as snowpark_utils @@ -46,11 +50,14 @@ def create_from_stage( stage_path: str, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - self._version_name = version_name - self._session.sql( - f"CREATE MODEL {self.fully_qualified_model_name(model_name)} WITH VERSION {version_name.identifier()}" - f" FROM {stage_path}" - ).collect(statement_params=statement_params) + query_result_checker.SqlResultValidator( + self._session, + ( + f"CREATE MODEL {self.fully_qualified_model_name(model_name)} WITH VERSION {version_name.identifier()}" + f" FROM {stage_path}" + ), + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() # TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...` def add_version_from_stage( @@ -61,11 +68,14 @@ def add_version_from_stage( stage_path: str, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - self._version_name = version_name - self._session.sql( - f"ALTER MODEL {self.fully_qualified_model_name(model_name)} ADD VERSION {version_name.identifier()}" - f" FROM {stage_path}" - ).collect(statement_params=statement_params) + query_result_checker.SqlResultValidator( + self._session, + ( + f"ALTER MODEL {self.fully_qualified_model_name(model_name)} ADD VERSION {version_name.identifier()}" + f" FROM {stage_path}" + ), + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() def set_default_version( self, @@ -74,24 +84,14 @@ def set_default_version( version_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - self._session.sql( - f"ALTER MODEL {self.fully_qualified_model_name(model_name)} " - f"SET DEFAULT_VERSION = {version_name.identifier()}" - ).collect(statement_params=statement_params) - - def get_default_version( - self, - *, - model_name: sql_identifier.SqlIdentifier, - statement_params: Optional[Dict[str, Any]] = None, - ) -> str: - # TODO: Replace SHOW with DESC when available. - default_version: str = ( - self._session.sql(f"SHOW VERSIONS IN MODEL {self.fully_qualified_model_name(model_name)}") - .filter('"is_default_version" = TRUE')[['"name"']] - .collect(statement_params=statement_params)[0][0] - ) - return default_version + query_result_checker.SqlResultValidator( + self._session, + ( + f"ALTER MODEL {self.fully_qualified_model_name(model_name)} " + f"SET DEFAULT_VERSION = {version_name.identifier()}" + ), + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() def get_file( self, @@ -108,14 +108,14 @@ def get_file( stage_location_url = ParseResult( scheme="snow", netloc="model", path=stage_location, params="", query="", fragment="" ).geturl() - local_location = target_path.absolute().as_posix() - local_location_url = ParseResult( - scheme="file", netloc="", path=local_location, params="", query="", fragment="" - ).geturl() + local_location = target_path.resolve().as_posix() + local_location_url = f"file://{local_location}" - self._session.sql( - f"GET {_normalize_url_for_sql(stage_location_url)} {_normalize_url_for_sql(local_location_url)}" - ).collect(statement_params=statement_params) + query_result_checker.SqlResultValidator( + self._session, + f"GET {_normalize_url_for_sql(stage_location_url)} {_normalize_url_for_sql(local_location_url)}", + statement_params=statement_params, + ).has_dimensions(expected_rows=1).validate() return target_path / file_path.name def set_comment( @@ -126,11 +126,14 @@ def set_comment( version_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - comment_sql = ( - f"ALTER MODEL {self.fully_qualified_model_name(model_name)} " - f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$" - ) - self._session.sql(comment_sql).collect(statement_params=statement_params) + query_result_checker.SqlResultValidator( + self._session, + ( + f"ALTER MODEL {self.fully_qualified_model_name(model_name)} " + f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$" + ), + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() def invoke_method( self, @@ -206,8 +209,11 @@ def set_metadata( statement_params: Optional[Dict[str, Any]] = None, ) -> None: json_metadata = json.dumps(metadata_dict) - sql = ( - f"ALTER MODEL {self.fully_qualified_model_name(model_name)} MODIFY VERSION {version_name.identifier()}" - f" SET METADATA=$${json_metadata}$$" - ) - self._session.sql(sql).collect(statement_params=statement_params) + query_result_checker.SqlResultValidator( + self._session, + ( + f"ALTER MODEL {self.fully_qualified_model_name(model_name)} MODIFY VERSION {version_name.identifier()}" + f" SET METADATA=$${json_metadata}$$" + ), + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() diff --git a/snowflake/ml/model/_client/sql/model_version_test.py b/snowflake/ml/model/_client/sql/model_version_test.py index f732f714..0557fe77 100644 --- a/snowflake/ml/model/_client/sql/model_version_test.py +++ b/snowflake/ml/model/_client/sql/model_version_test.py @@ -53,6 +53,23 @@ def test_add_version_from_stage(self) -> None: statement_params=m_statement_params, ) + def test_set_default_version(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row("Model MODEL successfully altered.")], collect_statement_params=m_statement_params + ) + self.m_session.add_mock_sql("""ALTER MODEL TEMP."test".MODEL SET DEFAULT_VERSION = V2""", m_df) + c_session = cast(Session, self.m_session) + model_version_sql.ModelVersionSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).set_default_version( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V2"), + statement_params=m_statement_params, + ) + def test_set_comment(self) -> None: m_statement_params = {"test": "1"} m_df = mock_data_frame.MockDataFrame(collect_result=[Row("")], collect_statement_params=m_statement_params) @@ -74,7 +91,10 @@ def test_set_comment(self) -> None: def test_get_file(self) -> None: m_statement_params = {"test": "1"} - m_df = mock_data_frame.MockDataFrame(collect_result=[Row()], collect_statement_params=m_statement_params) + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row(file="946964364/MANIFEST.yml", size=419, status="DOWNLOADED", message="")], + collect_statement_params=m_statement_params, + ) self.m_session.add_mock_sql( """GET 'snow://model/TEMP."test".MODEL/versions/v1/model.yaml' 'file:///tmp'""", m_df ) diff --git a/snowflake/ml/model/_client/sql/stage.py b/snowflake/ml/model/_client/sql/stage.py index 8b9750a6..b40de375 100644 --- a/snowflake/ml/model/_client/sql/stage.py +++ b/snowflake/ml/model/_client/sql/stage.py @@ -1,6 +1,10 @@ from typing import Any, Dict, Optional -from snowflake.ml._internal.utils import identifier, sql_identifier +from snowflake.ml._internal.utils import ( + identifier, + query_result_checker, + sql_identifier, +) from snowflake.snowpark import session @@ -35,6 +39,8 @@ def create_tmp_stage( stage_name: sql_identifier.SqlIdentifier, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - self._session.sql(f"CREATE TEMPORARY STAGE {self.fully_qualified_stage_name(stage_name)}").collect( - statement_params=statement_params - ) + query_result_checker.SqlResultValidator( + self._session, + f"CREATE TEMPORARY STAGE {self.fully_qualified_stage_name(stage_name)}", + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() diff --git a/snowflake/ml/model/_client/sql/tag.py b/snowflake/ml/model/_client/sql/tag.py new file mode 100644 index 00000000..ac015a1a --- /dev/null +++ b/snowflake/ml/model/_client/sql/tag.py @@ -0,0 +1,118 @@ +from typing import Any, Dict, List, Optional + +from snowflake.ml._internal.utils import ( + identifier, + query_result_checker, + sql_identifier, +) +from snowflake.snowpark import row, session + + +class ModuleTagSQLClient: + def __init__( + self, + session: session.Session, + *, + database_name: sql_identifier.SqlIdentifier, + schema_name: sql_identifier.SqlIdentifier, + ) -> None: + self._session = session + self._database_name = database_name + self._schema_name = schema_name + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, ModuleTagSQLClient): + return False + return self._database_name == __value._database_name and self._schema_name == __value._schema_name + + def fully_qualified_module_name( + self, + module_name: sql_identifier.SqlIdentifier, + ) -> str: + return identifier.get_schema_level_object_identifier( + self._database_name.identifier(), self._schema_name.identifier(), module_name.identifier() + ) + + def set_tag_on_model( + self, + model_name: sql_identifier.SqlIdentifier, + *, + tag_database_name: sql_identifier.SqlIdentifier, + tag_schema_name: sql_identifier.SqlIdentifier, + tag_name: sql_identifier.SqlIdentifier, + tag_value: str, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + fq_model_name = self.fully_qualified_module_name(model_name) + fq_tag_name = identifier.get_schema_level_object_identifier( + tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier() + ) + query_result_checker.SqlResultValidator( + self._session, + f"ALTER MODEL {fq_model_name} SET TAG {fq_tag_name} = $${tag_value}$$", + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() + + def unset_tag_on_model( + self, + model_name: sql_identifier.SqlIdentifier, + *, + tag_database_name: sql_identifier.SqlIdentifier, + tag_schema_name: sql_identifier.SqlIdentifier, + tag_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + fq_model_name = self.fully_qualified_module_name(model_name) + fq_tag_name = identifier.get_schema_level_object_identifier( + tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier() + ) + query_result_checker.SqlResultValidator( + self._session, + f"ALTER MODEL {fq_model_name} UNSET TAG {fq_tag_name}", + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() + + def get_tag_value( + self, + module_name: sql_identifier.SqlIdentifier, + *, + tag_database_name: sql_identifier.SqlIdentifier, + tag_schema_name: sql_identifier.SqlIdentifier, + tag_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> row.Row: + fq_module_name = self.fully_qualified_module_name(module_name) + fq_tag_name = identifier.get_schema_level_object_identifier( + tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier() + ) + return ( + query_result_checker.SqlResultValidator( + self._session, + f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${fq_module_name}$$, 'MODULE') AS TAG_VALUE", + statement_params=statement_params, + ) + .has_dimensions(expected_rows=1, expected_cols=1) + .has_column("TAG_VALUE") + .validate()[0] + ) + + def get_tag_list( + self, + module_name: sql_identifier.SqlIdentifier, + *, + statement_params: Optional[Dict[str, Any]] = None, + ) -> List[row.Row]: + fq_module_name = self.fully_qualified_module_name(module_name) + return ( + query_result_checker.SqlResultValidator( + self._session, + f"""SELECT TAG_DATABASE, TAG_SCHEMA, TAG_NAME, TAG_VALUE +FROM TABLE({self._database_name.identifier()}.INFORMATION_SCHEMA.TAG_REFERENCES($${fq_module_name}$$, 'MODULE'))""", + statement_params=statement_params, + ) + .has_column("TAG_DATABASE", allow_empty=True) + .has_column("TAG_SCHEMA", allow_empty=True) + .has_column("TAG_NAME", allow_empty=True) + .has_column("TAG_VALUE", allow_empty=True) + .validate() + ) diff --git a/snowflake/ml/model/_client/sql/tag_test.py b/snowflake/ml/model/_client/sql/tag_test.py new file mode 100644 index 00000000..d2512268 --- /dev/null +++ b/snowflake/ml/model/_client/sql/tag_test.py @@ -0,0 +1,104 @@ +from typing import cast + +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.sql import tag as tag_sql +from snowflake.ml.test_utils import mock_data_frame, mock_session +from snowflake.snowpark import Row, Session + + +class ModuleTagSQLTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + + def test_set_tag_on_model(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row("Tag MYTAG successfully set.")], collect_statement_params=m_statement_params + ) + self.m_session.add_mock_sql( + """ALTER MODEL TEMP."test".MODEL SET TAG DB."schema".MYTAG = $$tag content$$""", m_df + ) + c_session = cast(Session, self.m_session) + tag_sql.ModuleTagSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).set_tag_on_model( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + tag_value="tag content", + statement_params=m_statement_params, + ) + + def test_unset_tag_on_model(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row("Tag MYTAG successfully unset.")], collect_statement_params=m_statement_params + ) + self.m_session.add_mock_sql("""ALTER MODEL TEMP."test".MODEL UNSET TAG DB."schema".MYTAG""", m_df) + c_session = cast(Session, self.m_session) + tag_sql.ModuleTagSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).unset_tag_on_model( + model_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=m_statement_params, + ) + + def test_get_tag_value(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row(TAG_VALUE="tag content")], collect_statement_params=m_statement_params + ) + self.m_session.add_mock_sql( + """SELECT SYSTEM$GET_TAG($$DB."schema".MYTAG$$, $$TEMP."test".MODEL$$, 'MODULE') AS TAG_VALUE""", m_df + ) + c_session = cast(Session, self.m_session) + res = tag_sql.ModuleTagSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).get_tag_value( + module_name=sql_identifier.SqlIdentifier("MODEL"), + tag_database_name=sql_identifier.SqlIdentifier("DB"), + tag_schema_name=sql_identifier.SqlIdentifier("schema", case_sensitive=True), + tag_name=sql_identifier.SqlIdentifier("MYTAG"), + statement_params=m_statement_params, + ) + self.assertEqual(res, Row(TAG_VALUE="tag content")) + + def test_list_tags(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row(TAG_DATABASE="DB", TAG_SCHEMA="schema", TAG_NAME="MYTAG", TAG_VALUE="tag content")], + collect_statement_params=m_statement_params, + ) + self.m_session.add_mock_sql( + """SELECT TAG_DATABASE, TAG_SCHEMA, TAG_NAME, TAG_VALUE +FROM TABLE(TEMP.INFORMATION_SCHEMA.TAG_REFERENCES($$TEMP."test".MODEL$$, 'MODULE'))""", + m_df, + ) + c_session = cast(Session, self.m_session) + res = tag_sql.ModuleTagSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).get_tag_list( + module_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=m_statement_params, + ) + self.assertListEqual( + res, [Row(TAG_DATABASE="DB", TAG_SCHEMA="schema", TAG_NAME="MYTAG", TAG_VALUE="tag content")] + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/model/_deploy_client/image_builds/BUILD.bazel b/snowflake/ml/model/_deploy_client/image_builds/BUILD.bazel index 37fb5129..e1b3e968 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/BUILD.bazel +++ b/snowflake/ml/model/_deploy_client/image_builds/BUILD.bazel @@ -13,9 +13,9 @@ py_library( deps = [ ":base_image_builder", ":docker_context", + "//snowflake/ml/_internal/container_services/image_registry:credential", "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/utils:query_result_checker", - "//snowflake/ml/_internal/utils:spcs_image_registry", "//snowflake/ml/model/_packager/model_meta", ], ) @@ -31,10 +31,10 @@ py_library( ":base_image_builder", ":docker_context", "//snowflake/ml/_internal:file_utils", + "//snowflake/ml/_internal/container_services/image_registry:registry_client", "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/model/_deploy_client/utils:constants", - "//snowflake/ml/model/_deploy_client/utils:image_registry_client", "//snowflake/ml/model/_deploy_client/utils:snowservice_client", ], ) diff --git a/snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py b/snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py index 2f0ef72c..11878bc2 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +++ b/snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py @@ -9,11 +9,11 @@ from typing import List from snowflake import snowpark +from snowflake.ml._internal.container_services.image_registry import credential from snowflake.ml._internal.exceptions import ( error_codes, exceptions as snowml_exceptions, ) -from snowflake.ml._internal.utils import spcs_image_registry from snowflake.ml.model._deploy_client.image_builds import base_image_builder logger = logging.getLogger(__name__) @@ -106,7 +106,7 @@ def _cleanup_local_image(docker_config_dir: str) -> None: self._run_docker_commands(commands) self.validate_docker_client_env() - with spcs_image_registry.generate_image_registry_credential( + with credential.generate_image_registry_credential( self.session ) as registry_cred, tempfile.TemporaryDirectory() as docker_config_dir: try: diff --git a/snowflake/ml/model/_deploy_client/image_builds/docker_context.py b/snowflake/ml/model/_deploy_client/image_builds/docker_context.py index 56aa0b57..0d732045 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/docker_context.py +++ b/snowflake/ml/model/_deploy_client/image_builds/docker_context.py @@ -2,7 +2,6 @@ import posixpath import shutil import string -from abc import ABC from typing import Optional import importlib_resources @@ -15,7 +14,7 @@ from snowflake.snowpark import FileOperation, Session -class DockerContext(ABC): +class DockerContext: """ Constructs the Docker context directory required for image building. """ @@ -53,12 +52,13 @@ def build(self) -> None: def _copy_entrypoint_script_to_docker_context(self) -> None: """Copy gunicorn_run.sh entrypoint to docker context directory.""" - with importlib_resources.as_file( - importlib_resources.files(image_builds).joinpath( # type: ignore[no-untyped-call] - constants.ENTRYPOINT_SCRIPT - ) - ) as path: - shutil.copy(path, os.path.join(self.context_dir, constants.ENTRYPOINT_SCRIPT)) + script_path = importlib_resources.files(image_builds).joinpath( # type: ignore[no-untyped-call] + constants.ENTRYPOINT_SCRIPT + ) + target_path = os.path.join(self.context_dir, constants.ENTRYPOINT_SCRIPT) + + with open(script_path, encoding="utf-8") as source_file, file_utils.open_file(target_path, "w") as target_file: + target_file.write(source_file.read()) def _copy_model_env_dependency_to_docker_context(self) -> None: """ diff --git a/snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py b/snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py index 77bdc5eb..0626ec73 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +++ b/snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py @@ -105,6 +105,8 @@ def _run_setup() -> None: # TODO (Server-side Model Rollout): # Keep try block only + # SPCS spec will convert all environment variables as strings. + use_gpu = os.environ.get("SNOWML_USE_GPU", "False").lower() == "true" try: from snowflake.ml.model._packager import model_packager @@ -112,9 +114,7 @@ def _run_setup() -> None: pk.load( as_custom_model=True, meta_only=False, - options=model_types.ModelLoadOption( - {"use_gpu": cast(bool, os.environ.get("SNOWML_USE_GPU", False))} - ), + options=model_types.ModelLoadOption({"use_gpu": use_gpu}), ) _LOADED_MODEL = pk.model _LOADED_META = pk.meta @@ -132,9 +132,7 @@ def _run_setup() -> None: _LOADED_MODEL, meta_LOADED_META = model_api._load( local_dir_path=extracted_dir, as_custom_model=True, - options=model_types.ModelLoadOption( - {"use_gpu": cast(bool, os.environ.get("SNOWML_USE_GPU", False))} - ), + options=model_types.ModelLoadOption({"use_gpu": use_gpu}), ) _MODEL_LOADING_STATE = _ModelLoadingState.SUCCEEDED logger.info("Successfully loaded model into memory") diff --git a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py b/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py index 963a3c0f..0713911d 100644 --- a/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +++ b/snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py @@ -7,6 +7,9 @@ from snowflake import snowpark from snowflake.ml._internal import file_utils +from snowflake.ml._internal.container_services.image_registry import ( + registry_client as image_registry_client, +) from snowflake.ml._internal.exceptions import ( error_codes, exceptions as snowml_exceptions, @@ -14,11 +17,7 @@ from snowflake.ml._internal.utils import identifier from snowflake.ml.model._deploy_client import image_builds from snowflake.ml.model._deploy_client.image_builds import base_image_builder -from snowflake.ml.model._deploy_client.utils import ( - constants, - image_registry_client, - snowservice_client, -) +from snowflake.ml.model._deploy_client.utils import constants, snowservice_client logger = logging.getLogger(__name__) @@ -117,7 +116,7 @@ def _construct_and_upload_docker_entrypoint_script(self, context_tarball_stage_l kaniko_shell_file = os.path.join(self.context_dir, constants.KANIKO_SHELL_SCRIPT_NAME) - with open(kaniko_shell_file, "w+", encoding="utf-8") as script_file: + with file_utils.open_file(kaniko_shell_file, "w+") as script_file: normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@")) params = { # Remove @ in the beginning, append "/" to denote root directory. @@ -175,7 +174,7 @@ def _construct_and_upload_job_spec(self, base_image: str, kaniko_shell_script_st os.path.dirname(self.context_dir), f"{constants.IMAGE_BUILD_JOB_SPEC_TEMPLATE}.yaml" ) - with open(spec_file_path, "w+", encoding="utf-8") as spec_file: + with file_utils.open_file(spec_file_path, "w+") as spec_file: assert self.artifact_stage_location.startswith("@") normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@")) (db, schema, stage, path) = identifier.parse_schema_level_object_identifier(normed_artifact_stage_path) diff --git a/snowflake/ml/model/_deploy_client/snowservice/BUILD.bazel b/snowflake/ml/model/_deploy_client/snowservice/BUILD.bazel index 2cc37a5c..7579d010 100644 --- a/snowflake/ml/model/_deploy_client/snowservice/BUILD.bazel +++ b/snowflake/ml/model/_deploy_client/snowservice/BUILD.bazel @@ -22,6 +22,7 @@ py_library( ":deploy_options", ":instance_types", "//snowflake/ml/_internal:env_utils", + "//snowflake/ml/_internal/container_services/image_registry:registry_client", "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/_internal/utils:spcs_attribution_utils", @@ -29,7 +30,6 @@ py_library( "//snowflake/ml/model/_deploy_client/image_builds:base_image_builder", "//snowflake/ml/model/_deploy_client/image_builds:client_image_builder", "//snowflake/ml/model/_deploy_client/image_builds:server_image_builder", - "//snowflake/ml/model/_deploy_client/utils:image_registry_client", "//snowflake/ml/model/_deploy_client/utils:snowservice_client", "//snowflake/ml/model/_packager/model_meta", ], diff --git a/snowflake/ml/model/_deploy_client/snowservice/deploy.py b/snowflake/ml/model/_deploy_client/snowservice/deploy.py index 748ce594..182c417b 100644 --- a/snowflake/ml/model/_deploy_client/snowservice/deploy.py +++ b/snowflake/ml/model/_deploy_client/snowservice/deploy.py @@ -14,6 +14,9 @@ from typing_extensions import Unpack from snowflake.ml._internal import env_utils, file_utils +from snowflake.ml._internal.container_services.image_registry import ( + registry_client as image_registry_client, +) from snowflake.ml._internal.exceptions import ( error_codes, exceptions as snowml_exceptions, @@ -32,11 +35,7 @@ server_image_builder, ) from snowflake.ml.model._deploy_client.snowservice import deploy_options, instance_types -from snowflake.ml.model._deploy_client.utils import ( - constants, - image_registry_client, - snowservice_client, -) +from snowflake.ml.model._deploy_client.utils import constants, snowservice_client from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema from snowflake.snowpark import Session diff --git a/snowflake/ml/model/_deploy_client/snowservice/instance_types.py b/snowflake/ml/model/_deploy_client/snowservice/instance_types.py index 11a9e09a..ab27b15b 100644 --- a/snowflake/ml/model/_deploy_client/snowservice/instance_types.py +++ b/snowflake/ml/model/_deploy_client/snowservice/instance_types.py @@ -1,2 +1,10 @@ # Snowpark Container Service GPU instance type and corresponding GPU counts. -INSTANCE_TYPE_TO_GPU_COUNT = {"GPU_3": 1, "GPU_5": 1, "GPU_7": 4, "GPU_10": 8} +INSTANCE_TYPE_TO_GPU_COUNT = { + "GPU_3": 1, + "GPU_5": 1, + "GPU_7": 4, + "GPU_10": 8, + "GPU_NV_S": 1, + "GPU_NV_M": 4, + "GPU_NV_L": 8, +} diff --git a/snowflake/ml/model/_deploy_client/utils/BUILD.bazel b/snowflake/ml/model/_deploy_client/utils/BUILD.bazel index d5fa0a02..31a123fd 100644 --- a/snowflake/ml/model/_deploy_client/utils/BUILD.bazel +++ b/snowflake/ml/model/_deploy_client/utils/BUILD.bazel @@ -18,24 +18,6 @@ py_library( ], ) -py_library( - name = "image_registry_client", - srcs = ["image_registry_client.py"], - deps = [ - ":imagelib", - "//snowflake/ml/_internal/exceptions", - "//snowflake/ml/_internal/utils:image_registry_http_client", - ], -) - -py_library( - name = "imagelib", - srcs = ["imagelib.py"], - deps = [ - "//snowflake/ml/_internal/utils:image_registry_http_client", - ], -) - py_test( name = "snowservice_client_test", srcs = ["snowservice_client_test.py"], @@ -45,13 +27,3 @@ py_test( "//snowflake/ml/test_utils:mock_session", ], ) - -py_test( - name = "image_registry_client_test", - srcs = ["image_registry_client_test.py"], - deps = [ - ":image_registry_client", - "//snowflake/ml/test_utils:exception_utils", - "//snowflake/ml/test_utils:mock_session", - ], -) diff --git a/snowflake/ml/model/_deploy_client/warehouse/deploy.py b/snowflake/ml/model/_deploy_client/warehouse/deploy.py index 0f20fc4a..e89b49e6 100644 --- a/snowflake/ml/model/_deploy_client/warehouse/deploy.py +++ b/snowflake/ml/model/_deploy_client/warehouse/deploy.py @@ -2,6 +2,7 @@ import logging import posixpath import tempfile +import textwrap from types import ModuleType from typing import IO, List, Optional, Tuple, TypedDict, Union @@ -154,7 +155,7 @@ def _get_model_final_packages( Returns: List of final packages string that is accepted by Snowpark register UDF call. """ - final_packages = None + if ( any(channel.lower() not in [env_utils.DEFAULT_CHANNEL_NAME] for channel in meta.env._conda_dependencies.keys()) or meta.env.pip_requirements @@ -173,21 +174,29 @@ def _get_model_final_packages( else: required_packages = meta.env._conda_dependencies[env_utils.DEFAULT_CHANNEL_NAME] - final_packages = env_utils.validate_requirements_in_information_schema( + package_availability_dict = env_utils.get_matched_package_versions_in_information_schema( session, required_packages, python_version=meta.env.python_version ) - - if final_packages is None: + no_version_available_packages = [ + req_name for req_name, ver_list in package_availability_dict.items() if len(ver_list) < 1 + ] + unavailable_packages = [req.name for req in required_packages if req.name not in package_availability_dict] + if no_version_available_packages or unavailable_packages: relax_version_info_str = "" if relax_version else "Try to set relax_version as True in the options. " + required_package_str = " ".join(map(lambda x: f'"{x}"', required_packages)) raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.DEPENDENCY_VERSION_ERROR, original_exception=RuntimeError( - "The model's dependencies are not available in Snowflake Anaconda Channel. " - + relax_version_info_str - + "Required packages are:\n" - + " ".join(map(lambda x: f'"{x}"', required_packages)) - + "\n Required Python version is: " - + meta.env.python_version + textwrap.dedent( + f""" + The model's dependencies are not available in Snowflake Anaconda Channel. {relax_version_info_str} + Required packages are: {required_package_str} + Required Python version is: {meta.env.python_version} + Packages that are not available are: {unavailable_packages} + Packages that cannot meet your requirements are: {no_version_available_packages} + Package availability information of those you requested is: {package_availability_dict} + """ + ), ), ) - return final_packages + return list(sorted(map(str, required_packages))) diff --git a/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py b/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py index 40250df4..4bbfa6df 100644 --- a/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py +++ b/snowflake/ml/model/_deploy_client/warehouse/deploy_test.py @@ -1,3 +1,4 @@ +import platform import tempfile import textwrap from importlib import metadata as importlib_metadata @@ -40,18 +41,7 @@ class TestFinalPackagesWithoutConda(absltest.TestCase): @classmethod def setUpClass(cls) -> None: - env_utils._INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION = None cls.m_session = mock_session.MockSession(conn=None, test_case=None) - cls.m_session.add_mock_sql( - query=textwrap.dedent( - """ - SHOW COLUMNS - LIKE 'runtime_version' - IN TABLE information_schema.packages; - """ - ), - result=mock_data_frame.MockDataFrame(count_result=0), - ) def setUp(self) -> None: self.add_packages( @@ -76,7 +66,9 @@ def add_packages(self, packages_dicts: Dict[str, List[str]]) -> None: SELECT PACKAGE_NAME, VERSION FROM information_schema.packages WHERE ({pkg_names_str}) - AND language = 'python'; + AND language = 'python' + AND (runtime_version = '{platform.python_version_tuple()[0]}.{platform.python_version_tuple()[1]}' + OR runtime_version is null); """ ) sql_result = [ diff --git a/snowflake/ml/model/_model_composer/model_composer_test.py b/snowflake/ml/model/_model_composer/model_composer_test.py index 6a46977a..9611995f 100644 --- a/snowflake/ml/model/_model_composer/model_composer_test.py +++ b/snowflake/ml/model/_model_composer/model_composer_test.py @@ -39,7 +39,11 @@ def test_save_interface(self) -> None: with mock.patch.object( file_utils, "upload_directory_to_stage", return_value=None ) as mock_upload_directory_to_stage: - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ): m.save( name="model1", model=LinearRegression(), @@ -59,7 +63,11 @@ def test_save_interface(self) -> None: with mock.patch.object( file_utils, "upload_directory_to_stage", return_value=None ) as mock_upload_directory_to_stage: - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ): m.save( name="model1", model=linear_model.LinearRegression(), diff --git a/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel b/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel index 0b0d9706..0c6b509c 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel +++ b/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel @@ -2,11 +2,22 @@ load("//bazel:py_rules.bzl", "py_library", "py_test") package(default_visibility = ["//visibility:public"]) +filegroup( + name = "manifest_fixtures", + srcs = [ + "fixtures/MANIFEST_0.yml", + "fixtures/MANIFEST_1.yml", + "fixtures/MANIFEST_2.yml", + "fixtures/MANIFEST_3.yml", + ], +) + py_library( name = "model_manifest", srcs = ["model_manifest.py"], deps = [ ":model_manifest_schema", + "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/model/_model_composer/model_method", "//snowflake/ml/model/_model_composer/model_method:function_generator", "//snowflake/ml/model/_model_composer/model_runtime", @@ -17,15 +28,22 @@ py_library( py_library( name = "model_manifest_schema", srcs = ["model_manifest_schema.py"], + deps = [ + "//snowflake/ml/model:model_signature", + ], ) py_test( name = "model_manifest_test", srcs = ["model_manifest_test.py"], - data = ["//snowflake/ml/model/_model_composer/model_method:function_fixtures"], + data = [ + ":manifest_fixtures", + "//snowflake/ml/model/_model_composer/model_method:function_fixtures", + ], deps = [ ":model_manifest", "//snowflake/ml/_internal:env_utils", + "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/model:model_signature", "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_packager/model_meta", diff --git a/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_0.yml b/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_0.yml new file mode 100644 index 00000000..d1914caf --- /dev/null +++ b/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_0.yml @@ -0,0 +1,41 @@ +manifest_version: '1.0' +methods: +- handler: functions.predict.infer + inputs: + - name: input_1 + type: FLOAT + - name: input_2 + type: ARRAY + - name: input_3 + type: ARRAY + - name: input_4 + type: ARRAY + name: predict + outputs: + - type: OBJECT + runtime: python_runtime + type: FUNCTION +- handler: functions.__call__.infer + inputs: + - name: INPUT_1 + type: FLOAT + - name: INPUT_2 + type: ARRAY + - name: INPUT_3 + type: ARRAY + - name: INPUT_4 + type: ARRAY + name: __CALL__ + outputs: + - type: OBJECT + runtime: python_runtime + type: FUNCTION +runtimes: + python_runtime: + dependencies: + conda: runtimes/python_runtime/env/conda.yml + imports: + - model.zip + - runtimes/python_runtime/snowflake-ml-python.zip + language: PYTHON + version: '3.8' diff --git a/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_1.yml b/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_1.yml new file mode 100644 index 00000000..bc4f4434 --- /dev/null +++ b/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_1.yml @@ -0,0 +1,65 @@ +manifest_version: '1.0' +methods: +- handler: functions.predict.infer + inputs: + - name: INPUT_1 + type: FLOAT + - name: INPUT_2 + type: ARRAY + - name: INPUT_3 + type: ARRAY + - name: INPUT_4 + type: ARRAY + name: PREDICT + outputs: + - type: OBJECT + runtime: python_runtime + type: FUNCTION +runtimes: + python_runtime: + dependencies: + conda: runtimes/python_runtime/env/conda.yml + imports: + - model.zip + language: PYTHON + version: '3.8' +user_data: + snowpark_ml_data: + functions: + - name: PREDICT + signature: + inputs: + - name: input_1 + type: FLOAT + - name: input_2 + shape: + - -1 + type: FLOAT + - name: input_3 + shape: + - -1 + type: FLOAT + - name: input_4 + shape: + - -1 + type: FLOAT + outputs: + - name: output_1 + type: FLOAT + - name: output_2 + shape: + - 2 + - 2 + type: FLOAT + - name: output_3 + shape: + - 2 + - 2 + type: FLOAT + - name: output_4 + shape: + - 2 + - 2 + type: FLOAT + target_method: predict + schema_version: '2024-02-01' diff --git a/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_2.yml b/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_2.yml new file mode 100644 index 00000000..73d76bda --- /dev/null +++ b/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_2.yml @@ -0,0 +1,66 @@ +manifest_version: '1.0' +methods: +- handler: functions.__call__.infer + inputs: + - name: INPUT_1 + type: FLOAT + - name: INPUT_2 + type: ARRAY + - name: INPUT_3 + type: ARRAY + - name: INPUT_4 + type: ARRAY + name: __CALL__ + outputs: + - type: OBJECT + runtime: python_runtime + type: FUNCTION +runtimes: + python_runtime: + dependencies: + conda: runtimes/python_runtime/env/conda.yml + imports: + - model.zip + - runtimes/python_runtime/snowflake-ml-python.zip + language: PYTHON + version: '3.8' +user_data: + snowpark_ml_data: + functions: + - name: __CALL__ + signature: + inputs: + - name: input_1 + type: FLOAT + - name: input_2 + shape: + - -1 + type: FLOAT + - name: input_3 + shape: + - -1 + type: FLOAT + - name: input_4 + shape: + - -1 + type: FLOAT + outputs: + - name: output_1 + type: FLOAT + - name: output_2 + shape: + - 2 + - 2 + type: FLOAT + - name: output_3 + shape: + - 2 + - 2 + type: FLOAT + - name: output_4 + shape: + - 2 + - 2 + type: FLOAT + target_method: __call__ + schema_version: '2024-02-01' diff --git a/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_3.yml b/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_3.yml new file mode 100644 index 00000000..08b5a91c --- /dev/null +++ b/snowflake/ml/model/_model_composer/model_manifest/fixtures/MANIFEST_3.yml @@ -0,0 +1,117 @@ +manifest_version: '1.0' +methods: +- handler: functions.predict.infer + inputs: + - name: input_1 + type: FLOAT + - name: input_2 + type: ARRAY + - name: input_3 + type: ARRAY + - name: input_4 + type: ARRAY + name: predict + outputs: + - type: OBJECT + runtime: python_runtime + type: FUNCTION +- handler: functions.__call__.infer + inputs: + - name: INPUT_1 + type: FLOAT + - name: INPUT_2 + type: ARRAY + - name: INPUT_3 + type: ARRAY + - name: INPUT_4 + type: ARRAY + name: __CALL__ + outputs: + - type: OBJECT + runtime: python_runtime + type: FUNCTION +runtimes: + python_runtime: + dependencies: + conda: runtimes/python_runtime/env/conda.yml + imports: + - model.zip + - runtimes/python_runtime/snowflake-ml-python.zip + language: PYTHON + version: '3.8' +user_data: + snowpark_ml_data: + functions: + - name: '"predict"' + signature: + inputs: + - name: input_1 + type: FLOAT + - name: input_2 + shape: + - -1 + type: FLOAT + - name: input_3 + shape: + - -1 + type: FLOAT + - name: input_4 + shape: + - -1 + type: FLOAT + outputs: + - name: output_1 + type: FLOAT + - name: output_2 + shape: + - 2 + - 2 + type: FLOAT + - name: output_3 + shape: + - 2 + - 2 + type: FLOAT + - name: output_4 + shape: + - 2 + - 2 + type: FLOAT + target_method: predict + - name: __CALL__ + signature: + inputs: + - name: input_1 + type: FLOAT + - name: input_2 + shape: + - -1 + type: FLOAT + - name: input_3 + shape: + - -1 + type: FLOAT + - name: input_4 + shape: + - -1 + type: FLOAT + outputs: + - name: output_1 + type: FLOAT + - name: output_2 + shape: + - 2 + - 2 + type: FLOAT + - name: output_3 + shape: + - 2 + - 2 + type: FLOAT + - name: output_4 + shape: + - 2 + - 2 + type: FLOAT + target_method: __call__ + schema_version: '2024-02-01' diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py index 3d4d4f70..78a782cf 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py @@ -1,9 +1,10 @@ import collections import pathlib -from typing import List, Optional, cast +from typing import Any, Dict, List, Optional, cast import yaml +from snowflake.ml._internal.utils import snowflake_env from snowflake.ml.model import type_hints from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.model._model_composer.model_method import ( @@ -83,7 +84,15 @@ def save( ], ) + if ( + snowflake_env.get_current_snowflake_version(session) + >= model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION + ): + manifest_dict["user_data"] = self.generate_user_data_with_client_data(model_meta) + with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f: + # Anchors are not supported in the server, avoid that. + yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign] yaml.safe_dump(manifest_dict, f) def load(self) -> model_manifest_schema.ModelManifestDict: @@ -99,3 +108,43 @@ def load(self) -> model_manifest_schema.ModelManifestDict: res = cast(model_manifest_schema.ModelManifestDict, raw_input) return res + + def generate_user_data_with_client_data(self, model_meta: model_meta_api.ModelMetadata) -> Dict[str, Any]: + client_data = model_manifest_schema.SnowparkMLDataDict( + schema_version=model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION, + functions=[ + model_manifest_schema.ModelFunctionInfoDict( + name=method.method_name.identifier(), + target_method=method.target_method, + signature=model_meta.signatures[method.target_method].to_dict(), + ) + for method in self.methods + ], + ) + return {model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME: client_data} + + @staticmethod + def parse_client_data_from_user_data(raw_user_data: Dict[str, Any]) -> model_manifest_schema.SnowparkMLDataDict: + raw_client_data = raw_user_data.get(model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME, {}) + if not isinstance(raw_client_data, dict) or "schema_version" not in raw_client_data: + raise ValueError(f"Ill-formatted client data {raw_client_data} in user data found.") + loaded_client_data_schema_version = raw_client_data["schema_version"] + if ( + not isinstance(loaded_client_data_schema_version, str) + or loaded_client_data_schema_version != model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION + ): + raise ValueError(f"Unsupported client data schema version {loaded_client_data_schema_version} confronted.") + + return_functions_info: List[model_manifest_schema.ModelFunctionInfoDict] = [] + loaded_functions_info = raw_client_data.get("functions", []) + for func in loaded_functions_info: + fi = model_manifest_schema.ModelFunctionInfoDict( + name=func["name"], + target_method=func["target_method"], + signature=func["signature"], + ) + return_functions_info.append(fi) + + return model_manifest_schema.SnowparkMLDataDict( + schema_version=loaded_client_data_schema_version, functions=return_functions_info + ) diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py index 2df33b9b..bf2d5473 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py @@ -2,10 +2,17 @@ from typing import Any, Dict, List, Literal, TypedDict +from packaging import version from typing_extensions import NotRequired, Required +from snowflake.ml.model import model_signature + MODEL_MANIFEST_VERSION = "1.0" +MANIFEST_USER_DATA_ENABLE_VERSION = version.parse("8.2.0") +MANIFEST_CLIENT_DATA_KEY_NAME = "snowpark_ml_data" +MANIFEST_CLIENT_DATA_SCHEMA_VERSION = "2024-02-01" + class ModelRuntimeDependenciesDict(TypedDict): conda: Required[str] @@ -38,6 +45,31 @@ class ModelFunctionMethodDict(TypedDict): ModelMethodDict = ModelFunctionMethodDict +class ModelFunctionInfo(TypedDict): + """Function information. + + Attributes: + name: Name of the function to be called via SQL. + target_method: actual target method name to be called. + signature: The signature of the model method. + """ + + name: Required[str] + target_method: Required[str] + signature: Required[model_signature.ModelSignature] + + +class ModelFunctionInfoDict(TypedDict): + name: Required[str] + target_method: Required[str] + signature: Required[Dict[str, Any]] + + +class SnowparkMLDataDict(TypedDict): + schema_version: Required[str] + functions: Required[List[ModelFunctionInfoDict]] + + class ModelManifestDict(TypedDict): manifest_version: Required[str] runtimes: Required[Dict[str, ModelRuntimeDict]] diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py index 50bfacd1..acc159cb 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py @@ -1,23 +1,37 @@ import os import pathlib import tempfile +from typing import Any, Dict from unittest import mock import importlib_resources import yaml from absl.testing import absltest +from packaging import version from snowflake.ml._internal import env_utils +from snowflake.ml._internal.utils import snowflake_env from snowflake.ml.model import model_signature, type_hints -from snowflake.ml.model._model_composer.model_manifest import model_manifest +from snowflake.ml.model._model_composer.model_manifest import ( + model_manifest, + model_manifest_schema, +) from snowflake.ml.model._packager.model_meta import model_blob_meta, model_meta _DUMMY_SIG = { "predict": model_signature.ModelSignature( inputs=[ - model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="input"), + model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="input_1"), + model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="input_2", shape=(-1,)), + model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="input_3", shape=(-1,)), + model_signature.FeatureSpec(dtype=model_signature.DataType.FLOAT, name="input_4", shape=(-1,)), + ], + outputs=[ + model_signature.FeatureSpec(name="output_1", dtype=model_signature.DataType.FLOAT), + model_signature.FeatureSpec(name="output_2", dtype=model_signature.DataType.FLOAT, shape=(2, 2)), + model_signature.FeatureSpec(name="output_3", dtype=model_signature.DataType.FLOAT, shape=(2, 2)), + model_signature.FeatureSpec(name="output_4", dtype=model_signature.DataType.FLOAT, shape=(2, 2)), ], - outputs=[model_signature.FeatureSpec(name="output", dtype=model_signature.DataType.FLOAT)], ) } @@ -30,48 +44,102 @@ class ModelManifestTest(absltest.TestCase): def setUp(self) -> None: self.m_session = mock.MagicMock() + snowflake_env.get_current_snowflake_version = mock.MagicMock( + return_value=model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION + ) + + def test_model_manifest_old(self) -> None: + snowflake_env.get_current_snowflake_version = mock.MagicMock(return_value=version.parse("8.0.0")) + with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: + mm = model_manifest.ModelManifest(pathlib.Path(workspace)) + with model_meta.create_model_metadata( + model_dir_path=tmpdir, + name="model1", + model_type="custom", + signatures={"predict": _DUMMY_SIG["predict"], "__call__": _DUMMY_SIG["predict"]}, + python_version="3.8", + ) as meta: + meta.models["model1"] = _DUMMY_BLOB + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ): + mm.save( + self.m_session, + meta, + pathlib.PurePosixPath("model.zip"), + options=type_hints.BaseModelSaveOption( + method_options={ + "predict": type_hints.ModelMethodSaveOptions(case_sensitive=True), + "__call__": type_hints.ModelMethodSaveOptions(max_batch_size=10), + } + ), + ) + with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: + self.assertEqual( + ( + importlib_resources.files("snowflake.ml.model._model_composer.model_manifest") + .joinpath("fixtures") # type: ignore[no-untyped-call] + .joinpath("MANIFEST_0.yml") + .read_text() + ), + f.read(), + ) + with open(pathlib.Path(workspace, "functions", "predict.py"), encoding="utf-8") as f: + self.assertEqual( + ( + importlib_resources.files("snowflake.ml.model._model_composer.model_method") + .joinpath("fixtures") # type: ignore[no-untyped-call] + .joinpath("function_1.py") + .read_text() + ), + f.read(), + ) + with open(pathlib.Path(workspace, "functions", "__call__.py"), encoding="utf-8") as f: + self.assertEqual( + ( + importlib_resources.files("snowflake.ml.model._model_composer.model_method") + .joinpath("fixtures") # type: ignore[no-untyped-call] + .joinpath("function_2.py") + .read_text() + ), + f.read(), + ) def test_model_manifest_1(self) -> None: with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: mm = model_manifest.ModelManifest(pathlib.Path(workspace)) with model_meta.create_model_metadata( - model_dir_path=tmpdir, name="model1", model_type="custom", signatures=_DUMMY_SIG + model_dir_path=tmpdir, + name="model1", + model_type="custom", + signatures=_DUMMY_SIG, + python_version="3.8", ) as meta: meta.models["model1"] = _DUMMY_BLOB - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: [""]}, + ): mm.save(self.m_session, meta, pathlib.PurePosixPath("model.zip")) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: - loaded_manifest = yaml.safe_load(f) - self.assertDictEqual( - loaded_manifest, - { - "manifest_version": "1.0", - "runtimes": { - "python_runtime": { - "language": "PYTHON", - "version": meta.env.python_version, - "imports": ["model.zip"], - "dependencies": {"conda": "runtimes/python_runtime/env/conda.yml"}, - } - }, - "methods": [ - { - "name": "PREDICT", - "runtime": "python_runtime", - "type": "FUNCTION", - "handler": "functions.predict.infer", - "inputs": [{"name": "INPUT", "type": "FLOAT"}], - "outputs": [{"type": "OBJECT"}], - } - ], - }, - ) + self.assertEqual( + ( + importlib_resources.files("snowflake.ml.model._model_composer.model_manifest") + .joinpath("fixtures") # type: ignore[no-untyped-call] + .joinpath("MANIFEST_1.yml") + .read_text() + ), + f.read(), + ) with open(pathlib.Path(workspace, "functions", "predict.py"), encoding="utf-8") as f: self.assertEqual( ( importlib_resources.files("snowflake.ml.model._model_composer.model_method") .joinpath("fixtures") # type: ignore[no-untyped-call] - .joinpath("function_fixture_1.py_fixture") + .joinpath("function_1.py") .read_text() ), f.read(), @@ -85,9 +153,14 @@ def test_model_manifest_2(self) -> None: name="model1", model_type="custom", signatures={"__call__": _DUMMY_SIG["predict"]}, + python_version="3.8", ) as meta: meta.models["model1"] = _DUMMY_BLOB - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ): mm.save( self.m_session, meta, @@ -97,37 +170,21 @@ def test_model_manifest_2(self) -> None: ), ) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: - loaded_manifest = yaml.safe_load(f) - self.assertDictEqual( - loaded_manifest, - { - "manifest_version": "1.0", - "runtimes": { - "python_runtime": { - "language": "PYTHON", - "version": meta.env.python_version, - "imports": ["model.zip"], - "dependencies": {"conda": "runtimes/python_runtime/env/conda.yml"}, - } - }, - "methods": [ - { - "name": "__CALL__", - "runtime": "python_runtime", - "type": "FUNCTION", - "handler": "functions.__call__.infer", - "inputs": [{"name": "INPUT", "type": "FLOAT"}], - "outputs": [{"type": "OBJECT"}], - } - ], - }, - ) + self.assertEqual( + ( + importlib_resources.files("snowflake.ml.model._model_composer.model_manifest") + .joinpath("fixtures") # type: ignore[no-untyped-call] + .joinpath("MANIFEST_2.yml") + .read_text() + ), + f.read(), + ) with open(pathlib.Path(workspace, "functions", "__call__.py"), encoding="utf-8") as f: self.assertEqual( ( importlib_resources.files("snowflake.ml.model._model_composer.model_method") .joinpath("fixtures") # type: ignore[no-untyped-call] - .joinpath("function_fixture_2.py_fixture") + .joinpath("function_2.py") .read_text() ), f.read(), @@ -141,9 +198,14 @@ def test_model_manifest_mix(self) -> None: name="model1", model_type="custom", signatures={"predict": _DUMMY_SIG["predict"], "__call__": _DUMMY_SIG["predict"]}, + python_version="3.8", ) as meta: meta.models["model1"] = _DUMMY_BLOB - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=None): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ): mm.save( self.m_session, meta, @@ -156,45 +218,21 @@ def test_model_manifest_mix(self) -> None: ), ) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: - loaded_manifest = yaml.safe_load(f) - self.assertDictEqual( - loaded_manifest, - { - "manifest_version": "1.0", - "runtimes": { - "python_runtime": { - "language": "PYTHON", - "version": meta.env.python_version, - "imports": ["model.zip", "runtimes/python_runtime/snowflake-ml-python.zip"], - "dependencies": {"conda": "runtimes/python_runtime/env/conda.yml"}, - } - }, - "methods": [ - { - "name": "predict", - "runtime": "python_runtime", - "type": "FUNCTION", - "handler": "functions.predict.infer", - "inputs": [{"name": "input", "type": "FLOAT"}], - "outputs": [{"type": "OBJECT"}], - }, - { - "name": "__CALL__", - "runtime": "python_runtime", - "type": "FUNCTION", - "handler": "functions.__call__.infer", - "inputs": [{"name": "INPUT", "type": "FLOAT"}], - "outputs": [{"type": "OBJECT"}], - }, - ], - }, - ) + self.assertEqual( + ( + importlib_resources.files("snowflake.ml.model._model_composer.model_manifest") + .joinpath("fixtures") # type: ignore[no-untyped-call] + .joinpath("MANIFEST_3.yml") + .read_text() + ), + f.read(), + ) with open(pathlib.Path(workspace, "functions", "predict.py"), encoding="utf-8") as f: self.assertEqual( ( importlib_resources.files("snowflake.ml.model._model_composer.model_method") .joinpath("fixtures") # type: ignore[no-untyped-call] - .joinpath("function_fixture_1.py_fixture") + .joinpath("function_1.py") .read_text() ), f.read(), @@ -204,7 +242,7 @@ def test_model_manifest_mix(self) -> None: ( importlib_resources.files("snowflake.ml.model._model_composer.model_method") .joinpath("fixtures") # type: ignore[no-untyped-call] - .joinpath("function_fixture_2.py_fixture") + .joinpath("function_2.py") .read_text() ), f.read(), @@ -220,7 +258,11 @@ def test_model_manifest_bad(self) -> None: signatures={"predict": _DUMMY_SIG["predict"], "PREDICT": _DUMMY_SIG["predict"]}, ) as meta: meta.models["model1"] = _DUMMY_BLOB - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ): with self.assertRaisesRegex( ValueError, "Found duplicate method named resolved as PREDICT in the model." ): @@ -297,6 +339,59 @@ def test_load(self) -> None: self.assertDictEqual(raw_input, mm.load()) + def test_generate_user_data_with_client_data_1(self) -> None: + m_user_data: Dict[str, Any] = {"description": "a"} + with self.assertRaisesRegex(ValueError, "Ill-formatted client data .* in user data found."): + model_manifest.ModelManifest.parse_client_data_from_user_data(m_user_data) + + m_user_data = {model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME: "a"} + with self.assertRaisesRegex(ValueError, "Ill-formatted client data .* in user data found."): + model_manifest.ModelManifest.parse_client_data_from_user_data(m_user_data) + + m_user_data = {model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME: {"description": "a"}} + with self.assertRaisesRegex(ValueError, "Ill-formatted client data .* in user data found."): + model_manifest.ModelManifest.parse_client_data_from_user_data(m_user_data) + + m_user_data = {model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME: {"schema_version": 1}} + with self.assertRaisesRegex(ValueError, "Unsupported client data schema version .* confronted."): + model_manifest.ModelManifest.parse_client_data_from_user_data(m_user_data) + + m_user_data = {model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME: {"schema_version": "2023-12-01"}} + with self.assertRaisesRegex(ValueError, "Unsupported client data schema version .* confronted."): + model_manifest.ModelManifest.parse_client_data_from_user_data(m_user_data) + + m_user_data = { + model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME: { + "schema_version": model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION + } + } + self.assertDictEqual( + model_manifest.ModelManifest.parse_client_data_from_user_data(m_user_data), + {"schema_version": model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION, "functions": []}, + ) + + def test_generate_user_data_with_client_data_2(self) -> None: + m_client_data = { + "schema_version": model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION, + "functions": [ + { + "name": '"predict"', + "target_method": "predict", + "signature": _DUMMY_SIG["predict"].to_dict(), + }, + { + "name": "__CALL__", + "target_method": "__call__", + "signature": _DUMMY_SIG["predict"].to_dict(), + }, + ], + } + m_user_data = {model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME: m_client_data} + self.assertDictEqual( + model_manifest.ModelManifest.parse_client_data_from_user_data(m_user_data), + m_client_data, + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_model_composer/model_method/BUILD.bazel b/snowflake/ml/model/_model_composer/model_method/BUILD.bazel index 27924817..e5bcc0b2 100644 --- a/snowflake/ml/model/_model_composer/model_method/BUILD.bazel +++ b/snowflake/ml/model/_model_composer/model_method/BUILD.bazel @@ -5,8 +5,8 @@ package(default_visibility = ["//visibility:public"]) filegroup( name = "function_fixtures", srcs = [ - "fixtures/function_fixture_1.py_fixture", - "fixtures/function_fixture_2.py_fixture", + "fixtures/function_1.py", + "fixtures/function_2.py", ], ) diff --git a/snowflake/ml/model/_model_composer/model_method/fixtures/function_fixture_1.py_fixture b/snowflake/ml/model/_model_composer/model_method/fixtures/function_1.py similarity index 100% rename from snowflake/ml/model/_model_composer/model_method/fixtures/function_fixture_1.py_fixture rename to snowflake/ml/model/_model_composer/model_method/fixtures/function_1.py diff --git a/snowflake/ml/model/_model_composer/model_method/fixtures/function_fixture_2.py_fixture b/snowflake/ml/model/_model_composer/model_method/fixtures/function_2.py similarity index 100% rename from snowflake/ml/model/_model_composer/model_method/fixtures/function_fixture_2.py_fixture rename to snowflake/ml/model/_model_composer/model_method/fixtures/function_2.py diff --git a/snowflake/ml/model/_model_composer/model_method/function_generator.py b/snowflake/ml/model/_model_composer/model_method/function_generator.py index 192480fa..1cc23e22 100644 --- a/snowflake/ml/model/_model_composer/model_method/function_generator.py +++ b/snowflake/ml/model/_model_composer/model_method/function_generator.py @@ -1,7 +1,6 @@ import pathlib from typing import Optional, TypedDict -import importlib_resources from typing_extensions import NotRequired from snowflake.ml.model import type_hints @@ -33,6 +32,8 @@ def generate( target_method: str, options: Optional[FunctionGenerateOptions] = None, ) -> None: + import importlib_resources + if options is None: options = {} function_template = ( diff --git a/snowflake/ml/model/_model_composer/model_method/function_generator_test.py b/snowflake/ml/model/_model_composer/model_method/function_generator_test.py index c10b06d0..1776963b 100644 --- a/snowflake/ml/model/_model_composer/model_method/function_generator_test.py +++ b/snowflake/ml/model/_model_composer/model_method/function_generator_test.py @@ -20,7 +20,7 @@ def test_function_generator(self) -> None: ( importlib_resources.files("snowflake.ml.model._model_composer.model_method") .joinpath("fixtures") # type: ignore[no-untyped-call] - .joinpath("function_fixture_1.py_fixture") + .joinpath("function_1.py") .read_text() ), f.read(), @@ -35,7 +35,7 @@ def test_function_generator(self) -> None: ( importlib_resources.files("snowflake.ml.model._model_composer.model_method") .joinpath("fixtures") # type: ignore[no-untyped-call] - .joinpath("function_fixture_2.py_fixture") + .joinpath("function_2.py") .read_text() ), f.read(), diff --git a/snowflake/ml/model/_model_composer/model_method/model_method_test.py b/snowflake/ml/model/_model_composer/model_method/model_method_test.py index 0594641d..c6b6e45c 100644 --- a/snowflake/ml/model/_model_composer/model_method/model_method_test.py +++ b/snowflake/ml/model/_model_composer/model_method/model_method_test.py @@ -48,7 +48,7 @@ def test_model_method(self) -> None: ( importlib_resources.files(model_method_pkg) .joinpath("fixtures") # type: ignore[no-untyped-call] - .joinpath("function_fixture_1.py_fixture") + .joinpath("function_1.py") .read_text() ), f.read(), @@ -87,7 +87,7 @@ def test_model_method(self) -> None: ( importlib_resources.files(model_method_pkg) .joinpath("fixtures") # type: ignore[no-untyped-call] - .joinpath("function_fixture_2.py_fixture") + .joinpath("function_2.py") .read_text() ), f.read(), @@ -152,7 +152,7 @@ def test_model_method(self) -> None: ( importlib_resources.files(model_method_pkg) .joinpath("fixtures") # type: ignore[no-untyped-call] - .joinpath("function_fixture_1.py_fixture") + .joinpath("function_1.py") .read_text() ), f.read(), diff --git a/snowflake/ml/model/_model_composer/model_runtime/model_runtime.py b/snowflake/ml/model/_model_composer/model_runtime/model_runtime.py index aa90d3f3..a68a0f87 100644 --- a/snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +++ b/snowflake/ml/model/_model_composer/model_runtime/model_runtime.py @@ -44,12 +44,17 @@ def __init__( if self.runtime_env._snowpark_ml_version.local: self.embed_local_ml_library = True else: - snowml_server_availability = env_utils.validate_requirements_in_information_schema( - session=session, - reqs=[requirements.Requirement(snowml_pkg_spec)], - python_version=snowml_env.PYTHON_VERSION, + snowml_server_availability = ( + len( + env_utils.get_matched_package_versions_in_information_schema( + session=session, + reqs=[requirements.Requirement(snowml_pkg_spec)], + python_version=snowml_env.PYTHON_VERSION, + ).get(env_utils.SNOWPARK_ML_PKG_NAME, []) + ) + >= 1 ) - self.embed_local_ml_library = snowml_server_availability is None + self.embed_local_ml_library = not snowml_server_availability if self.embed_local_ml_library: self.runtime_env.include_if_absent( diff --git a/snowflake/ml/model/_model_composer/model_runtime/model_runtime_test.py b/snowflake/ml/model/_model_composer/model_runtime/model_runtime_test.py index 1a10e220..9ad8b4c3 100644 --- a/snowflake/ml/model/_model_composer/model_runtime/model_runtime_test.py +++ b/snowflake/ml/model/_model_composer/model_runtime/model_runtime_test.py @@ -56,7 +56,11 @@ def test_model_runtime(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: [""]}, + ): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -83,7 +87,11 @@ def test_model_runtime_local_snowml(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=None): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -118,7 +126,11 @@ def test_model_runtime_dup_basic_dep(self) -> None: dep_target.append("pandas") dep_target.sort() - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: [""]}, + ): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -144,7 +156,11 @@ def test_model_runtime_dup_basic_dep_other_channel(self) -> None: dep_target.append("conda-forge::pandas") dep_target.sort() - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: [""]}, + ): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -169,7 +185,11 @@ def test_model_runtime_dup_basic_dep_pip(self) -> None: dep_target.remove(f"pandas=={importlib_metadata.version('pandas')}") dep_target.sort() - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: [""]}, + ): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -194,7 +214,11 @@ def test_model_runtime_additional_conda_dep(self) -> None: dep_target.append("pytorch") dep_target.sort() - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: [""]}, + ): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -218,7 +242,11 @@ def test_model_runtime_additional_pip_dep(self) -> None: dep_target = _BASIC_DEPENDENCIES_TARGET_WITH_SNOWML[:] dep_target.sort() - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: [""]}, + ): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) @@ -244,7 +272,11 @@ def test_model_runtime_additional_dep_both(self) -> None: dep_target.append("pytorch") dep_target.sort() - with mock.patch.object(env_utils, "validate_requirements_in_information_schema", return_value=[""]): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: [""]}, + ): mr = model_runtime.ModelRuntime( self.m_session, "python_runtime", meta, [pathlib.PurePosixPath("model.zip")] ) diff --git a/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py b/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py index 47459cfc..be5ca718 100644 --- a/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +++ b/snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py @@ -59,7 +59,7 @@ def get_requirements_from_task(task: str, spcs_only: bool = False) -> List[model return ( [model_env.ModelDependency(requirement="tokenizers>=0.13.3", pip_name="tokenizers")] if spcs_only - else [model_env.ModelDependency(requirement="tokenizers<=0.13.2", pip_name="tokenizers")] + else [model_env.ModelDependency(requirement="tokenizers", pip_name="tokenizers")] ) return [] diff --git a/snowflake/ml/model/_packager/model_handlers/xgboost.py b/snowflake/ml/model/_packager/model_handlers/xgboost.py index 478cdac8..d92d3729 100644 --- a/snowflake/ml/model/_packager/model_handlers/xgboost.py +++ b/snowflake/ml/model/_packager/model_handlers/xgboost.py @@ -1,6 +1,16 @@ # mypy: disable-error-code="import" import os -from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union, cast, final +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Optional, + Type, + Union, + cast, + final, +) import numpy as np import pandas as pd @@ -150,6 +160,7 @@ def load_model( m.load_model(os.path.join(model_blob_path, model_blob_filename)) if kwargs.get("use_gpu", False): + assert type(kwargs.get("use_gpu", False)) == bool gpu_params = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"} if isinstance(m, xgboost.Booster): m.set_param(gpu_params) @@ -197,7 +208,7 @@ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame: return fn - type_method_dict = {} + type_method_dict: Dict[str, Any] = {"_raw_model": raw_model} for target_method_name, sig in model_meta.signatures.items(): type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name) diff --git a/snowflake/ml/model/_signatures/core.py b/snowflake/ml/model/_signatures/core.py index 1df982fc..a245b181 100644 --- a/snowflake/ml/model/_signatures/core.py +++ b/snowflake/ml/model/_signatures/core.py @@ -146,7 +146,8 @@ def from_snowpark_type(cls, snowpark_type: spt.DataType) -> "DataType": " is being automatically converted to INT64 in the Snowpark DataFrame. " "This automatic conversion may lead to potential precision loss and rounding errors. " "If you wish to prevent this conversion, you should manually perform " - "the necessary data type conversion." + "the necessary data type conversion.", + stacklevel=2, ) return DataType.INT64 else: @@ -155,7 +156,8 @@ def from_snowpark_type(cls, snowpark_type: spt.DataType) -> "DataType": " is being automatically converted to DOUBLE in the Snowpark DataFrame. " "This automatic conversion may lead to potential precision loss and rounding errors. " "If you wish to prevent this conversion, you should manually perform " - "the necessary data type conversion." + "the necessary data type conversion.", + stacklevel=2, ) return DataType.DOUBLE raise snowml_exceptions.SnowflakeMLException( @@ -202,23 +204,24 @@ def __init__( dtype: DataType, shape: Optional[Tuple[int, ...]] = None, ) -> None: - """Initialize a feature. + """ + Initialize a feature. Args: name: Name of the feature. dtype: Type of the elements in the feature. - shape: Used to represent scalar feature, 1-d feature list or n-d tensor. - -1 is used to represent variable length.Defaults to None. + shape: Used to represent scalar feature, 1-d feature list, + or n-d tensor. Use -1 to represent variable length. Defaults to None. - E.g. - None: scalar - (2,): 1d list with fixed len of 2. - (-1,): 1d list with variable length. Used for ragged tensor representation. - (d1, d2, d3): 3d tensor. + Examples: + - None: scalar + - (2,): 1d list with a fixed length of 2. + - (-1,): 1d list with variable length, used for ragged tensor representation. + - (d1, d2, d3): 3d tensor. Raises: - SnowflakeMLException: TypeError: Raised when the dtype input type is incorrect. - SnowflakeMLException: TypeError: Raised when the shape input type is incorrect. + SnowflakeMLException: TypeError: When the dtype input type is incorrect. + SnowflakeMLException: TypeError: When the shape input type is incorrect. """ super().__init__(name=name) @@ -408,13 +411,13 @@ class ModelSignature: """Signature of a model that specifies the input and output of a model.""" def __init__(self, inputs: Sequence[BaseFeatureSpec], outputs: Sequence[BaseFeatureSpec]) -> None: - """Initialize a model signature + """Initialize a model signature. Args: - inputs: A sequence of feature specifications and feature group specifications that will compose the - input of the model. - outputs: A sequence of feature specifications and feature group specifications that will compose the - output of the model. + inputs: A sequence of feature specifications and feature group specifications that will compose + the input of the model. + outputs: A sequence of feature specifications and feature group specifications that will compose + the output of the model. """ self._inputs = inputs self._outputs = outputs diff --git a/snowflake/ml/model/custom_model.py b/snowflake/ml/model/custom_model.py index d1f576ce..f6d37e74 100644 --- a/snowflake/ml/model/custom_model.py +++ b/snowflake/ml/model/custom_model.py @@ -9,15 +9,16 @@ class MethodRef: - """Represents an method invocation of an instance of `ModelRef`. + """Represents a method invocation of an instance of `ModelRef`. + + This allows us to: + 1) Customize the place of actual execution of the method (inline, thread/process pool, or remote). + 2) Enrich the way of execution (sync versus async). - This allows us to - 1) Customize the place of actual execution of the method(inline, thread/process pool or remote). - 2) Enrich the way of execution(sync versus async). Example: - If you have a SKL model, you would normally invoke by `skl_ref.predict(df)` which has sync API. - Within inference graph, you could invoke `await skl_ref.predict.async_run(df)` which automatically - will be run on thread with async interface. + If you have an SKL model, you would normally invoke it by `skl_ref.predict(df)`, which has a synchronous API. + Within the inference graph, you could invoke `await skl_ref.predict.async_run(df)`, which will automatically + run on a thread with an asynchronous interface. """ def __init__(self, model_ref: "ModelRef", method_name: str) -> None: @@ -27,11 +28,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return self._func(*args, **kwargs) async def async_run(self, *args: Any, **kwargs: Any) -> Any: - """Run the method in a async way. If the method is defined as async, this will simply run it. If not, this will - be run in a separate thread. + """Run the method in an asynchronous way. If the method is defined as async, this will simply run it. + If not, this will be run in a separate thread. Args: - *args: Arguments of the original method, + *args: Arguments of the original method. **kwargs: Keyword arguments of the original method. Returns: @@ -43,19 +44,20 @@ async def async_run(self, *args: Any, **kwargs: Any) -> Any: class ModelRef: - """Represents an model in the inference graph. Method could be directly called using this reference object as if - with the original model object. + """ + Represents a model in the inference graph. Methods can be directly called using this reference object + as if with the original model object. - This enables us to separate physical and logical representation of a model which - will allows us to deeply understand the graph and perform optimization at entire - graph level. + This enables us to separate the physical and logical representation of a model, allowing for a deep understanding + of the graph and enabling optimization at the entire graph level. """ def __init__(self, name: str, model: model_types.SupportedModelType) -> None: - """Initialize the ModelRef. + """ + Initialize the ModelRef. Args: - name: The name of a model to refer it. + name: The name of the model to refer to. model: The model object. """ self._model = model @@ -91,11 +93,12 @@ def __setstate__(self, state: Any) -> None: class ModelContext: - """Context for a custom model showing path to artifacts and mapping between model name and object reference. + """ + Context for a custom model showing paths to artifacts and mapping between model name and object reference. Attributes: - artifacts: A dict mapping name of the artifact to its path. - model_refs: A dict mapping name of the sub-model to its ModelRef object. + artifacts: A dictionary mapping the name of the artifact to its path. + model_refs: A dictionary mapping the name of the sub-model to its ModelRef object. """ def __init__( @@ -104,11 +107,11 @@ def __init__( artifacts: Optional[Dict[str, str]] = None, models: Optional[Dict[str, model_types.SupportedModelType]] = None, ) -> None: - """Initialize the model context + """Initialize the model context. Args: - artifacts: A dict mapping name of the artifact to its currently available path. Defaults to None. - models: A dict mapping name of the sub-model to the corresponding model object. Defaults to None. + artifacts: A dictionary mapping the name of the artifact to its currently available path. Defaults to None. + models: A dictionary mapping the name of the sub-model to the corresponding model object. Defaults to None. """ self.artifacts: Dict[str, str] = artifacts if artifacts else dict() self.model_refs: Dict[str, ModelRef] = ( @@ -116,7 +119,8 @@ def __init__( ) def path(self, key: str) -> str: - """Get the actual path to a specific artifact. + """Get the actual path to a specific artifact. This could be used when defining a Custom Model to retrieve + artifacts. Args: key: The name of the artifact. @@ -127,14 +131,13 @@ def path(self, key: str) -> str: return self.artifacts[key] def model_ref(self, name: str) -> ModelRef: - """Get a ModelRef object of a sub-model containing the name and model object, while able to call its method - directly as well. + """Get a ModelRef object of a sub-model containing the name and model object, allowing direct method calls. Args: name: The name of the sub-model. Returns: - The ModelRef object to the sub-model. + The ModelRef object representing the sub-model. """ return self.model_refs[name] diff --git a/snowflake/ml/model/model_signature.py b/snowflake/ml/model/model_signature.py index 7abfa9f3..4309ecce 100644 --- a/snowflake/ml/model/model_signature.py +++ b/snowflake/ml/model/model_signature.py @@ -570,32 +570,31 @@ def infer_signature( input_feature_names: Optional[List[str]] = None, output_feature_names: Optional[List[str]] = None, ) -> core.ModelSignature: - """Infer model signature from given input and output sample data. + """ + Infer model signature from given input and output sample data. + + Currently supports inferring model signatures from the following data types: - Currently, we support infer the model signature from example input/output data in the following cases: - - Pandas data frame whose column could have types of supported data types, - list (including list of supported data types, list of numpy array of supported data types, and nested list), - and numpy array of supported data types. + - Pandas DataFrame with columns of supported data types, lists (including nested lists) of supported data types, + and NumPy arrays of supported data types. - Does not support DataFrame with CategoricalIndex column index. - - Does not support DataFrame with column of variant length list or numpy array. - - Numpy array of supported data types. - - List of Numpy array of supported data types. - - List of supported data types, or nested list of supported data types. - - Does not support list of list of variant length list. + - NumPy arrays of supported data types. + - Lists of NumPy arrays of supported data types. + - Lists of supported data types or nested lists of supported data types. + + When inferring the signature, a ValueError indicates that the data is insufficient or invalid. - When a ValueError is raised when inferring the signature, it indicates that the data is ill and it is impossible to - create a signature reflecting that. - When a NotImplementedError is raised, it indicates that it might be possible to create a signature reflecting the - provided data, however, we could not infer it. + When it might be possible to create a signature reflecting the provided data, but it could not be inferred, + a NotImplementedError is raised Args: input_data: Sample input data for the model. output_data: Sample output data for the model. - input_feature_names: Name for input features. Defaults to None. - output_feature_names: Name for output features. Defaults to None. + input_feature_names: Names for input features. Defaults to None. + output_feature_names: Names for output features. Defaults to None. Returns: - A model signature. + A model signature inferred from the given input and output sample data. """ inputs = _infer_signature(input_data, role="input") inputs = utils.rename_features(inputs, input_feature_names) diff --git a/snowflake/ml/model/package_visibility_test.py b/snowflake/ml/model/package_visibility_test.py new file mode 100644 index 00000000..fe74e208 --- /dev/null +++ b/snowflake/ml/model/package_visibility_test.py @@ -0,0 +1,34 @@ +from types import ModuleType + +from absl.testing import absltest + +from snowflake.ml import model +from snowflake.ml.model import ( + _api, + custom_model, + deploy_platforms, + model_signature, + type_hints, +) + + +class PackageVisibilityTest(absltest.TestCase): + """Ensure that the functions in this package are visible externally.""" + + def test_class_visible(self) -> None: + self.assertIsInstance(model.Model, type) + self.assertIsInstance(model.ModelVersion, type) + self.assertIsInstance(model.HuggingFacePipelineModel, type) + self.assertIsInstance(model.LLM, type) + self.assertIsInstance(model.LLMOptions, type) + + def test_module_visible(self) -> None: + self.assertIsInstance(_api, ModuleType) + self.assertIsInstance(custom_model, ModuleType) + self.assertIsInstance(model_signature, ModuleType) + self.assertIsInstance(deploy_platforms, ModuleType) + self.assertIsInstance(type_hints, ModuleType) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/modeling/_internal/BUILD.bazel b/snowflake/ml/modeling/_internal/BUILD.bazel index c4d34c58..1ab6e5c2 100644 --- a/snowflake/ml/modeling/_internal/BUILD.bazel +++ b/snowflake/ml/modeling/_internal/BUILD.bazel @@ -51,6 +51,7 @@ py_library( name = "model_specifications", srcs = ["model_specifications.py"], deps = [ + ":estimator_utils", "//snowflake/ml/_internal/exceptions", ], ) @@ -59,7 +60,9 @@ py_test( name = "model_specifications_test", srcs = ["model_specifications_test.py"], deps = [ + ":distributed_hpo_trainer", ":model_specifications", + "//snowflake/ml/utils:connection_params", ], ) @@ -88,6 +91,7 @@ py_library( "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/exceptions:modeling_error_messages", "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:pkg_version_utils", "//snowflake/ml/_internal/utils:query_result_checker", "//snowflake/ml/_internal/utils:snowpark_dataframe_utils", "//snowflake/ml/_internal/utils:temp_file_utils", @@ -105,6 +109,23 @@ py_library( "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/exceptions:modeling_error_messages", "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:pkg_version_utils", + "//snowflake/ml/_internal/utils:snowpark_dataframe_utils", + "//snowflake/ml/_internal/utils:temp_file_utils", + ], +) + +py_library( + name = "xgboost_external_memory_trainer", + srcs = ["xgboost_external_memory_trainer.py"], + deps = [ + ":model_specifications", + ":snowpark_trainer", + "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/exceptions", + "//snowflake/ml/_internal/exceptions:modeling_error_messages", + "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:pkg_version_utils", "//snowflake/ml/_internal/utils:snowpark_dataframe_utils", "//snowflake/ml/_internal/utils:temp_file_utils", ], @@ -115,8 +136,33 @@ py_library( srcs = ["model_trainer_builder.py"], deps = [ ":distributed_hpo_trainer", + ":estimator_utils", ":model_trainer", ":pandas_trainer", ":snowpark_trainer", + ":xgboost_external_memory_trainer", + ], +) + +py_test( + name = "model_trainer_builder_test", + srcs = ["model_trainer_builder_test.py"], + deps = [ + ":distributed_hpo_trainer", + ":model_trainer", + ":model_trainer_builder", + ":pandas_trainer", + ":snowpark_trainer", + ":xgboost_external_memory_trainer", + "//snowflake/ml/utils:connection_params", + ], +) + +py_test( + name = "xgboost_external_memory_trainer_test", + srcs = ["xgboost_external_memory_trainer_test.py"], + deps = [ + ":xgboost_external_memory_trainer", + "//snowflake/ml/_internal/utils:temp_file_utils", ], ) diff --git a/snowflake/ml/modeling/_internal/distributed_hpo_trainer.py b/snowflake/ml/modeling/_internal/distributed_hpo_trainer.py index 9dc57ee2..8928f61b 100644 --- a/snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +++ b/snowflake/ml/modeling/_internal/distributed_hpo_trainer.py @@ -4,15 +4,18 @@ import os import posixpath import sys -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import cloudpickle as cp import numpy as np -from scipy.stats import rankdata from sklearn import model_selection from snowflake.ml._internal import telemetry -from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils +from snowflake.ml._internal.utils import ( + identifier, + pkg_version_utils, + snowpark_dataframe_utils, +) from snowflake.ml._internal.utils.temp_file_utils import ( cleanup_temp_files, get_temp_file_path, @@ -26,7 +29,8 @@ TempObjectType, random_name_for_temp_object, ) -from snowflake.snowpark.functions import col, sproc, udtf +from snowflake.snowpark.functions import sproc, udtf +from snowflake.snowpark.row import Row from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path)) @@ -36,6 +40,153 @@ DEFAULT_UDTF_NJOBS = 3 +def construct_cv_results( + cv_results_raw_hex: List[Row], + cross_validator_indices_length: int, + parameter_grid_length: int, + search_cv_kwargs: Dict[str, Any], +) -> Tuple[bool, Dict[str, Any], int, Set[str]]: + """Construct the cross validation result from the UDF. Because we accelerate the process + by the number of cross validation number, and the combination of parameter grids. + Therefore, we need to stick them back together instead of returning the raw result + to align with original sklearn result. + + Args: + cv_results_raw_hex (List[Row]): the list of cv_results from each cv and parameter grid combination. + Because UDxF can only return string, and numpy array/masked arrays cannot be encoded in a + json format. Each cv_result is encoded into hex string. + cross_validator_indices_length (int): the length of cross validator indices + parameter_grid_length (int): the length of parameter grid combination + search_cv_kwargs (Dict[str, Any]): the kwargs for GridSearchCV/RandomSearchCV. + + Raises: + ValueError: Retrieved empty cross validation results + ValueError: Cross validator index length is 0 + ValueError: Parameter index length is 0 + ValueError: Retrieved incorrect dataframe dimension from Snowpark's UDTF. + RuntimeError: Cross validation results are unexpectedly empty for one fold. + + Returns: + Tuple[bool, Dict[str, Any], int, Set[str]]: returns multimetric, cv_results_, best_param_index, scorers + """ + # Filter corner cases: either the snowpark dataframe result is empty; or index length is empty + if len(cv_results_raw_hex) == 0: + raise ValueError( + "Retrieved empty cross validation results from snowpark. Please retry or contact snowflake support." + ) + if cross_validator_indices_length == 0: + raise ValueError("Cross validator index length is 0. Was the CV iterator empty? ") + if parameter_grid_length == 0: + raise ValueError("Parameter index length is 0. Were there no candidates?") + + from scipy.stats import rankdata + + # cv_result maintains the original order + multimetric = False + cv_results_ = dict() + scorers = set() + # retrieve the cv_results from udtf table; results are encoded by hex and cloudpickle; + # We are constructing the raw information back to original form + if len(cv_results_raw_hex) != cross_validator_indices_length * parameter_grid_length: + raise ValueError( + "Retrieved incorrect dataframe dimension from Snowpark's UDTF." + f"Expected {cross_validator_indices_length * parameter_grid_length}, got {len(cv_results_raw_hex)}. " + "Please retry or contact snowflake support." + ) + + for param_cv_indices, each_cv_result_hex in enumerate(cv_results_raw_hex): + # convert the hex string back to cv_results_ + hex_str = bytes.fromhex(each_cv_result_hex[0]) + with io.BytesIO(hex_str) as f_reload: + each_cv_result = cp.load(f_reload) + if not each_cv_result: + raise RuntimeError( + "Cross validation response is empty. This issue may be temporary - please try again." + ) + for k, v in each_cv_result.items(): + cur_cv_idx = param_cv_indices % cross_validator_indices_length + key = k + if "split0_test_" in k: + # For multi-metric evaluation, the scores for all the scorers are available in the + # cv_results_ dict at the keys ending with that scorer’s name ('_') + # instead of '_score'. + scorers.add(k[len("split0_test_") :]) + key = k.replace("split0_test", f"split{cur_cv_idx}_test") + if search_cv_kwargs.get("return_train_score", None) and "split0_train_" in k: + key = k.replace("split0_train", f"split{cur_cv_idx}_train") + elif k.startswith("param"): + if cur_cv_idx != 0: + continue + if key: + if key not in cv_results_: + cv_results_[key] = v + else: + cv_results_[key] = np.concatenate([cv_results_[key], v]) + + multimetric = len(scorers) > 1 + # Use numpy to re-calculate all the information in cv_results_ again + # Generally speaking, reshape all the results into the (scorers+2, idx_length, params_length) shape, + # and average them by the idx_length; + # idx_length is the number of cv folds; params_length is the number of parameter combinations + scores_test = [ + np.reshape( + np.concatenate( + [cv_results_[f"split{cur_cv}_test_{score}"] for cur_cv in range(cross_validator_indices_length)] + ), + (cross_validator_indices_length, -1), + ) + for score in scorers + ] + + fit_score_test_matrix = np.stack( + [ + np.reshape(cv_results_["mean_fit_time"], (cross_validator_indices_length, -1)), + np.reshape(cv_results_["mean_score_time"], (cross_validator_indices_length, -1)), + ] + + scores_test + ) + mean_fit_score_test_matrix = np.mean(fit_score_test_matrix, axis=1) + std_fit_score_test_matrix = np.std(fit_score_test_matrix, axis=1) + + if search_cv_kwargs.get("return_train_score", None): + scores_train = [ + np.reshape( + np.concatenate( + [cv_results_[f"split{cur_cv}_train_{score}"] for cur_cv in range(cross_validator_indices_length)] + ), + (cross_validator_indices_length, -1), + ) + for score in scorers + ] + mean_fit_score_train_matrix = np.mean(scores_train, axis=1) + std_fit_score_train_matrix = np.std(scores_train, axis=1) + + cv_results_["std_fit_time"] = std_fit_score_test_matrix[0] + cv_results_["mean_fit_time"] = mean_fit_score_test_matrix[0] + cv_results_["std_score_time"] = std_fit_score_test_matrix[1] + cv_results_["mean_score_time"] = mean_fit_score_test_matrix[1] + for idx, score in enumerate(scorers): + cv_results_[f"std_test_{score}"] = std_fit_score_test_matrix[idx + 2] + cv_results_[f"mean_test_{score}"] = mean_fit_score_test_matrix[idx + 2] + if search_cv_kwargs.get("return_train_score", None): + cv_results_[f"std_train_{score}"] = std_fit_score_train_matrix[idx] + cv_results_[f"mean_train_{score}"] = mean_fit_score_train_matrix[idx] + # re-compute the ranking again with mean_test_. + cv_results_[f"rank_test_{score}"] = rankdata(-cv_results_[f"mean_test_{score}"], method="min") + # The best param is the highest ranking (which is 1) and we choose the first time ranking 1 appeared. + # If all scores are `nan`, `rankdata` will also produce an array of `nan` values. + # In that case, default to first index. + best_param_index = ( + np.where(cv_results_[f"rank_test_{score}"] == 1)[0][0] + if not np.isnan(cv_results_[f"rank_test_{score}"]).all() + else 0 + ) + return multimetric, cv_results_, best_param_index, scorers + + +cp.register_pickle_by_value(inspect.getmodule(construct_cv_results)) + + class DistributedHPOTrainer(SnowparkModelTrainer): """ A class for performing distributed hyperparameter optimization (HPO) using Snowpark. @@ -105,7 +256,7 @@ def fit_search_snowpark( temp_stage_creation_query = f"CREATE OR REPLACE TEMP STAGE {temp_stage_name};" session.sql(temp_stage_creation_query).collect() - # Stage data. + # Stage data as parquet file dataset = snowpark_dataframe_utils.cast_snowpark_dataframe(dataset) remote_file_path = f"{temp_stage_name}/{temp_stage_name}.parquet" dataset.write.copy_into_location( # type:ignore[call-overload] @@ -114,6 +265,7 @@ def fit_search_snowpark( imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}").collect()] # Store GridSearchCV's refit variable. If user set it as False, we don't need to refit it again + # refit variable can be boolean, string or callable original_refit = estimator.refit # Create a temp file and dump the estimator to that file. @@ -208,7 +360,7 @@ def _distributed_search( for file_name in data_files ] df = pd.concat(partial_df, ignore_index=True) - df.columns = [identifier.get_inferred_name(col) for col in df.columns] + df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns] X = df[input_cols] y = df[label_cols].squeeze() if label_cols else None @@ -222,11 +374,12 @@ def _distributed_search( with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj: estimator = cp.load(local_estimator_file_obj)["estimator"] - cv_orig = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator)) - indices = [test for _, test in cv_orig.split(X, y)] + build_cross_validator = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator)) + # store the cross_validator's test indices only to save space + cross_validator_indices = [test for _, test in build_cross_validator.split(X, y)] local_indices_file_name = get_temp_file_path() with open(local_indices_file_name, mode="w+b") as local_indices_file_obj: - cp.dump(indices, local_indices_file_obj) + cp.dump(cross_validator_indices, local_indices_file_obj) # Put locally serialized indices on stage. put_result = session.file.put( @@ -237,7 +390,8 @@ def _distributed_search( ) indices_location = put_result[0].target imports.append(f"@{temp_stage_name}/{indices_location}") - indices_len = len(indices) + cross_validator_indices_length = int(len(cross_validator_indices)) + parameter_grid_length = len(param_grid) assert estimator is not None @@ -261,7 +415,7 @@ def _load_data_into_udf() -> Tuple[ for file_name in data_files ] df = pd.concat(partial_df, ignore_index=True) - df.columns = [identifier.get_inferred_name(col) for col in df.columns] + df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns] # load estimator local_estimator_file_path = os.path.join( @@ -299,16 +453,30 @@ def __init__(self) -> None: self.data_length = data_length self.params_to_evaluate = params_to_evaluate - def process(self, params_idx: int, idx: int) -> Iterator[Tuple[str]]: + def process(self, params_idx: int, cv_idx: int) -> Iterator[Tuple[str]]: + # Assign parameter to GridSearchCV if hasattr(estimator, "param_grid"): self.estimator.param_grid = self.params_to_evaluate[params_idx] + # Assign parameter to RandomizedSearchCV else: self.estimator.param_distributions = self.params_to_evaluate[params_idx] + # cross validator's indices: we stored test indices only (to save space); + # use the full indices to re-construct the train indices back. full_indices = np.array([i for i in range(self.data_length)]) - test_indice = self.indices[idx] + test_indice = self.indices[cv_idx] train_indice = np.setdiff1d(full_indices, test_indice) + # assign the tuple of train and test indices to estimator's original cross validator self.estimator.cv = [(train_indice, test_indice)] self.estimator.fit(**self.args) + # If the cv_results_ is empty, then the udtf table will have different number of output rows + # from the input rows. Raise ValueError. + if not self.estimator.cv_results_: + raise RuntimeError( + """Cross validation results are unexpectedly empty for one fold. + This issue may be temporary - please try again.""" + ) + # Encode the dictionary of cv_results_ as binary (in hex format) to send it back + # because udtf doesn't allow numpy within json file binary_cv_results = None with io.BytesIO() as f: cp.dump(self.estimator.cv_results_, f) @@ -333,96 +501,44 @@ def end_partition(self) -> None: HP_TUNING = F.table_function(random_udtf_name) - idx_length = int(indices_len) - params_length = len(param_grid) - idxs = [i for i in range(idx_length)] - param_indices, training_indices = [], [] - for param_idx, cv_idx in product([param_index for param_index in range(params_length)], idxs): + # param_indices is for the index for each parameter grid; + # cv_indices is for the index for each cross_validator's fold; + # param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices)) + param_indices, cv_indices = [], [] + for param_idx, cv_idx in product( + [param_index for param_index in range(parameter_grid_length)], + [cv_index for cv_index in range(cross_validator_indices_length)], + ): param_indices.append(param_idx) - training_indices.append(cv_idx) + cv_indices.append(cv_idx) - pd_df = pd.DataFrame( + indices_info_pandas = pd.DataFrame( { - "PARAMS": param_indices, - "TRAIN_IND": training_indices, - "PARAM_INDEX": [i for i in range(idx_length * params_length)], + "PARAM_IND": param_indices, + "CV_IND": cv_indices, + "PARAM_CV_IND": [i for i in range(cross_validator_indices_length * parameter_grid_length)], } ) - df = session.create_dataframe(pd_df) - results = df.select( - F.cast(df["PARAM_INDEX"], IntegerType()).as_("PARAM_INDEX"), - (HP_TUNING(df["PARAMS"], df["TRAIN_IND"]).over(partition_by=df["PARAM_INDEX"])), + indices_info_sp = session.create_dataframe(indices_info_pandas) + # execute udtf by querying HP_TUNING table + HP_raw_results = indices_info_sp.select( + F.cast(indices_info_sp["PARAM_CV_IND"], IntegerType()).as_("PARAM_CV_IND"), + ( + HP_TUNING(indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]).over( + partition_by=indices_info_sp["PARAM_CV_IND"] + ) + ), ) - # cv_result maintains the original order - multimetric = False - cv_results_ = dict() - scorers = set() - for i, val in enumerate(results.select("CV_RESULTS").sort(col("PARAM_INDEX")).collect()): - # retrieved string had one more double quote in the front and end of the string. - # use [1:-1] to remove the extra double quotes - hex_str = bytes.fromhex(val[0]) - with io.BytesIO(hex_str) as f_reload: - each_cv_result = cp.load(f_reload) - for k, v in each_cv_result.items(): - cur_cv = i % idx_length - key = k - if "split0_test_" in k: - # For multi-metric evaluation, the scores for all the scorers are available in the - # cv_results_ dict at the keys ending with that scorer’s name ('_') - # instead of '_score'. - scorers.add(k[len("split0_test_") :]) - key = k.replace("split0_test", f"split{cur_cv}_test") - elif k.startswith("param"): - if cur_cv != 0: - key = False - if key: - if key not in cv_results_: - cv_results_[key] = v - else: - cv_results_[key] = np.concatenate([cv_results_[key], v]) - - multimetric = len(scorers) > 1 - # Use numpy to re-calculate all the information in cv_results_ again - # Generally speaking, reshape all the results into the (scorers+2, idx_length, params_length) shape, - # and average them by the idx_length; - # idx_length is the number of cv folds; params_length is the number of parameter combinations - scores = [ - np.reshape( - np.concatenate([cv_results_[f"split{cur_cv}_test_{score}"] for cur_cv in range(idx_length)]), - (idx_length, -1), - ) - for score in scorers - ] - - fit_score_test_matrix = np.stack( - [ - np.reshape(cv_results_["mean_fit_time"], (idx_length, -1)), - np.reshape(cv_results_["mean_score_time"], (idx_length, -1)), - ] - + scores + multimetric, cv_results_, best_param_index, scorers = construct_cv_results( + HP_raw_results.select("CV_RESULTS").sort(F.col("PARAM_CV_IND")).collect(), + cross_validator_indices_length, + parameter_grid_length, + { + "return_train_score": estimator.return_train_score, + }, # TODO(xjiang): support more kwargs in here ) - mean_fit_score_test_matrix = np.mean(fit_score_test_matrix, axis=1) - std_fit_score_test_matrix = np.std(fit_score_test_matrix, axis=1) - cv_results_["std_fit_time"] = std_fit_score_test_matrix[0] - cv_results_["mean_fit_time"] = mean_fit_score_test_matrix[0] - cv_results_["std_score_time"] = std_fit_score_test_matrix[1] - cv_results_["mean_score_time"] = mean_fit_score_test_matrix[1] - for idx, score in enumerate(scorers): - cv_results_[f"std_test_{score}"] = std_fit_score_test_matrix[idx + 2] - cv_results_[f"mean_test_{score}"] = mean_fit_score_test_matrix[idx + 2] - # re-compute the ranking again with mean_test_. - cv_results_[f"rank_test_{score}"] = rankdata(-cv_results_[f"mean_test_{score}"], method="min") - # The best param is the highest ranking (which is 1) and we choose the first time ranking 1 appeared. - # If all scores are `nan`, `rankdata` will also produce an array of `nan` values. - # In that case, default to first index. - best_param_index = ( - np.where(cv_results_[f"rank_test_{score}"] == 1)[0][0] - if not np.isnan(cv_results_[f"rank_test_{score}"]).all() - else 0 - ) - estimator.cv_results_ = cv_results_ estimator.multimetric_ = multimetric @@ -541,12 +657,15 @@ def train(self) -> object: n_iter=self.estimator.n_iter, random_state=self.estimator.random_state, ) + relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( + pkg_versions=model_spec.pkgDependencies, session=self.session + ) return self.fit_search_snowpark( param_grid=param_grid, dataset=self.dataset, session=self.session, estimator=self.estimator, - dependencies=model_spec.pkgDependencies, + dependencies=relaxed_dependencies, udf_imports=["sklearn"], input_cols=self.input_cols, label_cols=self.label_cols, diff --git a/snowflake/ml/modeling/_internal/estimator_utils.py b/snowflake/ml/modeling/_internal/estimator_utils.py index e08306d5..74edf8a2 100644 --- a/snowflake/ml/modeling/_internal/estimator_utils.py +++ b/snowflake/ml/modeling/_internal/estimator_utils.py @@ -132,3 +132,24 @@ def is_single_node(session: Session) -> bool: # If current session cannot retrieve the warehouse name back, # Default as True; Let HPO fall back to stored procedure implementation return True + + +def get_module_name(model: object) -> str: + """Returns the source module of the given object. + + Args: + model: Object to inspect. + + Returns: + Source module of the given object. + + Raises: + SnowflakeMLException: If the source module of the given object is not found. + """ + module = inspect.getmodule(model) + if module is None: + raise exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_TYPE, + original_exception=ValueError(f"Unable to infer the source module of the given object {model}."), + ) + return module.__name__ diff --git a/snowflake/ml/modeling/_internal/model_specifications.py b/snowflake/ml/modeling/_internal/model_specifications.py index e6f375c5..d9d1cebd 100644 --- a/snowflake/ml/modeling/_internal/model_specifications.py +++ b/snowflake/ml/modeling/_internal/model_specifications.py @@ -1,10 +1,9 @@ -import inspect from typing import List import cloudpickle as cp import numpy as np -from snowflake.ml._internal.exceptions import error_codes, exceptions +from snowflake.ml.modeling._internal.estimator_utils import get_module_name class ModelSpecifications: @@ -120,16 +119,10 @@ def build(cls, model: object) -> ModelSpecifications: Appropriate ModelSpecification object Raises: - SnowflakeMLException: Raises an exception the module of given model can't be determined. TypeError: Raises the exception for unsupported modules. """ - module = inspect.getmodule(model) - if module is None: - raise exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_TYPE, - original_exception=ValueError("Unable to infer model type of the given native model object."), - ) - root_module_name = module.__name__.split(".")[0] + module_name = get_module_name(model=model) + root_module_name = module_name.split(".")[0] if root_module_name == "sklearn": from sklearn.model_selection import GridSearchCV, RandomizedSearchCV diff --git a/snowflake/ml/modeling/_internal/model_specifications_test.py b/snowflake/ml/modeling/_internal/model_specifications_test.py index 26671eb2..a7ec2a26 100644 --- a/snowflake/ml/modeling/_internal/model_specifications_test.py +++ b/snowflake/ml/modeling/_internal/model_specifications_test.py @@ -1,18 +1,214 @@ -from typing import Any +import io +from typing import Any, Dict from unittest import mock +import cloudpickle as cp +import numpy as np from absl.testing import absltest, parameterized from lightgbm import LGBMRegressor from sklearn.linear_model import LinearRegression from sklearn.model_selection import GridSearchCV from xgboost import XGBRegressor +from snowflake.ml.modeling._internal.distributed_hpo_trainer import construct_cv_results from snowflake.ml.modeling._internal.model_specifications import ( ModelSpecificationsBuilder, ) +from snowflake.snowpark import Row + +each_cv_result_basic_sample = [ + { + "mean_fit_time": np.array([0.00315547]), + "std_fit_time": np.array([0.0]), + "mean_score_time": np.array([0.00176454]), + "std_score_time": np.array([0.0]), + "param_n_components": np.ma.array( + data=[2], mask=[False], fill_value="?", dtype=object + ), # type: ignore[no-untyped-call] + "params": [{"n_components": 2}], + "split0_test_score": np.array([-13.61564833]), + "mean_test_score": np.array([-13.61564833]), + "std_test_score": np.array([0.0]), + "rank_test_score": np.array([1], dtype=np.int32), + }, + { + "mean_fit_time": np.array([0.00257707]), + "std_fit_time": np.array([0.0]), + "mean_score_time": np.array([0.00151849]), + "std_score_time": np.array([0.0]), + "param_n_components": np.ma.array( + data=[2], mask=[False], fill_value="?", dtype=object + ), # type: ignore[no-untyped-call] + "params": [{"n_components": 2}], + "split0_test_score": np.array([-8.57012999]), + "mean_test_score": np.array([-8.57012999]), + "std_test_score": np.array([0.0]), + "rank_test_score": np.array([1], dtype=np.int32), + }, + { + "mean_fit_time": np.array([0.00270677]), + "std_fit_time": np.array([0.0]), + "mean_score_time": np.array([0.00146675]), + "std_score_time": np.array([0.0]), + "param_n_components": np.ma.array( + data=[1], mask=[False], fill_value="?", dtype=object + ), # type: ignore[no-untyped-call] + "params": [{"n_components": 1}], + "split0_test_score": np.array([-12.50893109]), + "mean_test_score": np.array([-12.50893109]), + "std_test_score": np.array([0.0]), + "rank_test_score": np.array([1], dtype=np.int32), + }, + { + "mean_fit_time": np.array([0.00293922]), + "std_fit_time": np.array([0.0]), + "mean_score_time": np.array([0.00342846]), + "std_score_time": np.array([0.0]), + "param_n_components": np.ma.array( + data=[1], mask=[False], fill_value="?", dtype=object + ), # type: ignore[no-untyped-call] + "params": [{"n_components": 1}], + "split0_test_score": np.array([-21.4394793]), + "mean_test_score": np.array([-21.4394793]), + "std_test_score": np.array([0.0]), + "rank_test_score": np.array([1], dtype=np.int32), + }, + { + "mean_fit_time": np.array([0.00297642]), + "std_fit_time": np.array([0.0]), + "mean_score_time": np.array([0.00161123]), + "std_score_time": np.array([0.0]), + "param_n_components": np.ma.array( + data=[1], mask=[False], fill_value="?", dtype=object + ), # type: ignore[no-untyped-call] + "params": [{"n_components": 1}], + "split0_test_score": np.array([-9.62685757]), + "mean_test_score": np.array([-9.62685757]), + "std_test_score": np.array([0.0]), + "rank_test_score": np.array([1], dtype=np.int32), + }, + { + "mean_fit_time": np.array([0.00596809]), + "std_fit_time": np.array([0.0]), + "mean_score_time": np.array([0.00264239]), + "std_score_time": np.array([0.0]), + "param_n_components": np.ma.array( + data=[2], mask=[False], fill_value="?", dtype=object + ), # type: ignore[no-untyped-call] + "params": [{"n_components": 2}], + "split0_test_score": np.array([-29.95119419]), + "mean_test_score": np.array([-29.95119419]), + "std_test_score": np.array([0.0]), + "rank_test_score": np.array([1], dtype=np.int32), + }, +] + +each_cv_result_return_train = [ + { + "mean_fit_time": np.array([0.00315547]), + "std_fit_time": np.array([0.0]), + "mean_score_time": np.array([0.00176454]), + "std_score_time": np.array([0.0]), + "param_n_components": np.ma.array( + data=[2], mask=[False], fill_value="?", dtype=object + ), # type: ignore[no-untyped-call] + "params": [{"n_components": 2}], + "split0_train_score": np.array([-13.61564833]), + "split0_test_score": np.array([-13.61564833]), + "mean_train_score": np.array([-13.61564833]), + "std_train_score": np.array([0.0]), + "mean_test_score": np.array([-13.61564833]), + "std_test_score": np.array([0.0]), + "rank_test_score": np.array([1], dtype=np.int32), + }, + { + "mean_fit_time": np.array([0.00257707]), + "std_fit_time": np.array([0.0]), + "mean_score_time": np.array([0.00151849]), + "std_score_time": np.array([0.0]), + "param_n_components": np.ma.array( + data=[2], mask=[False], fill_value="?", dtype=object + ), # type: ignore[no-untyped-call] + "params": [{"n_components": 2}], + "split0_train_score": np.array([-8.57012999]), + "split0_test_score": np.array([-8.57012999]), + "mean_train_score": np.array([-8.57012999]), + "std_train_score": np.array([0.0]), + "mean_test_score": np.array([-8.57012999]), + "std_test_score": np.array([0.0]), + "rank_test_score": np.array([1], dtype=np.int32), + }, +] + +SAMPLES: Dict[str, Dict[str, Any]] = { + "basic": { + "each_cv_result": each_cv_result_basic_sample, + "IDX_LENGTH": 3, + "PARAM_LENGTH": 2, + "CV_RESULT_": { + "mean_fit_time": np.array([0.00770839, 0.00551335]), + "std_fit_time": np.array([0.00061078, 0.00179875]), + "mean_score_time": np.array([0.00173187, 0.00182652]), + "std_score_time": np.array([0.00016869, 0.00014979]), + "param_n_components": np.ma.masked_array( + data=[1, 2], mask=False, fill_value="?", dtype=object + ), # type: ignore[no-untyped-call] + "params": np.array([{"n_components": 1}, {"n_components": 2}], dtype=object), + "split0_test_score": np.array([-21.4394793, -29.95119419]), + "mean_test_score": np.array([-14.52508932, -17.37899084]), + "std_test_score": np.array([5.02879565, 9.12540544]), + "rank_test_score": np.array([1, 2]), + "split1_test_score": np.array([-9.62685757, -8.57012999]), + "split2_test_score": np.array([-12.50893109, -13.61564833]), + }, + }, + "return_train_score": { + "each_cv_result": each_cv_result_return_train, + "IDX_LENGTH": 2, + "PARAM_LENGTH": 1, + "CV_RESULT_": { + "mean_fit_time": np.array( + [ + 0.00286627, + ] + ), + "std_fit_time": np.array([0.0002892]), + "mean_score_time": np.array([0.00164152]), + "std_score_time": np.array([0.00012303]), + "param_n_components": np.ma.masked_array( + data=[2], mask=False, fill_value="?", dtype=object + ), # type: ignore[no-untyped-call] + "params": np.array([{"n_components": 2}], dtype=object), + "mean_train_score": np.array([-11.09288916]), + "std_train_score": np.array([2.52275917]), + "mean_test_score": np.array([-11.09288916]), + "std_test_score": np.array([2.52275917]), + "rank_test_score": np.array([1]), + "split0_test_score": np.array([-13.61564833]), + "split1_test_score": np.array([-8.57012999]), + "split0_train_score": np.array([-13.61564833]), + "split1_train_score": np.array([-8.57012999]), + }, + }, +} + +for key, val in SAMPLES.items(): + combine_hex_cv_result = [] + for each_array in val["each_cv_result"]: + with io.BytesIO() as f: + cp.dump(each_array, f) + f.seek(0) + binary_cv_results = f.getvalue().hex() + combine_hex_cv_result.append(binary_cv_results) + SAMPLES[key]["combine_hex_cv_result"] = combine_hex_cv_result class SnowparkHandlersUnitTest(parameterized.TestCase): + def setUp(self) -> None: + """Creates Snowpark and Snowflake environments for testing.""" + zipped = sorted(zip([5, 4, 2, 0, 1, 3], SAMPLES["basic"]["combine_hex_cv_result"]), key=lambda x: x[0]) + self.RAW_DATA_SP = [Row(val) for _, val in zipped] + def test_sklearn_model_selection_wrapper_provider_lightgbm_installed(self) -> None: orig_import = __import__ @@ -65,6 +261,57 @@ def import_mock(name: str, *args: Any, **kwargs: Any) -> Any: provider = ModelSpecificationsBuilder.build(model=LGBMRegressor()) self.assertEqual(provider.imports, ["lightgbm"]) + def _compare_cv_results(self, cv_result_1: Dict[str, Any], cv_result_2: Dict[str, Any]) -> None: + # compare the keys + self.assertEqual(sorted(cv_result_1.keys()), sorted(cv_result_2.keys())) + # compare the values + for k, v in cv_result_1.items(): + if isinstance(v, np.ndarray): + if k.startswith("param_"): # compare the masked array + np.ma.allequal(v, cv_result_2[k]) # type: ignore[no-untyped-call] + elif k == "params": # compare the parameter combination + self.assertEqual(v.tolist(), cv_result_2[k].tolist()) + elif k.endswith("test_score"): # compare the test score + np.testing.assert_allclose(v, cv_result_2[k], rtol=1.0e-1, atol=1.0e-2) + # Do not compare the fit time + + def test_cv_result(self) -> None: + multimetric, cv_results_, best_param_index, scorers = construct_cv_results( + self.RAW_DATA_SP, + SAMPLES["basic"]["IDX_LENGTH"], + SAMPLES["basic"]["PARAM_LENGTH"], + {"return_train_score": False}, + ) + self.assertEqual(multimetric, False) + self.assertEqual(best_param_index, 0) + self._compare_cv_results(cv_results_, SAMPLES["basic"]["CV_RESULT_"]) + self.assertEqual(scorers, {"score"}) + + def test_cv_result_return_train_score(self) -> None: + multimetric, cv_results_, best_param_index, scorers = construct_cv_results( + [Row(val) for val in SAMPLES["return_train_score"]["combine_hex_cv_result"]], + SAMPLES["return_train_score"]["IDX_LENGTH"], + SAMPLES["return_train_score"]["PARAM_LENGTH"], + {"return_train_score": True}, + ) + self.assertEqual(multimetric, False) + self._compare_cv_results(cv_results_, SAMPLES["return_train_score"]["CV_RESULT_"]) + self.assertEqual(scorers, {"score"}) + + def test_cv_result_incorrect_param_length(self) -> None: + with self.assertRaises(ValueError): + construct_cv_results(self.RAW_DATA_SP, SAMPLES["basic"]["IDX_LENGTH"], 1, {"return_train_score": False}) + + def test_cv_result_nan(self) -> None: + # corner cases with nan values + with self.assertRaises(ValueError): + construct_cv_results(self.RAW_DATA_SP, 0, SAMPLES["basic"]["PARAM_LENGTH"], {"return_train_score": False}) + # empty list + with self.assertRaises(ValueError): + construct_cv_results( + [], SAMPLES["basic"]["IDX_LENGTH"], SAMPLES["basic"]["PARAM_LENGTH"], {"return_train_score": False} + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/modeling/_internal/model_trainer_builder.py b/snowflake/ml/modeling/_internal/model_trainer_builder.py index 4c4d7aca..c4947fab 100644 --- a/snowflake/ml/modeling/_internal/model_trainer_builder.py +++ b/snowflake/ml/modeling/_internal/model_trainer_builder.py @@ -3,13 +3,20 @@ import pandas as pd from sklearn import model_selection +from snowflake.ml._internal.exceptions import error_codes, exceptions from snowflake.ml.modeling._internal.distributed_hpo_trainer import ( DistributedHPOTrainer, ) -from snowflake.ml.modeling._internal.estimator_utils import is_single_node +from snowflake.ml.modeling._internal.estimator_utils import ( + get_module_name, + is_single_node, +) from snowflake.ml.modeling._internal.model_trainer import ModelTrainer from snowflake.ml.modeling._internal.pandas_trainer import PandasModelTrainer from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer +from snowflake.ml.modeling._internal.xgboost_external_memory_trainer import ( + XGBoostExternalMemoryTrainer, +) from snowflake.snowpark import DataFrame, Session _PROJECT = "ModelDevelopment" @@ -30,6 +37,31 @@ class ModelTrainerBuilder: def _check_if_distributed_hpo_enabled(cls, session: Session) -> bool: return not is_single_node(session) and ModelTrainerBuilder._ENABLE_DISTRIBUTED is True + @classmethod + def _validate_external_memory_params(cls, estimator: object, batch_size: int) -> None: + """ + Validate the params are set appropriately for external memory training. + + Args: + estimator: Model object + batch_size: Number of rows in each batch of data processed during training. + + Raises: + SnowflakeMLException: If the params are not appropriate for the external memory training feature. + """ + module_name = get_module_name(model=estimator) + root_module_name = module_name.split(".")[0] + if root_module_name != "xgboost": + raise exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_ARGUMENT, + original_exception=RuntimeError("External memory training is only supported for XGBoost models."), + ) + if batch_size <= 0: + raise exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_ARGUMENT, + original_exception=RuntimeError("Batch size must be >= 0 when using external memory training feature."), + ) + @classmethod def build( cls, @@ -40,6 +72,8 @@ def build( sample_weight_col: Optional[str] = None, autogenerated: bool = False, subproject: str = "", + use_external_memory_version: bool = False, + batch_size: int = -1, ) -> ModelTrainer: """ Builder method that creates an approproiate ModelTrainer instance based on the given params. @@ -55,22 +89,32 @@ def build( ) elif isinstance(dataset, DataFrame): trainer_klass = SnowparkModelTrainer + init_args = { + "estimator": estimator, + "dataset": dataset, + "session": dataset._session, + "input_cols": input_cols, + "label_cols": label_cols, + "sample_weight_col": sample_weight_col, + "autogenerated": autogenerated, + "subproject": subproject, + } + assert dataset._session is not None # Make MyPy happpy if isinstance(estimator, model_selection.GridSearchCV) or isinstance( estimator, model_selection.RandomizedSearchCV ): if ModelTrainerBuilder._check_if_distributed_hpo_enabled(session=dataset._session): trainer_klass = DistributedHPOTrainer - return trainer_klass( - estimator=estimator, - dataset=dataset, - session=dataset._session, - input_cols=input_cols, - label_cols=label_cols, - sample_weight_col=sample_weight_col, - autogenerated=autogenerated, - subproject=subproject, - ) + elif use_external_memory_version: + ModelTrainerBuilder._validate_external_memory_params( + estimator=estimator, + batch_size=batch_size, + ) + trainer_klass = XGBoostExternalMemoryTrainer + init_args["batch_size"] = batch_size + + return trainer_klass(**init_args) # type: ignore[arg-type] else: raise TypeError( f"Unexpected dataset type: {type(dataset)}." diff --git a/snowflake/ml/modeling/_internal/model_trainer_builder_test.py b/snowflake/ml/modeling/_internal/model_trainer_builder_test.py new file mode 100644 index 00000000..8fbb37e5 --- /dev/null +++ b/snowflake/ml/modeling/_internal/model_trainer_builder_test.py @@ -0,0 +1,84 @@ +from typing import Any +from unittest import mock + +import inflection +from absl.testing import absltest +from sklearn.datasets import load_iris +from sklearn.linear_model import LinearRegression +from sklearn.model_selection import GridSearchCV +from xgboost import XGBRegressor + +from snowflake.ml.modeling._internal.distributed_hpo_trainer import ( + DistributedHPOTrainer, +) +from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder +from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer +from snowflake.ml.modeling._internal.xgboost_external_memory_trainer import ( + XGBoostExternalMemoryTrainer, +) +from snowflake.ml.utils.connection_params import SnowflakeLoginOptions +from snowflake.snowpark import DataFrame, Session + + +class SnowparkHandlersUnitTest(absltest.TestCase): + def setUp(self) -> None: + self._session = Session.builder.configs(SnowflakeLoginOptions()).create() + + def tearDown(self) -> None: + self._session.close() + + def get_snowpark_dataset(self) -> DataFrame: + input_df_pandas = load_iris(as_frame=True).frame + input_df_pandas.columns = [inflection.parameterize(c, "_").upper() for c in input_df_pandas.columns] + input_df_pandas["INDEX"] = input_df_pandas.reset_index().index + input_df: DataFrame = self._session.create_dataframe(input_df_pandas) + return input_df + + def test_sklearn_model_trainer(self) -> None: + model = LinearRegression() + dataset = self.get_snowpark_dataset() + trainer = ModelTrainerBuilder.build(estimator=model, dataset=dataset, input_cols=[]) + + self.assertTrue(isinstance(trainer, SnowparkModelTrainer)) + + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") + def test_distributed_hpo_trainer(self, mock_is_single_node: Any) -> None: + mock_is_single_node.return_value = False + dataset = self.get_snowpark_dataset() + model = GridSearchCV(estimator=LinearRegression(), param_grid={"loss": ["rmsqe", "mae"]}) + trainer = ModelTrainerBuilder.build(estimator=model, dataset=dataset, input_cols=[]) + + self.assertTrue(isinstance(trainer, DistributedHPOTrainer)) + + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") + def test_single_node_hpo_trainer(self, mock_is_single_node: Any) -> None: + mock_is_single_node.return_value = True + dataset = self.get_snowpark_dataset() + model = GridSearchCV(estimator=LinearRegression(), param_grid={"loss": ["rmsqe", "mae"]}) + trainer = ModelTrainerBuilder.build(estimator=model, dataset=dataset, input_cols=[]) + + self.assertTrue(isinstance(trainer, SnowparkModelTrainer)) + + def test_xgboost_external_memory_model_trainer(self) -> None: + model = XGBRegressor() + dataset = self.get_snowpark_dataset() + trainer = ModelTrainerBuilder.build( + estimator=model, dataset=dataset, input_cols=[], use_external_memory_version=True, batch_size=1000 + ) + + self.assertTrue(isinstance(trainer, XGBoostExternalMemoryTrainer)) + + def test_xgboost_standard_model_trainer(self) -> None: + model = XGBRegressor() + dataset = self.get_snowpark_dataset() + trainer = ModelTrainerBuilder.build( + estimator=model, + dataset=dataset, + input_cols=[], + ) + + self.assertTrue(isinstance(trainer, SnowparkModelTrainer)) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/modeling/_internal/snowpark_trainer.py b/snowflake/ml/modeling/_internal/snowpark_trainer.py index 3d7aaf39..315eded1 100644 --- a/snowflake/ml/modeling/_internal/snowpark_trainer.py +++ b/snowflake/ml/modeling/_internal/snowpark_trainer.py @@ -12,7 +12,11 @@ exceptions, modeling_error_messages, ) -from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils +from snowflake.ml._internal.utils import ( + identifier, + pkg_version_utils, + snowpark_dataframe_utils, +) from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator from snowflake.ml._internal.utils.temp_file_utils import ( cleanup_temp_files, @@ -253,11 +257,15 @@ def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProc fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE) + relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( + pkg_versions=model_spec.pkgDependencies, session=self.session + ) + fit_wrapper_sproc = self.session.sproc.register( func=self._build_fit_wrapper_sproc(model_spec=model_spec), is_permanent=False, name=fit_sproc_name, - packages=["snowflake-snowpark-python"] + model_spec.pkgDependencies, # type: ignore[arg-type] + packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type] replace=True, session=self.session, statement_params=statement_params, diff --git a/snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py b/snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py new file mode 100644 index 00000000..5f7e5942 --- /dev/null +++ b/snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py @@ -0,0 +1,444 @@ +import inspect +import os +import tempfile +from typing import Any, Dict, List, Optional + +import cloudpickle as cp +import pandas as pd +import pyarrow.parquet as pq + +from snowflake.ml._internal import telemetry +from snowflake.ml._internal.exceptions import ( + error_codes, + exceptions, + modeling_error_messages, +) +from snowflake.ml._internal.utils import pkg_version_utils +from snowflake.ml._internal.utils.query_result_checker import ResultValidator +from snowflake.ml._internal.utils.snowpark_dataframe_utils import ( + cast_snowpark_dataframe, +) +from snowflake.ml._internal.utils.temp_file_utils import get_temp_file_path +from snowflake.ml.modeling._internal.model_specifications import ( + ModelSpecifications, + ModelSpecificationsBuilder, +) +from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer +from snowflake.snowpark import ( + DataFrame, + Session, + exceptions as snowpark_exceptions, + functions as F, +) +from snowflake.snowpark._internal.utils import ( + TempObjectType, + random_name_for_temp_object, +) + +_PROJECT = "ModelDevelopment" + + +def get_data_iterator( + file_paths: List[str], + batch_size: int, + input_cols: List[str], + label_cols: List[str], + sample_weight_col: Optional[str] = None, +) -> Any: + from typing import List, Optional + + import xgboost + + class ParquetDataIterator(xgboost.DataIter): + """ + This iterator reads parquet data stored in a specified files and returns + deserialized data, enabling seamless integration with the xgboost framework for + machine learning tasks. + """ + + def __init__( + self, + file_paths: List[str], + batch_size: int, + input_cols: List[str], + label_cols: List[str], + sample_weight_col: Optional[str] = None, + ) -> None: + """ + Initialize the DataIterator. + + Args: + file_paths: List of file paths containing the data. + batch_size: Target number of rows in each batch. + input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for + training. + label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) + to learn. + sample_weight_col: The column name representing the weight of training examples. + """ + self._file_paths = file_paths + self._batch_size = batch_size + self._input_cols = input_cols + self._label_cols = label_cols + self._sample_weight_col = sample_weight_col + + # File index + self._it = 0 + # Pandas dataframe containing temp data + self._df = None + # XGBoost will generate some cache files under current directory with the prefix + # "cache" + cache_dir_name = tempfile.mkdtemp() + super().__init__(cache_prefix=os.path.join(cache_dir_name, "cache")) + + def next(self, batch_consumer_fn) -> int: # type: ignore[no-untyped-def] + """Advance the iterator by 1 step and pass the data to XGBoost's batch_consumer_fn. + This function is called by XGBoost during the construction of ``DMatrix`` + + Args: + batch_consumer_fn: batch consumer function + + Returns: + 0 if there is no more data, else 1. + """ + while (self._df is None) or (self._df.shape[0] < self._batch_size): + # Read files and append data to temp df until batch size is reached. + if self._it == len(self._file_paths): + break + new_df = pq.read_table(self._file_paths[self._it]).to_pandas() + self._it += 1 + + if self._df is None: + self._df = new_df + else: + self._df = pd.concat([self._df, new_df], ignore_index=True) + + if (self._df is None) or (self._df.shape[0] == 0): + # No more data + return 0 + + # Slice the temp df and save the remainder in the temp df + batch_end_index = min(self._batch_size, self._df.shape[0]) + batch_df = self._df.iloc[:batch_end_index] + self._df = self._df.truncate(before=batch_end_index).reset_index(drop=True) + + # TODO(snandamuri): Make it proper to support categorical features, etc. + func_args = { + "data": batch_df[self._input_cols], + "label": batch_df[self._label_cols].squeeze(), + } + if self._sample_weight_col is not None: + func_args["weight"] = batch_df[self._sample_weight_col].squeeze() + + batch_consumer_fn(**func_args) + # Return 1 to let XGBoost know we haven't seen all the files yet. + return 1 + + def reset(self) -> None: + """Reset the iterator to its beginning""" + self._it = 0 + + return ParquetDataIterator( + file_paths=file_paths, + batch_size=batch_size, + input_cols=input_cols, + label_cols=label_cols, + sample_weight_col=sample_weight_col, + ) + + +def train_xgboost_model( + estimator: object, + file_paths: List[str], + batch_size: int, + input_cols: List[str], + label_cols: List[str], + sample_weight_col: Optional[str] = None, +) -> object: + """ + Function to train XGBoost models using the external memory version of XGBoost. + """ + import xgboost + + def _objective_decorator(func): # type: ignore[no-untyped-def] + def inner(preds, dmatrix): # type: ignore[no-untyped-def] + """internal function""" + labels = dmatrix.get_label() + return func(labels, preds) + + return inner + + assert isinstance(estimator, xgboost.XGBModel) + params = estimator.get_xgb_params() + obj = None + + if isinstance(estimator, xgboost.XGBClassifier): + # TODO (snandamuri): Find better way to get expected_classes + # Set: self.classes_, self.n_classes_ + expected_classes = pd.unique(pq.read_table(file_paths[0]).to_pandas()[label_cols].squeeze()) + estimator.n_classes_ = len(expected_classes) + if callable(estimator.objective): + obj = _objective_decorator(estimator.objective) # type: ignore[no-untyped-call] + # Use default value. Is it really not used ? + params["objective"] = "binary:logistic" + + if len(expected_classes) > 2: + # Switch to using a multiclass objective in the underlying XGB instance + if params.get("objective", None) != "multi:softmax": + params["objective"] = "multi:softprob" + params["num_class"] = len(expected_classes) + + if "tree_method" not in params.keys() or params["tree_method"] is None or params["tree_method"].lower() == "exact": + params["tree_method"] = "hist" + + if ( + "grow_policy" not in params.keys() + or params["grow_policy"] is None + or params["grow_policy"].lower() != "depthwise" + ): + params["grow_policy"] = "depthwise" + + it = get_data_iterator( + file_paths=file_paths, + batch_size=batch_size, + input_cols=input_cols, + label_cols=label_cols, + sample_weight_col=sample_weight_col, + ) + Xy = xgboost.DMatrix(it) + estimator._Booster = xgboost.train( + params, + Xy, + estimator.get_num_boosting_rounds(), + evals=[], + early_stopping_rounds=estimator.early_stopping_rounds, + evals_result=None, + obj=obj, + custom_metric=estimator.eval_metric, + verbose_eval=None, + xgb_model=None, + callbacks=None, + ) + return estimator + + +cp.register_pickle_by_value(inspect.getmodule(get_data_iterator)) +cp.register_pickle_by_value(inspect.getmodule(train_xgboost_model)) + + +class XGBoostExternalMemoryTrainer(SnowparkModelTrainer): + """ + When working with large datasets, training XGBoost models traditionally requires loading the entire dataset into + memory, which can be costly and sometimes infeasible due to memory constraints. To solve this problem, XGBoost + provides support for loading data from external memory using a built-in data parser. With this feature enabled, + the training process occurs in a two-step approach: + Preprocessing Step: Input data is read and parsed into an internal format, such as CSR, CSC, or sorted CSC. + Processed state is appended to an in-memory buffer. Once the buffer reaches a predefined size, it is + written out to disk as a page. + Tree Construction Step: During the tree construction phase, the data pages stored on disk are streamed via + a multi-threaded pre-fetcher, allowing the model to efficiently access and process the data without + overloading memory. + """ + + def __init__( + self, + estimator: object, + dataset: DataFrame, + session: Session, + input_cols: List[str], + label_cols: Optional[List[str]], + sample_weight_col: Optional[str], + autogenerated: bool = False, + subproject: str = "", + batch_size: int = 10000, + ) -> None: + """ + Initializes the XGBoostExternalMemoryTrainer with a model, a Snowpark DataFrame, feature, and label column + names, etc. + + Args: + estimator: SKLearn compatible estimator or transformer object. + dataset: The dataset used for training the model. + session: Snowflake session object to be used for training. + input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training. + label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn. + sample_weight_col: The column name representing the weight of training examples. + autogenerated: A boolean denoting if the trainer is being used by autogenerated code or not. + subproject: subproject name to be used in telemetry. + batch_size: Number of the rows in the each batch processed during training. + """ + super().__init__( + estimator=estimator, + dataset=dataset, + session=session, + input_cols=input_cols, + label_cols=label_cols, + sample_weight_col=sample_weight_col, + autogenerated=autogenerated, + subproject=subproject, + ) + self._batch_size = batch_size + + def _get_xgb_external_memory_fit_wrapper_sproc( + self, + model_spec: ModelSpecifications, + session: Session, + statement_params: Dict[str, str], + import_file_paths: List[str], + ) -> Any: + fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE) + + relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel( + pkg_versions=model_spec.pkgDependencies, session=self.session + ) + + @F.sproc( + is_permanent=False, + name=fit_sproc_name, + packages=list(["snowflake-snowpark-python"] + relaxed_dependencies), + replace=True, + session=session, + statement_params=statement_params, + anonymous=True, + imports=list(import_file_paths), + ) # type: ignore[misc] + def fit_wrapper_sproc( + session: Session, + stage_transform_file_name: str, + stage_result_file_name: str, + dataset_stage_name: str, + batch_size: int, + input_cols: List[str], + label_cols: List[str], + sample_weight_col: Optional[str], + statement_params: Dict[str, str], + ) -> str: + import os + import sys + + import cloudpickle as cp + + local_transform_file_name = get_temp_file_path() + + session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params) + + local_transform_file_path = os.path.join( + local_transform_file_name, os.listdir(local_transform_file_name)[0] + ) + with open(local_transform_file_path, mode="r+b") as local_transform_file_obj: + estimator = cp.load(local_transform_file_obj) + + data_files = [ + os.path.join(sys._xoptions["snowflake_import_directory"], filename) + for filename in os.listdir(sys._xoptions["snowflake_import_directory"]) + if filename.startswith(dataset_stage_name) + ] + + estimator = train_xgboost_model( + estimator=estimator, + file_paths=data_files, + batch_size=batch_size, + input_cols=input_cols, + label_cols=label_cols, + sample_weight_col=sample_weight_col, + ) + + local_result_file_name = get_temp_file_path() + with open(local_result_file_name, mode="w+b") as local_result_file_obj: + cp.dump(estimator, local_result_file_obj) + + session.file.put( + local_result_file_name, + stage_result_file_name, + auto_compress=False, + overwrite=True, + statement_params=statement_params, + ) + + # Note: you can add something like + "|" + str(df) to the return string + # to pass debug information to the caller. + return str(os.path.basename(local_result_file_name)) + + return fit_wrapper_sproc + + def _write_training_data_to_stage(self, dataset_stage_name: str) -> List[str]: + """ + Materializes the training to the specified stage and returns the list of stage file paths. + + Args: + dataset_stage_name: Target stage to materialize training data. + + Returns: + List of stage file paths that contain the materialized data. + """ + # Stage data. + dataset = cast_snowpark_dataframe(self.dataset) + remote_file_path = f"{dataset_stage_name}/{dataset_stage_name}.parquet" + copy_response = dataset.write.copy_into_location( # type:ignore[call-overload] + remote_file_path, file_format_type="parquet", header=True, overwrite=True + ) + ResultValidator(result=copy_response).has_dimensions(expected_rows=1).validate() + data_file_paths = [f"@{row.name}" for row in self.session.sql(f"LIST @{dataset_stage_name}").collect()] + return data_file_paths + + def train(self) -> object: + """ + Runs hyper parameter optimization by distributing the tasks across warehouse. + + Returns: + Trained model + + Raises: + SnowflakeMLException: For known types of user and system errors. + e: For every unexpected exception from SnowflakeClient. + """ + temp_stage_name = self._create_temp_stage() + (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(stage_name=temp_stage_name) + data_file_paths = self._write_training_data_to_stage(dataset_stage_name=temp_stage_name) + + # Call fit sproc + statement_params = telemetry.get_function_usage_statement_params( + project=_PROJECT, + subproject=self._subproject, + function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name), + api_calls=[Session.call], + custom_tags=None, + ) + + model_spec = ModelSpecificationsBuilder.build(model=self.estimator) + fit_wrapper = self._get_xgb_external_memory_fit_wrapper_sproc( + model_spec=model_spec, + session=self.session, + statement_params=statement_params, + import_file_paths=data_file_paths, + ) + + try: + sproc_export_file_name = fit_wrapper( + self.session, + stage_transform_file_name, + stage_result_file_name, + temp_stage_name, + self._batch_size, + self.input_cols, + self.label_cols, + self.sample_weight_col, + statement_params, + ) + except snowpark_exceptions.SnowparkClientException as e: + if "fit() missing 1 required positional argument: 'y'" in str(e): + raise exceptions.SnowflakeMLException( + error_code=error_codes.NOT_FOUND, + original_exception=RuntimeError(modeling_error_messages.ATTRIBUTE_NOT_SET.format("label_cols")), + ) from e + raise e + + if "|" in sproc_export_file_name: + fields = sproc_export_file_name.strip().split("|") + sproc_export_file_name = fields[0] + + return self._fetch_model_from_stage( + dir_path=stage_result_file_name, + file_name=sproc_export_file_name, + statement_params=statement_params, + ) diff --git a/snowflake/ml/modeling/_internal/xgboost_external_memory_trainer_test.py b/snowflake/ml/modeling/_internal/xgboost_external_memory_trainer_test.py new file mode 100644 index 00000000..8a663c91 --- /dev/null +++ b/snowflake/ml/modeling/_internal/xgboost_external_memory_trainer_test.py @@ -0,0 +1,100 @@ +import math + +import inflection +import pandas as pd +from absl.testing import absltest +from sklearn.datasets import load_iris + +from snowflake.ml._internal.utils.temp_file_utils import ( + cleanup_temp_files, + get_temp_file_path, +) +from snowflake.ml.modeling._internal.xgboost_external_memory_trainer import ( + get_data_iterator, +) + + +class XGBoostExternalMemoryTrainerTest(absltest.TestCase): + def setUp(self) -> None: + pass + + def tearDown(self) -> None: + pass + + def get_dataset(self) -> pd.DataFrame: + input_df_pandas = load_iris(as_frame=True).frame + input_df_pandas.columns = [inflection.parameterize(c, "_").upper() for c in input_df_pandas.columns] + input_cols = [c for c in input_df_pandas.columns if not c.startswith("TARGET")] + label_col = [c for c in input_df_pandas.columns if c.startswith("TARGET")] + return (input_df_pandas, input_cols, label_col) + + def test_data_iterator_single_file(self) -> None: + df, input_cols, label_col = self.get_dataset() + + num_rows_in_original_dataset = df.shape[0] + batch_size = 20 + + temp_file = get_temp_file_path() + df.to_parquet(temp_file) + + it = get_data_iterator( + file_paths=[temp_file], + batch_size=20, + input_cols=input_cols, + label_cols=label_col, + ) + + num_rows = 0 + num_batches = 0 + + def consumer_func(data: pd.DataFrame, label: pd.DataFrame) -> None: + nonlocal num_rows + nonlocal num_batches + num_rows += data.shape[0] + num_batches += 1 + + while it.next(consumer_func): + pass + + self.assertEqual(num_rows, num_rows_in_original_dataset) + self.assertEqual(num_batches, math.ceil(float(num_rows_in_original_dataset) / float(batch_size))) + cleanup_temp_files(temp_file) + + def test_data_iterator_multiple_file(self) -> None: + df, input_cols, label_col = self.get_dataset() + + num_rows_in_original_dataset = df.shape[0] + batch_size = 20 + + temp_file1 = get_temp_file_path() + temp_file2 = get_temp_file_path() + df1, df2 = df.iloc[:70], df.iloc[70:] + df1.to_parquet(temp_file1) + df2.to_parquet(temp_file2) + + it = get_data_iterator( + file_paths=[temp_file1, temp_file2], + batch_size=20, + input_cols=input_cols, + label_cols=label_col, + ) + + num_rows = 0 + num_batches = 0 + + def consumer_func(data: pd.DataFrame, label: pd.DataFrame) -> None: + nonlocal num_rows + nonlocal num_batches + num_rows += data.shape[0] + num_batches += 1 + + while it.next(consumer_func): + pass + + self.assertEqual(num_rows, num_rows_in_original_dataset) + self.assertEqual(num_batches, math.ceil(float(num_rows_in_original_dataset) / float(batch_size))) + cleanup_temp_files([temp_file1, temp_file2]) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/modeling/preprocessing/min_max_scaler.py b/snowflake/ml/modeling/preprocessing/min_max_scaler.py index 9e122cb7..56fff072 100644 --- a/snowflake/ml/modeling/preprocessing/min_max_scaler.py +++ b/snowflake/ml/modeling/preprocessing/min_max_scaler.py @@ -8,8 +8,9 @@ from snowflake import snowpark from snowflake.ml._internal import telemetry +from snowflake.ml._internal.exceptions import error_codes, exceptions from snowflake.ml.modeling.framework import _utils, base -from snowflake.snowpark import functions as F +from snowflake.snowpark import functions as F, types as T class MinMaxScaler(base.BaseTransformer): @@ -125,6 +126,18 @@ def _reset(self) -> None: self.data_max_ = {} self.data_range_ = {} + def _check_input_column_types(self, dataset: snowpark.DataFrame) -> None: + for field in dataset.schema.fields: + if field.name in self.input_cols: + if not issubclass(type(field.datatype), T._NumericType): + raise exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_DATA_TYPE, + original_exception=TypeError( + f"Non-numeric input column {field.name} datatype {field.datatype} " + "is not supported by the MinMaxScaler." + ), + ) + @telemetry.send_api_usage_telemetry( project=base.PROJECT, subproject=base.SUBPROJECT, @@ -169,6 +182,7 @@ def _fit_sklearn(self, dataset: pd.DataFrame) -> None: self.data_range_[input_col] = float(sklearn_scaler.data_range_[i]) def _fit_snowpark(self, dataset: snowpark.DataFrame) -> None: + self._check_input_column_types(dataset) computed_states = self._compute(dataset, self.input_cols, self.custom_states) # assign states to the object diff --git a/snowflake/ml/monitoring/tests/BUILD.bazel b/snowflake/ml/monitoring/tests/BUILD.bazel index f005dc1a..594aae84 100644 --- a/snowflake/ml/monitoring/tests/BUILD.bazel +++ b/snowflake/ml/monitoring/tests/BUILD.bazel @@ -1,6 +1,9 @@ load("//bazel:py_rules.bzl", "py_test") -package(default_visibility = ["//snowflake/ml/monitoring"]) +package(default_visibility = [ + "//bazel:snowml_public_common", + "//snowflake/ml/monitoring", +]) SHARD_COUNT = 3 diff --git a/snowflake/ml/registry/BUILD.bazel b/snowflake/ml/registry/BUILD.bazel index 89bb441b..19d7c1a1 100644 --- a/snowflake/ml/registry/BUILD.bazel +++ b/snowflake/ml/registry/BUILD.bazel @@ -46,7 +46,7 @@ py_library( "_schema_upgrade_plans.py", "_schema_version_manager.py", ], - visibility = ["//visibility:private"], + visibility = ["//bazel:snowml_public_common"], deps = [ "//snowflake/ml/_internal/utils:query_result_checker", "//snowflake/ml/_internal/utils:table_manager", @@ -78,7 +78,7 @@ py_test( ) py_library( - name = "registry", + name = "registry_impl", srcs = [ "registry.py", ], @@ -86,12 +86,23 @@ py_library( "//snowflake/ml/_internal:telemetry", "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model", "//snowflake/ml/model:model_signature", "//snowflake/ml/model:type_hints", - "//snowflake/ml/model/_client/model:model_impl", - "//snowflake/ml/model/_client/model:model_version_impl", - "//snowflake/ml/model/_client/ops:model_ops", - "//snowflake/ml/model/_model_composer:model_composer", + "//snowflake/ml/registry/_manager:model_manager", + ], +) + +py_library( + name = "registry", + srcs = [ + "__init__.py", + ], + deps = [ + ":artifact_manager", + ":model_registry", + ":registry_impl", + ":schema", ], ) @@ -101,15 +112,22 @@ py_test( "registry_test.py", ], deps = [ - ":registry", - "//snowflake/ml/_internal/utils:sql_identifier", - "//snowflake/ml/model/_client/model:model_version_impl", - "//snowflake/ml/model/_model_composer:model_composer", + ":registry_impl", + "//snowflake/ml/model", "//snowflake/ml/test_utils:mock_data_frame", "//snowflake/ml/test_utils:mock_session", ], ) +py_test( + name = "package_visibility_test", + srcs = ["package_visibility_test.py"], + deps = [ + ":model_registry", + ":registry", + ], +) + py_package( name = "model_registry_pkg", packages = ["snowflake.ml"], diff --git a/snowflake/ml/registry/__init__.py b/snowflake/ml/registry/__init__.py new file mode 100644 index 00000000..47275f2d --- /dev/null +++ b/snowflake/ml/registry/__init__.py @@ -0,0 +1,3 @@ +from snowflake.ml.registry.registry import Registry + +__all__ = ["Registry"] diff --git a/snowflake/ml/registry/_manager/BUILD.bazel b/snowflake/ml/registry/_manager/BUILD.bazel new file mode 100644 index 00000000..e75c6e12 --- /dev/null +++ b/snowflake/ml/registry/_manager/BUILD.bazel @@ -0,0 +1,40 @@ +load("//bazel:py_rules.bzl", "py_library", "py_test") + +package(default_visibility = [ + "//bazel:snowml_public_common", + "//snowflake/ml/registry:__pkg__", +]) + +py_library( + name = "model_manager", + srcs = [ + "model_manager.py", + ], + deps = [ + "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model", + "//snowflake/ml/model:model_signature", + "//snowflake/ml/model:type_hints", + "//snowflake/ml/model/_client/model:model_impl", + "//snowflake/ml/model/_client/model:model_version_impl", + "//snowflake/ml/model/_client/ops:metadata_ops", + "//snowflake/ml/model/_client/ops:model_ops", + "//snowflake/ml/model/_model_composer:model_composer", + ], +) + +py_test( + name = "model_manager_test", + srcs = [ + "model_manager_test.py", + ], + deps = [ + ":model_manager", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model/_client/model:model_version_impl", + "//snowflake/ml/model/_model_composer:model_composer", + "//snowflake/ml/test_utils:mock_data_frame", + "//snowflake/ml/test_utils:mock_session", + ], +) diff --git a/snowflake/ml/registry/_manager/model_manager.py b/snowflake/ml/registry/_manager/model_manager.py new file mode 100644 index 00000000..cd2f3c87 --- /dev/null +++ b/snowflake/ml/registry/_manager/model_manager.py @@ -0,0 +1,163 @@ +from types import ModuleType +from typing import Any, Dict, List, Optional + +import pandas as pd +from absl.logging import logging + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import model_signature, type_hints as model_types +from snowflake.ml.model._client.model import model_impl, model_version_impl +from snowflake.ml.model._client.ops import metadata_ops, model_ops +from snowflake.ml.model._model_composer import model_composer +from snowflake.snowpark import session + +logger = logging.getLogger(__name__) + + +class ModelManager: + def __init__( + self, + session: session.Session, + *, + database_name: sql_identifier.SqlIdentifier, + schema_name: sql_identifier.SqlIdentifier, + ) -> None: + self._database_name = database_name + self._schema_name = schema_name + self._model_ops = model_ops.ModelOperator( + session, database_name=self._database_name, schema_name=self._schema_name + ) + + def log_model( + self, + model: model_types.SupportedModelType, + *, + model_name: str, + version_name: str, + comment: Optional[str] = None, + metrics: Optional[Dict[str, Any]] = None, + conda_dependencies: Optional[List[str]] = None, + pip_requirements: Optional[List[str]] = None, + python_version: Optional[str] = None, + signatures: Optional[Dict[str, model_signature.ModelSignature]] = None, + sample_input_data: Optional[model_types.SupportedDataType] = None, + code_paths: Optional[List[str]] = None, + ext_modules: Optional[List[ModuleType]] = None, + options: Optional[model_types.ModelSaveOption] = None, + statement_params: Optional[Dict[str, Any]] = None, + ) -> model_version_impl.ModelVersion: + model_name_id = sql_identifier.SqlIdentifier(model_name) + + version_name_id = sql_identifier.SqlIdentifier(version_name) + + if self._model_ops.validate_existence( + model_name=model_name_id, statement_params=statement_params + ) and self._model_ops.validate_existence( + model_name=model_name_id, version_name=version_name_id, statement_params=statement_params + ): + raise ValueError(f"Model {model_name} version {version_name} already existed.") + + stage_path = self._model_ops.prepare_model_stage_path( + statement_params=statement_params, + ) + + logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.") + + mc = model_composer.ModelComposer(self._model_ops._session, stage_path=stage_path) + mc.save( + name=model_name_id.resolved(), + model=model, + signatures=signatures, + sample_input=sample_input_data, + conda_dependencies=conda_dependencies, + pip_requirements=pip_requirements, + python_version=python_version, + code_paths=code_paths, + ext_modules=ext_modules, + options=options, + ) + + logger.info("Start creating MODEL object for you in the Snowflake.") + + self._model_ops.create_from_stage( + composed_model=mc, + model_name=model_name_id, + version_name=version_name_id, + statement_params=statement_params, + ) + + mv = model_version_impl.ModelVersion._ref( + self._model_ops, + model_name=model_name_id, + version_name=version_name_id, + ) + + if comment: + mv.comment = comment + + if metrics: + self._model_ops._metadata_ops.save( + metadata_ops.ModelVersionMetadataSchema(metrics=metrics), + model_name=model_name_id, + version_name=version_name_id, + statement_params=statement_params, + ) + + return mv + + def get_model( + self, + model_name: str, + *, + statement_params: Optional[Dict[str, Any]] = None, + ) -> model_impl.Model: + model_name_id = sql_identifier.SqlIdentifier(model_name) + if self._model_ops.validate_existence( + model_name=model_name_id, + statement_params=statement_params, + ): + return model_impl.Model._ref( + self._model_ops, + model_name=model_name_id, + ) + else: + raise ValueError(f"Unable to find model {model_name}") + + def models( + self, + *, + statement_params: Optional[Dict[str, Any]] = None, + ) -> List[model_impl.Model]: + model_names = self._model_ops.list_models_or_versions( + statement_params=statement_params, + ) + return [ + model_impl.Model._ref( + self._model_ops, + model_name=model_name, + ) + for model_name in model_names + ] + + def show_models( + self, + *, + statement_params: Optional[Dict[str, Any]] = None, + ) -> pd.DataFrame: + rows = self._model_ops.show_models_or_versions( + statement_params=statement_params, + ) + return pd.DataFrame([row.as_dict() for row in rows]) + + def delete_model( + self, + model_name: str, + *, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + model_name_id = sql_identifier.SqlIdentifier(model_name) + + self._model_ops.delete_model_or_version( + model_name=model_name_id, + statement_params=statement_params, + ) diff --git a/snowflake/ml/registry/_manager/model_manager_test.py b/snowflake/ml/registry/_manager/model_manager_test.py new file mode 100644 index 00000000..a314961c --- /dev/null +++ b/snowflake/ml/registry/_manager/model_manager_test.py @@ -0,0 +1,351 @@ +from typing import cast +from unittest import mock + +import pandas as pd +from absl.testing import absltest + +from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model._client.model import model_impl, model_version_impl +from snowflake.ml.model._model_composer import model_composer +from snowflake.ml.registry._manager import model_manager +from snowflake.ml.test_utils import mock_session +from snowflake.snowpark import Row, Session + + +class RegistryTest(absltest.TestCase): + def setUp(self) -> None: + self.m_session = mock_session.MockSession(conn=None, test_case=self) + self.c_session = cast(Session, self.m_session) + self.m_r = model_manager.ModelManager( + self.c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("TEST"), + ) + + def test_get_model_1(self) -> None: + m_model = model_impl.Model._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + ) + with mock.patch.object(self.m_r._model_ops, "validate_existence", return_value=True) as mock_validate_existence: + m = self.m_r.get_model("MODEL") + self.assertEqual(m, m_model) + mock_validate_existence.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + + def test_get_model_2(self) -> None: + with mock.patch.object( + self.m_r._model_ops, "validate_existence", return_value=False + ) as mock_validate_existence: + with self.assertRaisesRegex(ValueError, "Unable to find model MODEL"): + self.m_r.get_model("MODEL") + mock_validate_existence.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + + def test_models(self) -> None: + m_model_1 = model_impl.Model._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + ) + m_model_2 = model_impl.Model._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("Model", case_sensitive=True), + ) + with mock.patch.object( + self.m_r._model_ops, + "list_models_or_versions", + return_value=[ + sql_identifier.SqlIdentifier("MODEL"), + sql_identifier.SqlIdentifier("Model", case_sensitive=True), + ], + ) as mock_list_models_or_versions: + m_list = self.m_r.models() + self.assertListEqual(m_list, [m_model_1, m_model_2]) + mock_list_models_or_versions.assert_called_once_with( + statement_params=mock.ANY, + ) + + def test_show_models(self) -> None: + m_list_res = [ + Row( + create_on="06/01", + name="MODEL", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + default_version_name="V1", + ), + Row( + create_on="06/01", + name="Model", + comment="This is a comment", + model_name="MODEL", + database_name="TEMP", + schema_name="test", + default_version_name="v1", + ), + ] + with mock.patch.object( + self.m_r._model_ops, + "show_models_or_versions", + return_value=m_list_res, + ) as mock_show_models_or_versions: + mv_info = self.m_r.show_models() + pd.testing.assert_frame_equal(mv_info, pd.DataFrame([row.as_dict() for row in m_list_res])) + mock_show_models_or_versions.assert_called_once_with( + statement_params=mock.ANY, + ) + + def test_log_model_1(self) -> None: + m_model = mock.MagicMock() + m_conda_dependency = mock.MagicMock() + m_sample_input_data = mock.MagicMock() + m_stage_path = "@TEMP.TEST.MODEL/V1" + with mock.patch.object( + self.m_r._model_ops, "validate_existence", return_value=False + ) as mock_validate_existence, mock.patch.object( + self.m_r._model_ops, "prepare_model_stage_path", return_value=m_stage_path + ) as mock_prepare_model_stage_path, mock.patch.object( + model_composer.ModelComposer, "save" + ) as mock_save, mock.patch.object( + self.m_r._model_ops, "create_from_stage" + ) as mock_create_from_stage: + mv = self.m_r.log_model( + model=m_model, + model_name="MODEL", + version_name="v1", + conda_dependencies=m_conda_dependency, + sample_input_data=m_sample_input_data, + ) + mock_validate_existence.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + mock_prepare_model_stage_path.assert_called_once_with( + statement_params=mock.ANY, + ) + mock_save.assert_called_once_with( + name="MODEL", + model=m_model, + signatures=None, + sample_input=m_sample_input_data, + conda_dependencies=m_conda_dependency, + pip_requirements=None, + python_version=None, + code_paths=None, + ext_modules=None, + options=None, + ) + mock_create_from_stage.assert_called_once_with( + composed_model=mock.ANY, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1"), + statement_params=mock.ANY, + ) + self.assertEqual( + mv, + model_version_impl.ModelVersion._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1"), + ), + ) + + def test_log_model_2(self) -> None: + m_model = mock.MagicMock() + m_pip_requirements = mock.MagicMock() + m_signatures = mock.MagicMock() + m_options = mock.MagicMock() + m_stage_path = "@TEMP.TEST.MODEL/V1" + with mock.patch.object(self.m_r._model_ops, "validate_existence", return_value=False), mock.patch.object( + self.m_r._model_ops, "prepare_model_stage_path", return_value=m_stage_path + ) as mock_prepare_model_stage_path, mock.patch.object( + model_composer.ModelComposer, "save" + ) as mock_save, mock.patch.object( + self.m_r._model_ops, "create_from_stage" + ) as mock_create_from_stage: + mv = self.m_r.log_model( + model=m_model, + model_name="MODEL", + version_name="V1", + pip_requirements=m_pip_requirements, + signatures=m_signatures, + options=m_options, + ) + mock_prepare_model_stage_path.assert_called_once_with( + statement_params=mock.ANY, + ) + mock_save.assert_called_once_with( + name="MODEL", + model=m_model, + signatures=m_signatures, + sample_input=None, + conda_dependencies=None, + pip_requirements=m_pip_requirements, + python_version=None, + code_paths=None, + ext_modules=None, + options=m_options, + ) + mock_create_from_stage.assert_called_once_with( + composed_model=mock.ANY, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=mock.ANY, + ) + self.assertEqual( + mv, + model_version_impl.ModelVersion._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + ), + ) + + def test_log_model_3(self) -> None: + m_model = mock.MagicMock() + m_python_version = mock.MagicMock() + m_code_paths = mock.MagicMock() + m_ext_modules = mock.MagicMock() + m_stage_path = "@TEMP.TEST.MODEL/V1" + with mock.patch.object(self.m_r._model_ops, "validate_existence", return_value=False), mock.patch.object( + self.m_r._model_ops, "prepare_model_stage_path", return_value=m_stage_path + ) as mock_prepare_model_stage_path, mock.patch.object( + model_composer.ModelComposer, "save" + ) as mock_save, mock.patch.object( + self.m_r._model_ops, "create_from_stage" + ) as mock_create_from_stage: + mv = self.m_r.log_model( + model=m_model, + model_name="MODEL", + version_name="V1", + python_version=m_python_version, + code_paths=m_code_paths, + ext_modules=m_ext_modules, + ) + mock_prepare_model_stage_path.assert_called_once_with( + statement_params=mock.ANY, + ) + mock_save.assert_called_once_with( + name="MODEL", + model=m_model, + signatures=None, + sample_input=None, + conda_dependencies=None, + pip_requirements=None, + python_version=m_python_version, + code_paths=m_code_paths, + ext_modules=m_ext_modules, + options=None, + ) + mock_create_from_stage.assert_called_once_with( + composed_model=mock.ANY, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=mock.ANY, + ) + self.assertEqual( + mv, + model_version_impl.ModelVersion._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + ), + ) + + def test_log_model_4(self) -> None: + m_model = mock.MagicMock() + m_stage_path = "@TEMP.TEST.MODEL/V1" + with mock.patch.object(self.m_r._model_ops, "validate_existence", return_value=False), mock.patch.object( + self.m_r._model_ops, "prepare_model_stage_path", return_value=m_stage_path + ) as mock_prepare_model_stage_path, mock.patch.object( + model_composer.ModelComposer, "save" + ) as mock_save, mock.patch.object( + self.m_r._model_ops, "create_from_stage" + ) as mock_create_from_stage, mock.patch.object( + self.m_r._model_ops, "set_comment" + ) as mock_set_comment, mock.patch.object( + self.m_r._model_ops._metadata_ops, "save" + ) as mock_metadata_save: + mv = self.m_r.log_model( + model=m_model, model_name="MODEL", version_name="V1", comment="this is comment", metrics={"a": 1} + ) + mock_prepare_model_stage_path.assert_called_once_with( + statement_params=mock.ANY, + ) + mock_save.assert_called_once_with( + name="MODEL", + model=m_model, + signatures=None, + sample_input=None, + conda_dependencies=None, + pip_requirements=None, + python_version=None, + code_paths=None, + ext_modules=None, + options=None, + ) + mock_create_from_stage.assert_called_once_with( + composed_model=mock.ANY, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=mock.ANY, + ) + self.assertEqual( + mv, + model_version_impl.ModelVersion._ref( + self.m_r._model_ops, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + ), + ) + mock_set_comment.assert_called_once_with( + comment="this is comment", + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=mock.ANY, + ) + mock_metadata_save.assert_called_once_with( + {"metrics": {"a": 1}}, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=mock.ANY, + ) + + def test_log_model_5(self) -> None: + m_model = mock.MagicMock() + with mock.patch.object(self.m_r._model_ops, "validate_existence", return_value=True) as mock_validate_existence: + with self.assertRaisesRegex(ValueError, "Model MODEL version V1 already existed."): + self.m_r.log_model(model=m_model, model_name="MODEL", version_name="V1") + mock_validate_existence.assert_has_calls( + [ + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ), + mock.call( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V1"), + statement_params=mock.ANY, + ), + ] + ) + + def test_delete_model(self) -> None: + with mock.patch.object(self.m_r._model_ops, "delete_model_or_version") as mock_delete_model_or_version: + self.m_r.delete_model( + model_name="MODEL", + ) + mock_delete_model_or_version.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + statement_params=mock.ANY, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/registry/model_registry.py b/snowflake/ml/registry/model_registry.py index 0146e52c..ae8a7a18 100644 --- a/snowflake/ml/registry/model_registry.py +++ b/snowflake/ml/registry/model_registry.py @@ -3,6 +3,7 @@ import sys import textwrap import types +import warnings from typing import ( TYPE_CHECKING, Any, @@ -305,6 +306,17 @@ def __init__( schema_name: Desired name of the schema used by this model registry inside the database. create_if_not_exists: create model registry if it's not exists already. """ + + warnings.warn( + """ +The `snowflake.ml.registry.model_registry.ModelRegistry` has been deprecated starting from version 1.2.0. +It will stay in the Private Preview phase. For future implementations, kindly utilize `snowflake.ml.registry.Registry`, +except when specifically required. The old model registry will be removed once all its primary functionalities are +fully integrated into the new registry. + """, + DeprecationWarning, + stacklevel=2, + ) if create_if_not_exists: create_model_registry(session=session, database_name=database_name, schema_name=schema_name) diff --git a/snowflake/ml/registry/notebooks/Using MODEL via Registry in Snowflake.ipynb b/snowflake/ml/registry/notebooks/Using MODEL via Registry in Snowflake.ipynb index 895cf8e9..9d2e5757 100644 --- a/snowflake/ml/registry/notebooks/Using MODEL via Registry in Snowflake.ipynb +++ b/snowflake/ml/registry/notebooks/Using MODEL via Registry in Snowflake.ipynb @@ -231,7 +231,7 @@ "metadata": {}, "outputs": [], "source": [ - "remote_prediction = mv.run(test_features, method_name=\"predict\")" + "remote_prediction = mv.run(test_features, function_name=\"predict\")" ] }, { @@ -260,7 +260,7 @@ "metadata": {}, "outputs": [], "source": [ - "mv.list_methods()" + "mv.show_functions()" ] }, { @@ -269,7 +269,7 @@ "metadata": {}, "outputs": [], "source": [ - "remote_prediction_proba = mv.run(test_features, method_name=\"predict_proba\")" + "remote_prediction_proba = mv.run(test_features, function_name=\"predict_proba\")" ] }, { @@ -321,7 +321,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### List models and versions\n" + "### Show and List models and versions\n" ] }, { @@ -330,7 +330,7 @@ "metadata": {}, "outputs": [], "source": [ - "reg.list_models()" + "reg.show_models()" ] }, { @@ -339,7 +339,25 @@ "metadata": {}, "outputs": [], "source": [ - "m.list_versions()" + "reg.models()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.show_versions()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.versions()" ] }, { @@ -444,7 +462,7 @@ "metadata": {}, "outputs": [], "source": [ - "mv.list_metrics()" + "mv.show_metrics()" ] }, { @@ -501,7 +519,7 @@ "metadata": {}, "outputs": [], "source": [ - "reg.list_models()" + "reg.show_models()" ] }, { @@ -637,7 +655,7 @@ "metadata": {}, "outputs": [], "source": [ - "mv.run(kddcup99_sp_df_test, method_name=\"predict\").show()" + "mv.run(kddcup99_sp_df_test, function_name=\"predict\").show()" ] }, { diff --git a/snowflake/ml/registry/package_visibility_test.py b/snowflake/ml/registry/package_visibility_test.py new file mode 100644 index 00000000..8dd064c9 --- /dev/null +++ b/snowflake/ml/registry/package_visibility_test.py @@ -0,0 +1,21 @@ +from types import ModuleType + +from absl.testing import absltest + +from snowflake.ml import registry +from snowflake.ml.registry import artifact, model_registry + + +class PackageVisibilityTest(absltest.TestCase): + """Ensure that the functions in this package are visible externally.""" + + def test_class_visible(self) -> None: + self.assertIsInstance(registry.Registry, type) + + def test_module_visible(self) -> None: + self.assertIsInstance(model_registry, ModuleType) + self.assertIsInstance(artifact, ModuleType) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/registry/registry.py b/snowflake/ml/registry/registry.py index f8031fac..2132d549 100644 --- a/snowflake/ml/registry/registry.py +++ b/snowflake/ml/registry/registry.py @@ -1,12 +1,17 @@ from types import ModuleType -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional + +import pandas as pd from snowflake.ml._internal import telemetry from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.model import model_signature, type_hints as model_types -from snowflake.ml.model._client.model import model_impl, model_version_impl -from snowflake.ml.model._client.ops import model_ops -from snowflake.ml.model._model_composer import model_composer +from snowflake.ml.model import ( + Model, + ModelVersion, + model_signature, + type_hints as model_types, +) +from snowflake.ml.registry._manager import model_manager from snowflake.snowpark import session _TELEMETRY_PROJECT = "MLOps" @@ -21,6 +26,18 @@ def __init__( database_name: Optional[str] = None, schema_name: Optional[str] = None, ) -> None: + """Opens a registry within a pre-created Snowflake schema. + + Args: + session: The Snowpark Session to connect with Snowflake. + database_name: The name of the database. If None, the current database of the session + will be used. Defaults to None. + schema_name: The name of the schema. If None, the current schema of the session + will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None. + + Raises: + ValueError: When there is no specified or active database in the session. + """ if database_name: self._database_name = sql_identifier.SqlIdentifier(database_name) else: @@ -42,12 +59,13 @@ def __init__( else sql_identifier.SqlIdentifier("PUBLIC") ) - self._model_ops = model_ops.ModelOperator( + self._model_manager = model_manager.ModelManager( session, database_name=self._database_name, schema_name=self._schema_name ) @property def location(self) -> str: + """Get the location (database.schema) of the registry.""" return ".".join([self._database_name.identifier(), self._schema_name.identifier()]) @telemetry.send_api_usage_telemetry( @@ -60,6 +78,8 @@ def log_model( *, model_name: str, version_name: str, + comment: Optional[str] = None, + metrics: Optional[Dict[str, Any]] = None, conda_dependencies: Optional[List[str]] = None, pip_requirements: Optional[List[str]] = None, python_version: Optional[str] = None, @@ -68,148 +88,136 @@ def log_model( code_paths: Optional[List[str]] = None, ext_modules: Optional[List[ModuleType]] = None, options: Optional[model_types.ModelSaveOption] = None, - ) -> model_version_impl.ModelVersion: - """Log a model. + ) -> ModelVersion: + """ + Log a model with various parameters and metadata. Args: - model: Model Python object - model_name: A string as name. - version_name: A string as version. model_name and version_name combination must be unique. - signatures: Model data signatures for inputs and output for every target methods. If it is None, + model: Model object of supported types such as Scikit-learn, XGBoost, Snowpark ML, + PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline, + Peft-finetuned LLM, or Custom Model. + model_name: Name to identify the model. + version_name: Version identifier for the model. Combination of model_name and version_name must be unique. + comment: Comment associated with the model version. Defaults to None. + metrics: A JSON serializable dictionary containing metrics linked to the model version. Defaults to None. + signatures: Model data signatures for inputs and outputs for various target methods. If it is None, sample_input_data would be used to infer the signatures for those models that cannot automatically - infer the signature. If not None, sample_input should not be specified. Defaults to None. - sample_input_data: Sample input data to infer the model signatures from. If it is None, signatures must be - specified if the model cannot automatically infer the signature. If not None, signatures should not be - specified. Defaults to None. - conda_dependencies: List of Conda package specs. Use "[channel::]package [operator version]" syntax to - specify a dependency. It is a recommended way to specify your dependencies using conda. When channel is - not specified, Snowflake Anaconda Channel will be used. - pip_requirements: List of Pip package specs. - python_version: A string of python version where model is run. Used for user override. If specified as None, - current version would be captured. Defaults to None. - code_paths: Directory of code to import. - ext_modules: External modules that user might want to get pickled with model object. Defaults to None. - options: Model specific kwargs. + infer the signature. If not None, sample_input_data should not be specified. Defaults to None. + sample_input_data: Sample input data to infer model signatures from. Defaults to None. + conda_dependencies: List of Conda package specifications. Use "[channel::]package [operator version]" syntax + to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel + is not specified, Snowflake Anaconda Channel will be used. Defaults to None. + pip_requirements: List of Pip package specifications. Defaults to None. + python_version: Python version in which the model is run. Defaults to None. + code_paths: List of directories containing code to import. Defaults to None. + ext_modules: List of external modules to pickle with the model object. + Only supported when logging the following types of model: + Scikit-learn, Snowpark ML, PyTorch, TorchScript and Custom Model. Defaults to None. + options (Dict[str, Any], optional): Additional model saving options. + + Model Saving Options include: + + - embed_local_ml_library: Embed local Snowpark ML into the code directory or folder. + Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda + Channel. Otherwise, defaults to False + - method_options: Per-method saving options including: + - case_sensitive: Indicates whether the method and its signature should be case sensitive. + This means when you refer the method in the SQL, you need to double quote it. + This will be helpful if you need case to tell apart your methods or features, or you have + non-alphabetic characters in your method or feature name. Defaults to False. + - max_batch_size: Maximum batch size that the method could accept in the Snowflake Warehouse. + Defaults to None, determined automatically by Snowflake. Returns: - A ModelVersion object corresponding to the model just get logged. + ModelVersion: ModelVersion object corresponding to the model just logged. """ statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_MODEL_TELEMETRY_SUBPROJECT, ) - model_name_id = sql_identifier.SqlIdentifier(model_name) - - version_name_id = sql_identifier.SqlIdentifier(version_name) - - stage_path = self._model_ops.prepare_model_stage_path( - statement_params=statement_params, - ) - - mc = model_composer.ModelComposer(self._model_ops._session, stage_path=stage_path) - mc.save( - name=model_name_id.resolved(), + return self._model_manager.log_model( model=model, - signatures=signatures, - sample_input=sample_input_data, + model_name=model_name, + version_name=version_name, + comment=comment, + metrics=metrics, conda_dependencies=conda_dependencies, pip_requirements=pip_requirements, python_version=python_version, + signatures=signatures, + sample_input_data=sample_input_data, code_paths=code_paths, ext_modules=ext_modules, options=options, - ) - self._model_ops.create_from_stage( - composed_model=mc, - model_name=model_name_id, - version_name=version_name_id, statement_params=statement_params, ) - return model_version_impl.ModelVersion._ref( - self._model_ops, - model_name=model_name_id, - version_name=version_name_id, - ) - @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_MODEL_TELEMETRY_SUBPROJECT, ) - def get_model(self, model_name: str) -> model_impl.Model: - """Get the model object. + def get_model(self, model_name: str) -> Model: + """Get the model object by its name. Args: - model_name: The model name. - - Raises: - ValueError: Raised when the model requested does not exist. + model_name: The name of the model. Returns: - The model object. + The corresponding model object. """ - model_name_id = sql_identifier.SqlIdentifier(model_name) - statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_MODEL_TELEMETRY_SUBPROJECT, ) - if self._model_ops.validate_existence( - model_name=model_name_id, - statement_params=statement_params, - ): - return model_impl.Model._ref( - self._model_ops, - model_name=model_name_id, - ) - else: - raise ValueError(f"Unable to find model {model_name}") + return self._model_manager.get_model(model_name=model_name, statement_params=statement_params) @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_MODEL_TELEMETRY_SUBPROJECT, ) - def list_models(self) -> List[model_impl.Model]: - """List all models in the schema where the registry is opened. + def models(self) -> List[Model]: + """Get all models in the schema where the registry is opened. Returns: - A List of Model= object representing all models in the schema where the registry is opened. + A list of Model objects representing all models in the opened registry. """ statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_MODEL_TELEMETRY_SUBPROJECT, ) - model_names = self._model_ops.list_models_or_versions( - statement_params=statement_params, + return self._model_manager.models(statement_params=statement_params) + + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_MODEL_TELEMETRY_SUBPROJECT, + ) + def show_models(self) -> pd.DataFrame: + """Show information of all models in the schema where the registry is opened. + + Returns: + A Pandas DataFrame containing information of all models in the schema. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_MODEL_TELEMETRY_SUBPROJECT, ) - return [ - model_impl.Model._ref( - self._model_ops, - model_name=model_name, - ) - for model_name in model_names - ] + return self._model_manager.show_models(statement_params=statement_params) @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_MODEL_TELEMETRY_SUBPROJECT, ) def delete_model(self, model_name: str) -> None: - """Delete the model. + """ + Delete the model by its name. Args: - model_name: The model name, can be fully qualified one. - If not, use database name and schema name of the registry. + model_name: The name of the model to be deleted. """ - model_name_id = sql_identifier.SqlIdentifier(model_name) - statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, subproject=_MODEL_TELEMETRY_SUBPROJECT, ) - self._model_ops.delete_model_or_version( - model_name=model_name_id, - statement_params=statement_params, - ) + self._model_manager.delete_model(model_name=model_name, statement_params=statement_params) diff --git a/snowflake/ml/registry/registry_test.py b/snowflake/ml/registry/registry_test.py index b59e27f9..ef42d1ed 100644 --- a/snowflake/ml/registry/registry_test.py +++ b/snowflake/ml/registry/registry_test.py @@ -3,9 +3,6 @@ from absl.testing import absltest -from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.model._client.model import model_impl, model_version_impl -from snowflake.ml.model._model_composer import model_composer from snowflake.ml.registry import registry from snowflake.ml.test_utils import mock_session from snowflake.snowpark import Session @@ -85,211 +82,86 @@ def setUp(self) -> None: self.c_session = cast(Session, self.m_session) self.m_r = registry.Registry(self.c_session, database_name="TEMP", schema_name="TEST") - def test_get_model_1(self) -> None: - m_model = model_impl.Model._ref( - self.m_r._model_ops, - model_name=sql_identifier.SqlIdentifier("MODEL"), - ) - with mock.patch.object(self.m_r._model_ops, "validate_existence", return_value=True) as mock_validate_existence: - m = self.m_r.get_model("MODEL") - self.assertEqual(m, m_model) - mock_validate_existence.assert_called_once_with( - model_name=sql_identifier.SqlIdentifier("MODEL"), + def test_get_model(self) -> None: + with mock.patch.object(self.m_r._model_manager, "get_model", return_value=True) as mock_get_model: + self.m_r.get_model("MODEL") + mock_get_model.assert_called_once_with( + model_name="MODEL", statement_params=mock.ANY, ) - def test_get_model_2(self) -> None: + def test_models(self) -> None: with mock.patch.object( - self.m_r._model_ops, "validate_existence", return_value=False - ) as mock_validate_existence: - with self.assertRaisesRegex(ValueError, "Unable to find model MODEL"): - self.m_r.get_model("MODEL") - mock_validate_existence.assert_called_once_with( - model_name=sql_identifier.SqlIdentifier("MODEL"), + self.m_r._model_manager, + "models", + ) as mock_show_models: + self.m_r.models() + mock_show_models.assert_called_once_with( statement_params=mock.ANY, ) - def test_list_models(self) -> None: - m_model_1 = model_impl.Model._ref( - self.m_r._model_ops, - model_name=sql_identifier.SqlIdentifier("MODEL"), - ) - m_model_2 = model_impl.Model._ref( - self.m_r._model_ops, - model_name=sql_identifier.SqlIdentifier("Model", case_sensitive=True), - ) + def test_show_models(self) -> None: with mock.patch.object( - self.m_r._model_ops, - "list_models_or_versions", - return_value=[ - sql_identifier.SqlIdentifier("MODEL"), - sql_identifier.SqlIdentifier("Model", case_sensitive=True), - ], - ) as mock_list_models_or_versions: - m_list = self.m_r.list_models() - self.assertListEqual(m_list, [m_model_1, m_model_2]) - mock_list_models_or_versions.assert_called_once_with( + self.m_r._model_manager, + "show_models", + ) as mock_show_models: + self.m_r.show_models() + mock_show_models.assert_called_once_with( statement_params=mock.ANY, ) - def test_log_model_1(self) -> None: + def test_log_model(self) -> None: m_model = mock.MagicMock() m_conda_dependency = mock.MagicMock() m_sample_input_data = mock.MagicMock() - m_stage_path = "@TEMP.TEST.MODEL/V1" - with mock.patch.object( - self.m_r._model_ops, "prepare_model_stage_path", return_value=m_stage_path - ) as mock_prepare_model_stage_path, mock.patch.object( - model_composer.ModelComposer, "save" - ) as mock_save, mock.patch.object( - self.m_r._model_ops, "create_from_stage" - ) as mock_create_from_stage: - mv = self.m_r.log_model( - model=m_model, - model_name="MODEL", - version_name="v1", - conda_dependencies=m_conda_dependency, - sample_input_data=m_sample_input_data, - ) - mock_prepare_model_stage_path.assert_called_once_with( - statement_params=mock.ANY, - ) - mock_save.assert_called_once_with( - name="MODEL", - model=m_model, - signatures=None, - sample_input=m_sample_input_data, - conda_dependencies=m_conda_dependency, - pip_requirements=None, - python_version=None, - code_paths=None, - ext_modules=None, - options=None, - ) - mock_create_from_stage.assert_called_once_with( - composed_model=mock.ANY, - model_name=sql_identifier.SqlIdentifier("MODEL"), - version_name=sql_identifier.SqlIdentifier("v1"), - statement_params=mock.ANY, - ) - self.assertEqual( - mv, - model_version_impl.ModelVersion._ref( - self.m_r._model_ops, - model_name=sql_identifier.SqlIdentifier("MODEL"), - version_name=sql_identifier.SqlIdentifier("v1"), - ), - ) - - def test_log_model_2(self) -> None: - m_model = mock.MagicMock() m_pip_requirements = mock.MagicMock() m_signatures = mock.MagicMock() m_options = mock.MagicMock() - m_stage_path = "@TEMP.TEST.MODEL/V1" - with mock.patch.object( - self.m_r._model_ops, "prepare_model_stage_path", return_value=m_stage_path - ) as mock_prepare_model_stage_path, mock.patch.object( - model_composer.ModelComposer, "save" - ) as mock_save, mock.patch.object( - self.m_r._model_ops, "create_from_stage" - ) as mock_create_from_stage: - mv = self.m_r.log_model( - model=m_model, - model_name="MODEL", - version_name="V1", - pip_requirements=m_pip_requirements, - signatures=m_signatures, - options=m_options, - ) - mock_prepare_model_stage_path.assert_called_once_with( - statement_params=mock.ANY, - ) - mock_save.assert_called_once_with( - name="MODEL", - model=m_model, - signatures=m_signatures, - sample_input=None, - conda_dependencies=None, - pip_requirements=m_pip_requirements, - python_version=None, - code_paths=None, - ext_modules=None, - options=m_options, - ) - mock_create_from_stage.assert_called_once_with( - composed_model=mock.ANY, - model_name=sql_identifier.SqlIdentifier("MODEL"), - version_name=sql_identifier.SqlIdentifier("V1"), - statement_params=mock.ANY, - ) - self.assertEqual( - mv, - model_version_impl.ModelVersion._ref( - self.m_r._model_ops, - model_name=sql_identifier.SqlIdentifier("MODEL"), - version_name=sql_identifier.SqlIdentifier("V1"), - ), - ) - - def test_log_model_3(self) -> None: - m_model = mock.MagicMock() m_python_version = mock.MagicMock() m_code_paths = mock.MagicMock() m_ext_modules = mock.MagicMock() - m_stage_path = "@TEMP.TEST.MODEL/V1" - with mock.patch.object( - self.m_r._model_ops, "prepare_model_stage_path", return_value=m_stage_path - ) as mock_prepare_model_stage_path, mock.patch.object( - model_composer.ModelComposer, "save" - ) as mock_save, mock.patch.object( - self.m_r._model_ops, "create_from_stage" - ) as mock_create_from_stage: - mv = self.m_r.log_model( + m_comment = mock.MagicMock() + m_metrics = mock.MagicMock() + with mock.patch.object(self.m_r._model_manager, "log_model") as mock_log_model: + self.m_r.log_model( model=m_model, model_name="MODEL", - version_name="V1", + version_name="v1", + comment=m_comment, + metrics=m_metrics, + conda_dependencies=m_conda_dependency, + pip_requirements=m_pip_requirements, python_version=m_python_version, + signatures=m_signatures, + sample_input_data=m_sample_input_data, code_paths=m_code_paths, ext_modules=m_ext_modules, + options=m_options, ) - mock_prepare_model_stage_path.assert_called_once_with( - statement_params=mock.ANY, - ) - mock_save.assert_called_once_with( - name="MODEL", + mock_log_model.assert_called_once_with( model=m_model, - signatures=None, - sample_input=None, - conda_dependencies=None, - pip_requirements=None, + model_name="MODEL", + version_name="v1", + comment=m_comment, + metrics=m_metrics, + conda_dependencies=m_conda_dependency, + pip_requirements=m_pip_requirements, python_version=m_python_version, + signatures=m_signatures, + sample_input_data=m_sample_input_data, code_paths=m_code_paths, ext_modules=m_ext_modules, - options=None, - ) - mock_create_from_stage.assert_called_once_with( - composed_model=mock.ANY, - model_name=sql_identifier.SqlIdentifier("MODEL"), - version_name=sql_identifier.SqlIdentifier("V1"), + options=m_options, statement_params=mock.ANY, ) - self.assertEqual( - mv, - model_version_impl.ModelVersion._ref( - self.m_r._model_ops, - model_name=sql_identifier.SqlIdentifier("MODEL"), - version_name=sql_identifier.SqlIdentifier("V1"), - ), - ) def test_delete_model(self) -> None: - with mock.patch.object(self.m_r._model_ops, "delete_model_or_version") as mock_delete_model_or_version: + with mock.patch.object(self.m_r._model_manager, "delete_model") as mock_delete_model: self.m_r.delete_model( model_name="MODEL", ) - mock_delete_model_or_version.assert_called_once_with( - model_name=sql_identifier.SqlIdentifier("MODEL"), + mock_delete_model.assert_called_once_with( + model_name="MODEL", statement_params=mock.ANY, ) diff --git a/snowflake/ml/version.bzl b/snowflake/ml/version.bzl index d6dfd85a..793bbe93 100644 --- a/snowflake/ml/version.bzl +++ b/snowflake/ml/version.bzl @@ -1,2 +1,2 @@ # This is parsed by regex in conda reciper meta file. Make sure not to break it. -VERSION = "1.1.2" +VERSION = "1.2.0" diff --git a/tests/integ/snowflake/ml/_internal/env_utils_integ_test.py b/tests/integ/snowflake/ml/_internal/env_utils_integ_test.py index 35813d87..487a8576 100644 --- a/tests/integ/snowflake/ml/_internal/env_utils_integ_test.py +++ b/tests/integ/snowflake/ml/_internal/env_utils_integ_test.py @@ -14,33 +14,32 @@ def tearDown(self) -> None: self._session.close() def test_validate_requirement_in_snowflake_conda_channel(self) -> None: - res = env_utils.validate_requirements_in_information_schema( + res = env_utils.get_matched_package_versions_in_information_schema( session=self._session, reqs=[requirements.Requirement("xgboost")], python_version=snowml_env.PYTHON_VERSION ) - self.assertNotEmpty(res) + self.assertNotEmpty(res["xgboost"]) - res = env_utils.validate_requirements_in_information_schema( + res = env_utils.get_matched_package_versions_in_information_schema( session=self._session, reqs=[requirements.Requirement("xgboost"), requirements.Requirement("pytorch")], python_version=snowml_env.PYTHON_VERSION, ) - self.assertNotEmpty(res) - - self.assertIsNone( - env_utils.validate_requirements_in_information_schema( - session=self._session, - reqs=[requirements.Requirement("xgboost==1.0.*")], - python_version=snowml_env.PYTHON_VERSION, - ) + self.assertNotEmpty(res["xgboost"]) + self.assertNotEmpty(res["pytorch"]) + + res = env_utils.get_matched_package_versions_in_information_schema( + session=self._session, + reqs=[requirements.Requirement("xgboost==1.0.*")], + python_version=snowml_env.PYTHON_VERSION, ) + self.assertEmpty(res["xgboost"]) - self.assertIsNone( - env_utils.validate_requirements_in_information_schema( - session=self._session, - reqs=[requirements.Requirement("python-package")], - python_version=snowml_env.PYTHON_VERSION, - ) + res = env_utils.get_matched_package_versions_in_information_schema( + session=self._session, + reqs=[requirements.Requirement("python-package")], + python_version=snowml_env.PYTHON_VERSION, ) + self.assertNotIn("python-package", res) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/extra_tests/BUILD.bazel b/tests/integ/snowflake/ml/extra_tests/BUILD.bazel index 9dfdf6d3..d34395ef 100644 --- a/tests/integ/snowflake/ml/extra_tests/BUILD.bazel +++ b/tests/integ/snowflake/ml/extra_tests/BUILD.bazel @@ -138,6 +138,18 @@ py_test( ], ) +py_test( + name = "fit_transform_test", + srcs = ["fit_transform_test.py"], + shard_count = 3, + deps = [ + "//snowflake/ml/modeling/manifold:mds", + "//snowflake/ml/modeling/manifold:spectral_embedding", + "//snowflake/ml/modeling/manifold:tsne", + "//snowflake/ml/utils:connection_params", + ], +) + py_test( name = "decimal_type_test", srcs = ["decimal_type_test.py"], @@ -146,3 +158,13 @@ py_test( "//snowflake/ml/utils:connection_params", ], ) + +py_test( + name = "xgboost_external_memory_training_test", + srcs = ["xgboost_external_memory_training_test.py"], + deps = [ + "//snowflake/ml/modeling/metrics:classification", + "//snowflake/ml/modeling/xgboost:xgb_classifier", + "//snowflake/ml/utils:connection_params", + ], +) diff --git a/tests/integ/snowflake/ml/extra_tests/fit_transform_test.py b/tests/integ/snowflake/ml/extra_tests/fit_transform_test.py new file mode 100644 index 00000000..d29a611c --- /dev/null +++ b/tests/integ/snowflake/ml/extra_tests/fit_transform_test.py @@ -0,0 +1,73 @@ +import numpy as np +import pandas as pd +from absl.testing.absltest import TestCase, main +from sklearn.datasets import load_digits +from sklearn.manifold import ( + MDS as SKMDS, + TSNE as SKTSNE, + SpectralEmbedding as SKSpectralEmbedding, +) + +from snowflake.ml.modeling.manifold import MDS, TSNE, SpectralEmbedding +from snowflake.ml.utils.connection_params import SnowflakeLoginOptions +from snowflake.snowpark import Session + + +class FitTransformTest(TestCase): + def _load_data(self): + X, _ = load_digits(return_X_y=True) + self._input_df_pandas = pd.DataFrame(X)[:100] + self._input_df_pandas.columns = [str(c) for c in self._input_df_pandas.columns] + self._input_df = self._session.create_dataframe(self._input_df_pandas) + self._input_cols = self._input_df.columns + self._output_cols = [str(c) for c in range(100)] + + def setUp(self): + """Creates Snowpark and Snowflake environments for testing.""" + self._session = Session.builder.configs(SnowflakeLoginOptions()).create() + self._load_data() + + def tearDown(self): + self._session.close() + + def testMDS(self): + sk_embedding = SKMDS(n_components=2, normalized_stress="auto", random_state=2024) + + embedding = MDS( + input_cols=self._input_cols, + output_cols=self._output_cols, + n_components=2, + normalized_stress="auto", + random_state=2024, + ) + sk_X_transformed = sk_embedding.fit_transform(self._input_df_pandas) + X_transformed = embedding.fit_transform(self._input_df) + np.testing.assert_allclose(sk_X_transformed, X_transformed, rtol=1.0e-1, atol=1.0e-2) + + def testSpectralEmbedding(self): + sk_embedding = SKSpectralEmbedding(n_components=2, random_state=2024) + sk_X_transformed = sk_embedding.fit_transform(self._input_df_pandas) + + embedding = SpectralEmbedding( + input_cols=self._input_cols, output_cols=self._output_cols, n_components=2, random_state=2024 + ) + X_transformed = embedding.fit_transform(self._input_df) + np.testing.assert_allclose(sk_X_transformed, X_transformed, rtol=1.0e-1, atol=1.0e-2) + + def testTSNE(self): + sk_embedding = SKTSNE(n_components=2, random_state=2024, n_jobs=1) + sk_X_transformed = sk_embedding.fit_transform(self._input_df_pandas) + + embedding = TSNE( + input_cols=self._input_cols, + output_cols=self._output_cols, + n_components=2, + random_state=2024, + n_jobs=1, + ) + X_transformed = embedding.fit_transform(self._input_df) + np.testing.assert_allclose(sk_X_transformed.shape, X_transformed.shape, rtol=1.0e-1, atol=1.0e-2) + + +if __name__ == "__main__": + main() diff --git a/tests/integ/snowflake/ml/extra_tests/xgboost_external_memory_training_test.py b/tests/integ/snowflake/ml/extra_tests/xgboost_external_memory_training_test.py new file mode 100644 index 00000000..1ab9100d --- /dev/null +++ b/tests/integ/snowflake/ml/extra_tests/xgboost_external_memory_training_test.py @@ -0,0 +1,81 @@ +import numpy as np +from absl.testing.absltest import TestCase, main +from sklearn.metrics import accuracy_score as sk_accuracy_score +from xgboost import XGBClassifier as NativeXGBClassifier + +from snowflake.ml.modeling.xgboost import XGBClassifier +from snowflake.ml.utils.connection_params import SnowflakeLoginOptions +from snowflake.snowpark import Session, functions as F + +categorical_columns = [ + "AGE", + "CAMPAIGN", + "CONTACT", + "DAY_OF_WEEK", + "EDUCATION", + "HOUSING", + "JOB", + "LOAN", + "MARITAL", + "MONTH", + "POUTCOME", + "DEFAULT", +] +numerical_columns = [ + "CONS_CONF_IDX", + "CONS_PRICE_IDX", + "DURATION", + "EMP_VAR_RATE", + "EURIBOR3M", + "NR_EMPLOYED", + "PDAYS", + "PREVIOUS", +] +label_column = ["LABEL"] +feature_cols = categorical_columns + numerical_columns + ["ROW_INDEX"] + + +class XGBoostExternalMemoryTrainingTest(TestCase): + def setUp(self): + """Creates Snowpark and Snowflake environments for testing.""" + self._session = Session.builder.configs(SnowflakeLoginOptions()).create() + + def tearDown(self): + self._session.close() + + def test_fit_and_compare_results(self) -> None: + input_df = ( + self._session.sql( + """SELECT *, IFF(Y = 'yes', 1.0, 0.0) as LABEL + FROM ML_DATASETS.PUBLIC.UCI_BANK_MARKETING_20COLUMNS""" + ) + .drop("Y") + .withColumn("ROW_INDEX", F.monotonically_increasing_id()) + ) + pd_df = input_df.to_pandas().sort_values(by=["ROW_INDEX"])[numerical_columns + ["ROW_INDEX", "LABEL"]] + sp_df = self._session.create_dataframe(pd_df) + + sk_reg = NativeXGBClassifier(random_state=0) + sk_reg.fit(pd_df[numerical_columns], pd_df["LABEL"]) + sk_result = sk_reg.predict(pd_df[numerical_columns]) + + sk_accuracy = sk_accuracy_score(pd_df["LABEL"], sk_result) + + reg = XGBClassifier( + random_state=0, + input_cols=numerical_columns, + label_cols=label_column, + use_external_memory_version=True, + batch_size=10000, + ) + reg.fit(sp_df) + result = reg.predict(sp_df) + + result_pd = result.to_pandas().sort_values(by="ROW_INDEX")[["LABEL", "OUTPUT_LABEL"]] + accuracy = sk_accuracy_score(result_pd["LABEL"], result_pd["OUTPUT_LABEL"]) + + np.testing.assert_allclose(sk_accuracy, accuracy, rtol=0.01, atol=0.01) + + +if __name__ == "__main__": + main() diff --git a/tests/integ/snowflake/ml/image_builds/BUILD.bazel b/tests/integ/snowflake/ml/image_builds/BUILD.bazel index bb31ad2e..8e3e15f9 100644 --- a/tests/integ/snowflake/ml/image_builds/BUILD.bazel +++ b/tests/integ/snowflake/ml/image_builds/BUILD.bazel @@ -5,9 +5,9 @@ py_test( timeout = "long", srcs = ["image_registry_client_integ_test.py"], deps = [ + "//snowflake/ml/_internal/container_services/image_registry:registry_client", "//snowflake/ml/_internal/utils:identifier", "//snowflake/ml/_internal/utils:query_result_checker", - "//snowflake/ml/model/_deploy_client/utils:image_registry_client", "//snowflake/ml/model/_deploy_client/utils:snowservice_client", "//tests/integ/snowflake/ml/test_utils:spcs_integ_test_base", ], diff --git a/tests/integ/snowflake/ml/image_builds/image_registry_client_integ_test.py b/tests/integ/snowflake/ml/image_builds/image_registry_client_integ_test.py index af70f7a7..e502c894 100644 --- a/tests/integ/snowflake/ml/image_builds/image_registry_client_integ_test.py +++ b/tests/integ/snowflake/ml/image_builds/image_registry_client_integ_test.py @@ -1,10 +1,10 @@ from absl.testing import absltest -from snowflake.ml._internal.utils import identifier, query_result_checker -from snowflake.ml.model._deploy_client.utils import ( - image_registry_client, - snowservice_client, +from snowflake.ml._internal.container_services.image_registry import ( + registry_client as image_registry_client, ) +from snowflake.ml._internal.utils import identifier, query_result_checker +from snowflake.ml.model._deploy_client.utils import snowservice_client from tests.integ.snowflake.ml.test_utils import spcs_integ_test_base diff --git a/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel b/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel index d5bfd8b7..f3410613 100644 --- a/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel +++ b/tests/integ/snowflake/ml/model/_client/model/BUILD.bazel @@ -4,8 +4,9 @@ py_test( name = "model_impl_integ_test", timeout = "long", srcs = ["model_impl_integ_test.py"], - shard_count = 6, deps = [ + "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/registry", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/test_utils:db_manager", @@ -18,8 +19,8 @@ py_test( name = "model_version_impl_integ_test", timeout = "long", srcs = ["model_version_impl_integ_test.py"], - shard_count = 6, deps = [ + "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model/_client/model:model_version_impl", "//snowflake/ml/registry", diff --git a/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py b/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py index 48760ac5..5256c1e0 100644 --- a/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py +++ b/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py @@ -4,6 +4,7 @@ from absl.testing import absltest, parameterized from packaging import version +from snowflake.ml._internal.utils import identifier, snowflake_env from snowflake.ml.registry import registry from snowflake.ml.utils import connection_params from snowflake.snowpark import Session @@ -18,6 +19,14 @@ VERSION_NAME2 = "V2" +@unittest.skipUnless( + test_env_utils.get_current_snowflake_version() >= version.parse("8.0.0"), + "New model only available when the Snowflake Version is newer than 8.0.0", +) +@unittest.skipUnless( + test_env_utils.get_current_snowflake_cloud_type() == snowflake_env.SnowflakeCloudType.AWS, + "New model only available in AWS", +) class TestModelImplInteg(parameterized.TestCase): @classmethod def setUpClass(self) -> None: @@ -37,11 +46,6 @@ def setUpClass(self) -> None: } ).create() - current_sf_version = test_env_utils.get_current_snowflake_version(self._session) - - if current_sf_version < version.parse("8.0.0"): - raise unittest.SkipTest("This test requires Snowflake Version 8.0.0 or higher.") - self._db_manager = db_manager.DBManager(self._session) self._db_manager.create_database(self._test_db) self._db_manager.create_schema(self._test_schema) @@ -63,11 +67,21 @@ def setUpClass(self) -> None: ) self._model = self.registry.get_model(model_name=MODEL_NAME) + self._tag_name1 = "MYTAG" + self._tag_name2 = '"live_version"' + + self._session.sql(f"CREATE TAG {self._tag_name1}").collect() + self._session.sql(f"CREATE TAG {self._tag_name2}").collect() + @classmethod def tearDownClass(self) -> None: self._db_manager.drop_database(self._test_db) self._session.close() + def test_versions(self) -> None: + self.assertEqual(self._model.versions(), [self._mv, self._mv2]) + self.assertLen(self._model.show_versions(), 2) + def test_description(self) -> None: description = "test description" self._model.description = description @@ -79,6 +93,41 @@ def test_default(self) -> None: self._model.default = VERSION_NAME2 self.assertEqual(self._model.default.version_name, VERSION_NAME2) + @unittest.skipUnless( + test_env_utils.get_current_snowflake_version() >= version.parse("8.2.0"), + "TAG on model only available when the Snowflake Version is newer than 8.2.0", + ) + def test_tag(self) -> None: + fq_tag_name1 = identifier.get_schema_level_object_identifier(self._test_db, self._test_schema, self._tag_name1) + fq_tag_name2 = identifier.get_schema_level_object_identifier(self._test_db, self._test_schema, self._tag_name2) + self.assertDictEqual({}, self._model.show_tags()) + self.assertIsNone(self._model.get_tag(self._tag_name1)) + self._model.set_tag(self._tag_name1, "val1") + self.assertEqual( + "val1", + self._model.get_tag(fq_tag_name1), + ) + self.assertDictEqual( + {fq_tag_name1: "val1"}, + self._model.show_tags(), + ) + self._model.set_tag(fq_tag_name2, "v2") + self.assertEqual("v2", self._model.get_tag(self._tag_name2)) + self.assertDictEqual( + { + fq_tag_name1: "val1", + fq_tag_name2: "v2", + }, + self._model.show_tags(), + ) + self._model.unset_tag(fq_tag_name2) + self.assertDictEqual( + {fq_tag_name1: "val1"}, + self._model.show_tags(), + ) + self._model.unset_tag(self._tag_name1) + self.assertDictEqual({}, self._model.show_tags()) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py b/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py index ca7b367d..0ac15c80 100644 --- a/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py +++ b/tests/integ/snowflake/ml/model/_client/model/model_version_impl_integ_test.py @@ -4,6 +4,7 @@ from absl.testing import absltest, parameterized from packaging import version +from snowflake.ml._internal.utils import snowflake_env from snowflake.ml.registry import registry from snowflake.ml.utils import connection_params from snowflake.snowpark import Session @@ -17,6 +18,14 @@ VERSION_NAME = "V1" +@unittest.skipUnless( + test_env_utils.get_current_snowflake_version() >= version.parse("8.0.0"), + "New model only available when the Snowflake Version is newer than 8.0.0", +) +@unittest.skipUnless( + test_env_utils.get_current_snowflake_cloud_type() == snowflake_env.SnowflakeCloudType.AWS, + "New model only available in AWS", +) class TestModelVersionImplInteg(parameterized.TestCase): @classmethod def setUpClass(self) -> None: @@ -36,11 +45,6 @@ def setUpClass(self) -> None: } ).create() - current_sf_version = test_env_utils.get_current_snowflake_version(self._session) - - if current_sf_version < version.parse("8.0.0"): - raise unittest.SkipTest("This test requires Snowflake Version 8.0.0 or higher.") - self._db_manager = db_manager.DBManager(self._session) self._db_manager.create_database(self._test_db) self._db_manager.create_schema(self._test_schema) @@ -72,11 +76,11 @@ def test_metrics(self) -> None: self._mv.set_metric(k, v) self.assertEqual(self._mv.get_metric("a"), expected_metrics["a"]) - self.assertDictEqual(self._mv.list_metrics(), expected_metrics) + self.assertDictEqual(self._mv.show_metrics(), expected_metrics) expected_metrics.pop("b") self._mv.delete_metric("b") - self.assertDictEqual(self._mv.list_metrics(), expected_metrics) + self.assertDictEqual(self._mv.show_metrics(), expected_metrics) with self.assertRaises(KeyError): self._mv.get_metric("b") diff --git a/tests/integ/snowflake/ml/modeling/model_selection/BUILD.bazel b/tests/integ/snowflake/ml/modeling/model_selection/BUILD.bazel index 2d1bcbee..e80541e7 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/BUILD.bazel +++ b/tests/integ/snowflake/ml/modeling/model_selection/BUILD.bazel @@ -45,3 +45,15 @@ py_test( "//snowflake/ml/utils:connection_params", ], ) + +py_test( + name = "check_output_hpo_integ_test", + timeout = "long", + srcs = ["check_output_hpo_integ_test.py"], + shard_count = 5, + deps = [ + "//snowflake/ml/modeling/linear_model:linear_regression", + "//snowflake/ml/modeling/model_selection:grid_search_cv", + "//snowflake/ml/utils:connection_params", + ], +) diff --git a/tests/integ/snowflake/ml/modeling/model_selection/check_output_hpo_integ_test.py b/tests/integ/snowflake/ml/modeling/model_selection/check_output_hpo_integ_test.py new file mode 100644 index 00000000..95326f58 --- /dev/null +++ b/tests/integ/snowflake/ml/modeling/model_selection/check_output_hpo_integ_test.py @@ -0,0 +1,243 @@ +""" +The main purpose of this file is to use Linear Regression, +to match all kinds of input and output for GridSearchCV/RandomSearchCV. +""" + +from typing import Any, Dict, List, Tuple, Union +from unittest import mock + +import inflection +import numpy as np +import numpy.typing as npt +import pandas as pd +from absl.testing import absltest, parameterized +from sklearn.datasets import load_iris +from sklearn.linear_model import LinearRegression as SkLinearRegression +from sklearn.model_selection import GridSearchCV as SkGridSearchCV, KFold +from sklearn.model_selection._split import BaseCrossValidator + +from snowflake.ml.modeling.linear_model import ( # type: ignore[attr-defined] + LinearRegression, +) +from snowflake.ml.modeling.model_selection import ( # type: ignore[attr-defined] + GridSearchCV, +) +from snowflake.ml.utils.connection_params import SnowflakeLoginOptions +from snowflake.snowpark import Session + + +def _load_iris_data() -> Tuple[pd.DataFrame, List[str], List[str]]: + input_df_pandas = load_iris(as_frame=True).frame + input_df_pandas.columns = [inflection.parameterize(c, "_").upper() for c in input_df_pandas.columns] + input_df_pandas["INDEX"] = input_df_pandas.reset_index().index + + input_cols = [c for c in input_df_pandas.columns if not c.startswith("TARGET")] + label_col = [c for c in input_df_pandas.columns if c.startswith("TARGET")] + + return input_df_pandas, input_cols, label_col + + +class GridSearchCVTest(parameterized.TestCase): + def setUp(self) -> None: + """Creates Snowpark and Snowflake environments for testing.""" + self._session = Session.builder.configs(SnowflakeLoginOptions()).create() + + pd_data, input_col, label_col = _load_iris_data() + self._input_df_pandas = pd_data + self._input_cols = input_col + self._label_col = label_col + self._input_df = self._session.create_dataframe(self._input_df_pandas) + + def tearDown(self) -> None: + self._session.close() + + def _compare_cv_results(self, cv_result_1: Dict[str, Any], cv_result_2: Dict[str, Any]) -> None: + # compare the keys + self.assertEqual(cv_result_1.keys(), cv_result_2.keys()) + # compare the values + for k, v in cv_result_1.items(): + if isinstance(v, np.ndarray): + if k.startswith("param_"): # compare the masked array + np.ma.allequal(v, cv_result_2[k]) # type: ignore[no-untyped-call] + elif k == "params": # compare the parameter combination + self.assertEqual(v.tolist(), cv_result_2[k]) + elif k.endswith("test_score"): # compare the test score + np.testing.assert_allclose(v, cv_result_2[k], rtol=1.0e-7, atol=1.0e-7) + # Do not compare the fit time + + def _compare_global_variables(self, sk_obj: SkLinearRegression, sklearn_reg: SkLinearRegression) -> None: + # the result of SnowML grid search cv should behave the same as sklearn's + # TODO - check scorer_ + assert isinstance(sk_obj.refit_time_, float) + np.testing.assert_allclose(sk_obj.best_score_, sklearn_reg.best_score_) + self.assertEqual(sk_obj.multimetric_, sklearn_reg.multimetric_) + self.assertEqual(sk_obj.best_index_, sklearn_reg.best_index_) + if hasattr(sk_obj, "n_splits_"): # n_splits_ is only available in RandomSearchCV + self.assertEqual(sk_obj.n_splits_, sklearn_reg.n_splits_) + if hasattr(sk_obj, "best_estimator_"): + for variable_name in sk_obj.best_estimator_.__dict__.keys(): + if variable_name != "n_jobs": + if isinstance(getattr(sk_obj.best_estimator_, variable_name), np.ndarray): + if getattr(sk_obj.best_estimator_, variable_name).dtype == "object": + self.assertEqual( + getattr(sk_obj.best_estimator_, variable_name).tolist(), + getattr(sklearn_reg.best_estimator_, variable_name).tolist(), + ) + else: + np.testing.assert_allclose( + getattr(sk_obj.best_estimator_, variable_name), + getattr(sklearn_reg.best_estimator_, variable_name), + rtol=1.0e-7, + atol=1.0e-7, + ) + else: + np.testing.assert_allclose( + getattr(sk_obj.best_estimator_, variable_name), + getattr(sklearn_reg.best_estimator_, variable_name), + rtol=1.0e-7, + atol=1.0e-7, + ) + self.assertEqual(sk_obj.n_features_in_, sklearn_reg.n_features_in_) + if hasattr(sk_obj, "feature_names_in_") and hasattr( + sklearn_reg, "feature_names_in_" + ): # feature_names_in_ variable is only available when `best_estimator_` is defined + self.assertEqual(sk_obj.feature_names_in_.tolist(), sklearn_reg.feature_names_in_.tolist()) + if hasattr(sk_obj, "classes_"): + self.assertEqual(sk_obj.classes_, sklearn_reg.classes_) + self._compare_cv_results(sk_obj.cv_results_, sklearn_reg.cv_results_) + if not sk_obj.multimetric_: + self.assertEqual(sk_obj.best_params_, sklearn_reg.best_params_) + + @parameterized.parameters( # type: ignore[misc] + # Standard Sklearn sample + { + "is_single_node": False, + "params": {"copy_X": [True, False], "fit_intercept": [True, False]}, + "cv": 5, + "kwargs": dict(), + }, + # param_grid: list of dictionary + { + "is_single_node": False, + "params": [ + {"copy_X": [True], "fit_intercept": [True, False]}, + {"copy_X": [False], "fit_intercept": [True, False]}, + ], + "cv": 5, + "kwargs": dict(), + }, + # cv: CV splitter + { + "is_single_node": False, + "params": [ + {"copy_X": [True], "fit_intercept": [True, False]}, + {"copy_X": [False], "fit_intercept": [True, False]}, + ], + "cv": KFold(5), + "kwargs": dict(), + }, + # cv: iterator + { + "is_single_node": False, + "params": [ + {"copy_X": [True], "fit_intercept": [True, False]}, + {"copy_X": [False], "fit_intercept": [True, False]}, + ], + "cv": [ + ( + np.array([i for i in range(30, 150)]), + np.array([i for i in range(30)]), + ), + ( + np.array([i for i in range(30)] + [i for i in range(60, 150)]), + np.array([i for i in range(30, 60)]), + ), + ( + np.array([i for i in range(60)] + [i for i in range(90, 150)]), + np.array([i for i in range(60, 90)]), + ), + ( + np.array([i for i in range(90)] + [i for i in range(120, 150)]), + np.array([i for i in range(90, 120)]), + ), + ( + np.array([i for i in range(120)]), + np.array([i for i in range(120, 150)]), + ), + ], + "kwargs": dict(), + }, + { + "is_single_node": False, + "params": [ + {"copy_X": [True], "fit_intercept": [True, False]}, + {"copy_X": [False], "fit_intercept": [True, False]}, + ], + "cv": [ + ( + [i for i in range(30, 150)], + [i for i in range(30)], + ), + ( + [i for i in range(30)] + [i for i in range(60, 150)], + [i for i in range(30, 60)], + ), + ( + [i for i in range(60)] + [i for i in range(90, 150)], + [i for i in range(60, 90)], + ), + ( + [i for i in range(90)] + [i for i in range(120, 150)], + [i for i in range(90, 120)], + ), + ( + [i for i in range(120)], + [i for i in range(120, 150)], + ), + ], + "kwargs": dict(), + }, + # TODO: scoring + { + "is_single_node": False, + "params": {"copy_X": [True, False], "fit_intercept": [True, False]}, + "cv": 5, + "kwargs": dict(scoring=["accuracy", "f1_macro"], refit="f1_macro", return_train_score=True), + }, + # TODO: refit + # TODO: error_score + # return_train_score: True + { + "is_single_node": False, + "params": {"copy_X": [True, False], "fit_intercept": [True, False]}, + "cv": 5, + "kwargs": dict(return_train_score=True), + }, + ) + @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") + def test_fit_and_compare_results( + self, + mock_is_single_node: mock.MagicMock, + is_single_node: bool, + params: Union[Dict[str, Any], List[Dict[str, Any]]], + cv: Union[int, BaseCrossValidator, List[Tuple[Union[List[int], npt.NDArray[np.int_]]]]], + kwargs: Dict[str, Any], + ) -> None: + mock_is_single_node.return_value = is_single_node + + reg = GridSearchCV(estimator=LinearRegression(), param_grid=params, cv=cv, **kwargs) + sklearn_reg = SkGridSearchCV(estimator=SkLinearRegression(), param_grid=params, cv=cv, **kwargs) + reg.set_input_cols(self._input_cols) + output_cols = ["OUTPUT_" + c for c in self._label_col] + reg.set_output_cols(output_cols) + reg.set_label_cols(self._label_col) + + reg.fit(self._input_df) + sklearn_reg.fit(X=self._input_df_pandas[self._input_cols], y=self._input_df_pandas[self._label_col].squeeze()) + sk_obj = reg.to_sklearn() + + self._compare_global_variables(sk_obj, sklearn_reg) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py b/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py index c152748a..7465acab 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py +++ b/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py @@ -97,6 +97,14 @@ def test_fit_and_compare_results(self, mock_is_single_node) -> None: "kwargs": dict(), "estimator_kwargs": dict(random_state=0), }, + { + "is_single_node": False, + "skmodel": SkRandomForestClassifier, + "model": RandomForestClassifier, + "params": {"n_estimators": [50, 200], "min_samples_split": [1.0, 2, 3], "max_depth": [3, 8]}, + "kwargs": dict(return_train_score=True), + "estimator_kwargs": dict(random_state=0), + }, { "is_single_node": False, "skmodel": SkSVC, @@ -105,6 +113,14 @@ def test_fit_and_compare_results(self, mock_is_single_node) -> None: "kwargs": dict(), "estimator_kwargs": dict(random_state=0), }, + { + "is_single_node": False, + "skmodel": SkSVC, + "model": SVC, + "params": {"kernel": ("linear", "rbf"), "C": [1, 10, 80]}, + "kwargs": dict(return_train_score=True), + "estimator_kwargs": dict(random_state=0), + }, { "is_single_node": False, "skmodel": SkXGBClassifier, @@ -113,6 +129,14 @@ def test_fit_and_compare_results(self, mock_is_single_node) -> None: "kwargs": dict(scoring=["accuracy", "f1_macro"], refit="f1_macro"), "estimator_kwargs": dict(seed=42), }, + { + "is_single_node": False, + "skmodel": SkXGBClassifier, + "model": XGBClassifier, + "params": {"max_depth": [2, 6], "learning_rate": [0.1, 0.01]}, + "kwargs": dict(scoring=["accuracy", "f1_macro"], refit="f1_macro", return_train_score=True), + "estimator_kwargs": dict(seed=42), + }, ) @mock.patch("snowflake.ml.modeling._internal.model_trainer_builder.is_single_node") def test_fit_and_compare_results_distributed( diff --git a/tests/integ/snowflake/ml/modeling/preprocessing/min_max_scaler_test.py b/tests/integ/snowflake/ml/modeling/preprocessing/min_max_scaler_test.py index ec4ca781..c200dd11 100644 --- a/tests/integ/snowflake/ml/modeling/preprocessing/min_max_scaler_test.py +++ b/tests/integ/snowflake/ml/modeling/preprocessing/min_max_scaler_test.py @@ -19,6 +19,7 @@ from snowflake.snowpark import Session from tests.integ.snowflake.ml.modeling.framework import utils as framework_utils from tests.integ.snowflake.ml.modeling.framework.utils import ( + CATEGORICAL_COLS, DATA, DATA_CLIP, ID_COL, @@ -42,6 +43,22 @@ def tearDown(self) -> None: if os.path.exists(filepath): os.remove(filepath) + def test_fit_non_numeric_raises_exception(self) -> None: + """ + Fitting scaler with non-numeric columns should raise an exception.. + + Raises + ------ + AssertionError + If the expected exception is not raised. + """ + input_cols = CATEGORICAL_COLS + _, df = framework_utils.get_df(self._session, DATA, SCHEMA, np.nan) + + scaler = MinMaxScaler().set_input_cols(input_cols) + with self.assertRaises(TypeError): + scaler.fit(df) + def test_fit(self) -> None: """ Verify fitted states. diff --git a/tests/integ/snowflake/ml/registry/model/BUILD.bazel b/tests/integ/snowflake/ml/registry/model/BUILD.bazel index 4dd54706..727e7374 100644 --- a/tests/integ/snowflake/ml/registry/model/BUILD.bazel +++ b/tests/integ/snowflake/ml/registry/model/BUILD.bazel @@ -7,6 +7,7 @@ py_library( testonly = True, srcs = ["registry_model_test_base.py"], deps = [ + "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/model:type_hints", "//snowflake/ml/registry", "//snowflake/ml/utils:connection_params", diff --git a/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py b/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py index 27311c9f..13a85a09 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py +++ b/tests/integ/snowflake/ml/registry/model/registry_model_test_base.py @@ -6,6 +6,7 @@ from absl.testing import absltest from packaging import version +from snowflake.ml._internal.utils import snowflake_env from snowflake.ml.model import type_hints as model_types from snowflake.ml.registry import registry from snowflake.ml.utils import connection_params @@ -13,6 +14,14 @@ from tests.integ.snowflake.ml.test_utils import db_manager, test_env_utils +@unittest.skipUnless( + test_env_utils.get_current_snowflake_version() >= version.parse("8.0.0"), + "New model only available when the Snowflake Version is newer than 8.0.0", +) +@unittest.skipUnless( + test_env_utils.get_current_snowflake_cloud_type() == snowflake_env.SnowflakeCloudType.AWS, + "New model only available in AWS", +) class RegistryModelTestBase(absltest.TestCase): def setUp(self) -> None: """Creates Snowpark and Snowflake environments for testing.""" @@ -31,11 +40,6 @@ def setUp(self) -> None: } ).create() - current_sf_version = test_env_utils.get_current_snowflake_version(self._session) - - if current_sf_version < version.parse("8.0.0"): - raise unittest.SkipTest("This test requires Snowflake Version 8.0.0 or higher.") - self._db_manager = db_manager.DBManager(self._session) self._db_manager.create_database(self._test_db) self._db_manager.create_schema(self._test_schema) @@ -73,13 +77,15 @@ def _test_registry_model( ) for target_method, (test_input, check_func) in prediction_assert_fns.items(): - res = mv.run(test_input, method_name=target_method) + res = mv.run(test_input, function_name=target_method) check_func(res) + self.registry.show_models() + self.registry.delete_model(model_name=name) - self.assertNotIn(mv.model_name, [m.name for m in self.registry.list_models()]) + self.assertNotIn(mv.model_name, [m.name for m in self.registry.models()]) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/registry/model/registry_tensorflow_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_tensorflow_model_test.py index 48398502..20af6c7f 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_tensorflow_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_tensorflow_model_test.py @@ -2,6 +2,7 @@ import numpy as np import pandas as pd +import pytest import tensorflow as tf from absl.testing import absltest @@ -25,6 +26,7 @@ def __call__(self, tensor: tf.Tensor) -> tf.Tensor: return self.a_variable * tensor + self.non_trainable_variable +@pytest.mark.pip_incompatible class TestRegistryTensorflowModelInteg(registry_model_test_base.RegistryModelTestBase): def test_tf_tensor_as_sample( self, diff --git a/tests/integ/snowflake/ml/registry/model_registry_compat_test.py b/tests/integ/snowflake/ml/registry/model_registry_compat_test.py index 781ae871..7a762c0d 100644 --- a/tests/integ/snowflake/ml/registry/model_registry_compat_test.py +++ b/tests/integ/snowflake/ml/registry/model_registry_compat_test.py @@ -69,6 +69,8 @@ def prepare_registry_and_log_model(session: session.Session, registry_name: str, registry = model_registry.ModelRegistry(session=session, database_name=registry_name) iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True) + # Normalize the column name to avoid set it as case_sensitive where there was a BCR in 1.1.2 + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] # LogisticRegression is for classfication task, such as iris regr = linear_model.LogisticRegression() regr.fit(iris_X, iris_y) @@ -101,6 +103,7 @@ def test_log_model_compat(self, permanent: bool) -> None: deployment_name=deployment_name, target_method="predict", permanent=permanent ) iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True) + iris_X.columns = [s.replace(" (CM)", "").replace(" ", "") for s in iris_X.columns.str.upper()] model_ref.predict(deployment_name, iris_X) diff --git a/tests/integ/snowflake/ml/test_utils/BUILD.bazel b/tests/integ/snowflake/ml/test_utils/BUILD.bazel index e55d4da4..2e8ca7d8 100644 --- a/tests/integ/snowflake/ml/test_utils/BUILD.bazel +++ b/tests/integ/snowflake/ml/test_utils/BUILD.bazel @@ -47,9 +47,10 @@ py_library( ], deps = [ ":_snowml_requirements", + ":test_env_utils", + "//snowflake/ml/_internal:env", "//snowflake/ml/_internal:env_utils", "//snowflake/ml/_internal:file_utils", - "//snowflake/ml/utils:connection_params", ], ) @@ -75,6 +76,8 @@ py_library( "//snowflake/ml/_internal:env", "//snowflake/ml/_internal:env_utils", "//snowflake/ml/_internal/utils:query_result_checker", + "//snowflake/ml/_internal/utils:snowflake_env", + "//snowflake/ml/utils:connection_params", ], ) @@ -83,7 +86,9 @@ py_library( testonly = True, srcs = ["spcs_integ_test_base.py"], deps = [ + ":test_env_utils", "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:snowflake_env", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/test_utils:db_manager", ], diff --git a/tests/integ/snowflake/ml/test_utils/common_test_base.py b/tests/integ/snowflake/ml/test_utils/common_test_base.py index 5695c2b2..11724c09 100644 --- a/tests/integ/snowflake/ml/test_utils/common_test_base.py +++ b/tests/integ/snowflake/ml/test_utils/common_test_base.py @@ -9,11 +9,10 @@ from packaging import requirements from typing_extensions import Concatenate, ParamSpec -from snowflake.ml._internal import env_utils, file_utils -from snowflake.ml.utils import connection_params +from snowflake.ml._internal import env, env_utils, file_utils from snowflake.snowpark import functions as F, session from snowflake.snowpark._internal import udf_utils, utils as snowpark_utils -from tests.integ.snowflake.ml.test_utils import _snowml_requirements +from tests.integ.snowflake.ml.test_utils import _snowml_requirements, test_env_utils _V = TypeVar("_V", bound="CommonTestBase") _T_args = ParamSpec("_T_args") @@ -40,11 +39,7 @@ def get_function_body(func: Callable[..., Any]) -> str: class CommonTestBase(parameterized.TestCase): def setUp(self) -> None: """Creates Snowpark and Snowflake environments for testing.""" - self.session = ( - session._get_active_session() - if snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call] # - else session.Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() - ) + self.session = test_env_utils.get_available_session() def tearDown(self) -> None: if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call] @@ -242,10 +237,12 @@ def {func_name}({first_arg_name}: snowflake.snowpark.Session, {", ".join(arg_lis actual_method(self, *args, **kwargs) additional_cases = [ - {"_snowml_pkg_ver": pkg_ver} - for pkg_ver in env_utils.get_matched_package_versions_in_snowflake_conda_channel( - req=requirements.Requirement(f"snowflake-ml-python{version_range}") - ) + {"_snowml_pkg_ver": str(pkg_ver)} + for pkg_ver in env_utils.get_matched_package_versions_in_information_schema( + test_env_utils.get_available_session(), + [requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}{version_range}")], + python_version=env.PYTHON_VERSION, + )[env_utils.SNOWPARK_ML_PKG_NAME] ] modified_test_cases = [{**t1, **t2} for t1 in test_cases for t2 in additional_cases] diff --git a/tests/integ/snowflake/ml/test_utils/spcs_integ_test_base.py b/tests/integ/snowflake/ml/test_utils/spcs_integ_test_base.py index 6d714034..256aa67a 100644 --- a/tests/integ/snowflake/ml/test_utils/spcs_integ_test_base.py +++ b/tests/integ/snowflake/ml/test_utils/spcs_integ_test_base.py @@ -1,27 +1,25 @@ +import unittest import uuid -from unittest import SkipTest from absl.testing import absltest +from snowflake.ml._internal.utils import snowflake_env from snowflake.ml.utils import connection_params from snowflake.snowpark import Session -from tests.integ.snowflake.ml.test_utils import db_manager +from tests.integ.snowflake.ml.test_utils import db_manager, test_env_utils +@unittest.skipUnless( + test_env_utils.get_current_snowflake_cloud_type() == snowflake_env.SnowflakeCloudType.AWS, + "SPCS only available in AWS", +) class SpcsIntegTestBase(absltest.TestCase): - _SNOWSERVICE_CONNECTION_NAME = "regtest" _TEST_CPU_COMPUTE_POOL = "REGTEST_INFERENCE_CPU_POOL" _TEST_GPU_COMPUTE_POOL = "REGTEST_INFERENCE_GPU_POOL" def setUp(self) -> None: """Creates Snowpark and Snowflake environments for testing.""" - try: - login_options = connection_params.SnowflakeLoginOptions(connection_name=self._SNOWSERVICE_CONNECTION_NAME) - except KeyError: - raise SkipTest( - "SnowService connection parameters not present: skipping " - "TestModelRegistryIntegWithSnowServiceDeployment." - ) + login_options = connection_params.SnowflakeLoginOptions() self._run_id = uuid.uuid4().hex[:2] self._test_db = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self._run_id, "db").upper() diff --git a/tests/integ/snowflake/ml/test_utils/test_env_utils.py b/tests/integ/snowflake/ml/test_utils/test_env_utils.py index f84c241a..74a066f1 100644 --- a/tests/integ/snowflake/ml/test_utils/test_env_utils.py +++ b/tests/integ/snowflake/ml/test_utils/test_env_utils.py @@ -1,70 +1,44 @@ import functools -import textwrap -from typing import List from packaging import requirements, version -import snowflake.connector from snowflake.ml._internal import env, env_utils -from snowflake.ml._internal.utils import query_result_checker +from snowflake.ml._internal.utils import snowflake_env +from snowflake.ml.utils import connection_params from snowflake.snowpark import session +from snowflake.snowpark._internal import utils as snowpark_utils -def get_current_snowflake_version(session: session.Session) -> version.Version: - res = session.sql("SELECT CURRENT_VERSION() AS CURRENT_VERSION").collect()[0] - version_str = res.CURRENT_VERSION - assert isinstance(version_str, str) - - version_str = "+".join(version_str.split()) - return version.parse(version_str) +def get_available_session() -> session.Session: + return ( + session._get_active_session() + if snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call] # + else session.Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() + ) @functools.lru_cache -def get_package_versions_in_server( - session: session.Session, - package_req_str: str, - python_version: str = env.PYTHON_VERSION, -) -> List[version.Version]: - package_req = requirements.Requirement(package_req_str) - parsed_python_version = version.Version(python_version) - sql = textwrap.dedent( - f""" - SELECT PACKAGE_NAME, VERSION - FROM information_schema.packages - WHERE package_name = '{package_req.name}' - AND language = 'python' - AND runtime_version = '{parsed_python_version.major}.{parsed_python_version.minor}'; - """ - ) +def get_current_snowflake_version() -> version.Version: + return snowflake_env.get_current_snowflake_version(get_available_session()) - version_list = [] - try: - result = ( - query_result_checker.SqlResultValidator( - session=session, - query=sql, - ) - .has_column("VERSION") - .has_dimensions(expected_rows=None, expected_cols=2) - .validate() - ) - for row in result: - req_ver = version.parse(row["VERSION"]) - version_list.append(req_ver) - except snowflake.connector.DataError: - return [] - available_version_list = list(package_req.specifier.filter(version_list)) - return available_version_list + +@functools.lru_cache +def get_current_snowflake_cloud_type() -> snowflake_env.SnowflakeCloudType: + sess = get_available_session() + region = snowflake_env.get_regions(sess)[snowflake_env.get_current_region_id(sess)] + return region["cloud"] @functools.lru_cache def get_latest_package_version_spec_in_server( - session: session.Session, + sess: session.Session, package_req_str: str, python_version: str = env.PYTHON_VERSION, ) -> str: package_req = requirements.Requirement(package_req_str) - available_version_list = get_package_versions_in_server(session, package_req_str, python_version) + available_version_list = env_utils.get_matched_package_versions_in_information_schema( + sess, [package_req], python_version + ).get(package_req.name, []) if len(available_version_list) == 0: return str(package_req) return f"{package_req.name}=={max(available_version_list)}" @@ -74,7 +48,7 @@ def get_latest_package_version_spec_in_server( def get_latest_package_version_spec_in_conda(package_req_str: str, python_version: str = env.PYTHON_VERSION) -> str: package_req = requirements.Requirement(package_req_str) available_version_list = env_utils.get_matched_package_versions_in_snowflake_conda_channel( - req=requirements.Requirement(package_req_str), python_version=python_version + package_req, python_version=python_version ) if len(available_version_list) == 0: return str(package_req) diff --git a/third_party/rules_mypy/BUILD.bazel b/third_party/rules_mypy/BUILD.bazel index 6f60a504..eda5ef8e 100644 --- a/third_party/rules_mypy/BUILD.bazel +++ b/third_party/rules_mypy/BUILD.bazel @@ -2,10 +2,9 @@ load("@rules_python//python:defs.bzl", "py_binary") package(default_visibility = ["//visibility:public"]) -exports_files(["mypy.sh.tpl"]) - py_binary( name = "mypy", srcs = ["main.py"], + legacy_create_init = 0, main = "main.py", ) diff --git a/third_party/rules_mypy/main.py b/third_party/rules_mypy/main.py index 262f860b..5f23779a 100644 --- a/third_party/rules_mypy/main.py +++ b/third_party/rules_mypy/main.py @@ -1,77 +1,11 @@ -import argparse -import json -import subprocess -import sys -import tempfile - -MYPY_ENTRYPOINT_CODE = """ import sys try: from mypy.main import main except ImportError as e: raise ImportError( - f"Unable to import mypy. Make sure mypy is added to the bazel conda environment. Actual error: {{e}}" + f"Unable to import mypy. Make sure mypy is added to the bazel conda environment. Actual error: {e}" ) if __name__ == "__main__": main(stdout=sys.stdout, stderr=sys.stderr) - -""" - - -def mypy_checker() -> None: - # To parse the arguments that bazel provides. - parser = argparse.ArgumentParser( - # Without this, the second path documented in main below fails. - fromfile_prefix_chars="@" - ) - parser.add_argument("--out") - parser.add_argument("--persistent_worker", action="store_true") - - args = parser.parse_args() - - with tempfile.NamedTemporaryFile(suffix=".py") as mypy_entrypoint: - mypy_entrypoint.write(MYPY_ENTRYPOINT_CODE.encode()) - mypy_entrypoint.flush() - first_run = True - while args.persistent_worker or first_run: - data = sys.stdin.readline() - req = json.loads(data) - mypy_args = req["arguments"] - process = subprocess.Popen( - # We use this to make sure we are invoking mypy that is installed in the same environment of the current - # Python. - [sys.executable, mypy_entrypoint.name] + mypy_args, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - process.wait() - text, _ = process.communicate() - - if process.returncode: - header = "=" * 20 + " MYPY TYPE CHECKING REPORT BEGIN " + "=" * 20 + "\n" - footer = "=" * 20 + " MYPY TYPE CHECKING REPORT END " + "=" * 20 + "\n" - - message = "".join([header, text.decode(), footer]) - else: - message = "" - - with open(args.out, "w") as output: - output.write(message) - sys.stderr.flush() - sys.stdout.write( - json.dumps( - { - "exitCode": process.returncode, - "output": message, - "requestId": req.get("requestId", 0), - } - ) - ) - sys.stdout.flush() - first_run = False - - -if __name__ == "__main__": - mypy_checker() diff --git a/third_party/rules_mypy/mypy.bzl b/third_party/rules_mypy/mypy.bzl index eea954b2..6cd64ea3 100644 --- a/third_party/rules_mypy/mypy.bzl +++ b/third_party/rules_mypy/mypy.bzl @@ -1,5 +1,7 @@ "Public API" +load("@bazel_skylib//lib:sets.bzl", "sets") +load("@bazel_skylib//lib:shell.bzl", "shell") load("@rules_mypy//:rules.bzl", "MyPyStubsInfo") MyPyAspectInfo = provider( @@ -33,8 +35,29 @@ DEFAULT_ATTRS = { default = Label("@//:mypy.ini"), allow_single_file = True, ), + "_template": attr.label( + default = Label("@rules_mypy//templates:mypy.sh.tpl"), + allow_single_file = True, + ), } +def _sources_to_cache_map_triples(srcs, is_aspect): + triples_as_flat_list = [] + for f in srcs: + if is_aspect: + f_path = f.path + else: + # "The path of this file relative to its root. This excludes the aforementioned root, i.e. configuration-specific fragments of the path. + # This is also the path under which the file is mapped if it's in the runfiles of a binary." + # - https://docs.bazel.build/versions/master/skylark/lib/File.html + f_path = f.short_path + triples_as_flat_list.extend([ + shell.quote(f_path), + shell.quote("{}.meta.json".format(f_path)), + shell.quote("{}.data.json".format(f_path)), + ]) + return triples_as_flat_list + def _is_external_dep(dep): return dep.label.workspace_root.startswith("external/") @@ -68,11 +91,28 @@ def _extract_stub_deps(deps): stub_files.append(src_f) return stub_files -def _mypy_rule_impl(ctx): - base_rule = ctx.rule +def _extract_imports(imports, label): + # NOTE: Bazel's implementation of this for py_binary, py_test is at + # src/main/java/com/google/devtools/build/lib/bazel/rules/python/BazelPythonSemantics.java + mypypath_parts = [] + for import_ in imports: + if import_.startswith("/"): + # buildifier: disable=print + print("ignoring invalid absolute path '{}'".format(import_)) + elif import_ in ["", "."]: + mypypath_parts.append(label.package) + else: + mypypath_parts.append("{}/{}".format(label.package, import_)) + return mypypath_parts + +def _mypy_rule_impl(ctx, is_aspect = False): + base_rule = ctx + if is_aspect: + base_rule = ctx.rule mypy_config_file = ctx.file._mypy_config + mypypath_parts = [] direct_src_files = [] transitive_srcs_depsets = [] stub_files = [] @@ -84,80 +124,99 @@ def _mypy_rule_impl(ctx): transitive_srcs_depsets = _extract_transitive_deps(base_rule.attr.deps) stub_files = _extract_stub_deps(base_rule.attr.deps) + if hasattr(base_rule.attr, "imports"): + mypypath_parts = _extract_imports(base_rule.attr.imports, ctx.label) + final_srcs_depset = depset(transitive = transitive_srcs_depsets + [depset(direct = direct_src_files)]) src_files = [f for f in final_srcs_depset.to_list() if not _is_external_src(f)] if not src_files: return None - out = ctx.actions.declare_file("%s_dummy_out" % ctx.rule.attr.name) - runfiles_name = "%s.mypy_runfiles" % ctx.rule.attr.name + mypypath_parts += [src_f.dirname for src_f in stub_files] + mypypath = ":".join(mypypath_parts) + + # Ideally, a file should be passed into this rule. If this is an executable + # rule, then we default to the implicit executable file, otherwise we create + # a stub. + if not is_aspect: + if hasattr(ctx, "outputs"): + exe = ctx.outputs.executable + else: + exe = ctx.actions.declare_file( + "%s_mypy_exe" % base_rule.attr.name, + ) + out = None + else: + out = ctx.actions.declare_file("%s_dummy_out" % ctx.rule.attr.name) + exe = ctx.actions.declare_file( + "%s_mypy_exe" % ctx.rule.attr.name, + ) # Compose a list of the files needed for use. Note that aspect rules can use # the project version of mypy however, other rules should fall back on their # relative runfiles. + runfiles = ctx.runfiles(files = src_files + stub_files + [mypy_config_file]) + if not is_aspect: + runfiles = runfiles.merge(ctx.attr._mypy_cli.default_runfiles) - src_run_files = [] - direct_src_run_files = [] - stub_run_files = [] - - for f in src_files + stub_files: - run_file_path = runfiles_name + "/" + f.short_path - run_file = ctx.actions.declare_file(run_file_path) - ctx.actions.symlink( - output = run_file, - target_file = f, - ) - if f in src_files: - src_run_files.append(run_file) - if f in direct_src_files: - direct_src_run_files.append(run_file) - if f in stub_files: - stub_run_files.append(run_file) - - src_root_path = src_run_files[0].path - src_root_path = src_root_path[0:(src_root_path.find(runfiles_name) + len(runfiles_name))] - - # arguments sent to mypy - args = ["--cache-dir", ctx.bin_dir.path + "/.mypy_cache", "--package-root", src_root_path, "--config-file", mypy_config_file.path] + [f.path for f in direct_src_run_files] - - worker_arg_file = ctx.actions.declare_file(ctx.rule.attr.name + ".worker_args") - ctx.actions.write( - output = worker_arg_file, - content = "\n".join(args), + src_root_paths = sets.to_list( + sets.make([f.root.path for f in src_files]), ) - return MyPyAspectInfo( - exe = ctx.executable._mypy_cli, - args = worker_arg_file, - runfiles = src_run_files + stub_run_files + [mypy_config_file, worker_arg_file], - out = out, + ctx.actions.expand_template( + template = ctx.file._template, + output = exe, + substitutions = { + "{CACHE_MAP_TRIPLES}": " ".join(_sources_to_cache_map_triples(src_files, is_aspect)), + "{MYPYPATH_PATH}": mypypath if mypypath else "", + "{MYPY_EXE}": ctx.executable._mypy_cli.path, + "{MYPY_INI_PATH}": mypy_config_file.path, + "{MYPY_ROOT}": ctx.executable._mypy_cli.root.path, + "{OUTPUT}": out.path if out else "", + "{PACKAGE_ROOTS}": " ".join([ + "--package-root " + shell.quote(path or ".") + for path in src_root_paths + ]), + "{SRCS}": " ".join([ + shell.quote(f.path) if is_aspect else shell.quote(f.short_path) + for f in src_files + ]), + "{VERBOSE_BASH}": "set -x" if DEBUG else "", + "{VERBOSE_OPT}": "--verbose" if DEBUG else "", + }, + is_executable = True, ) + if is_aspect: + return [ + DefaultInfo(executable = exe, runfiles = runfiles), + MyPyAspectInfo(exe = exe, out = out), + ] + return DefaultInfo(executable = exe, runfiles = runfiles) + def _mypy_aspect_impl(_, ctx): if (ctx.rule.kind not in ["py_binary", "py_library", "py_test", "mypy_test"] or ctx.label.workspace_root.startswith("external")): return [] - aspect_info = _mypy_rule_impl( + providers = _mypy_rule_impl( ctx, + is_aspect = True, ) - if not aspect_info: + if not providers: return [] + info = providers[0] + aspect_info = providers[1] + ctx.actions.run( outputs = [aspect_info.out], - inputs = aspect_info.runfiles, - tools = [aspect_info.exe], + inputs = info.default_runfiles.files, + tools = [ctx.executable._mypy_cli], executable = aspect_info.exe, mnemonic = "MyPy", progress_message = "Type-checking %s" % ctx.label, - execution_requirements = { - "requires-worker-protocol": "json", - "supports-workers": "1", - }, - # out is required for worker to write the output. - arguments = ["--out", aspect_info.out.path, "@" + aspect_info.args.path], use_default_shell_env = True, ) return [ @@ -166,8 +225,21 @@ def _mypy_aspect_impl(_, ctx): ), ] +def _mypy_test_impl(ctx): + info = _mypy_rule_impl(ctx, is_aspect = False) + if not info: + fail("A list of python deps are required for mypy_test") + return info + mypy_aspect = aspect( implementation = _mypy_aspect_impl, attr_aspects = ["deps"], attrs = DEFAULT_ATTRS, ) + +mypy_test = rule( + implementation = _mypy_test_impl, + test = True, + attrs = dict(DEFAULT_ATTRS.items() + + [("deps", attr.label_list(aspects = [mypy_aspect]))]), +) diff --git a/third_party/rules_mypy/templates/BUILD.bazel b/third_party/rules_mypy/templates/BUILD.bazel new file mode 100644 index 00000000..820e9a31 --- /dev/null +++ b/third_party/rules_mypy/templates/BUILD.bazel @@ -0,0 +1 @@ +exports_files(["mypy.sh.tpl"]) diff --git a/third_party/rules_mypy/templates/mypy.sh.tpl b/third_party/rules_mypy/templates/mypy.sh.tpl new file mode 100644 index 00000000..4ba83e73 --- /dev/null +++ b/third_party/rules_mypy/templates/mypy.sh.tpl @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +{VERBOSE_BASH} +set -o errexit +set -o nounset +set -o pipefail + +main() { + local output + local report_file + local status + local root + local mypy + + report_file="{OUTPUT}" + root="{MYPY_ROOT}/" + mypy="{MYPY_EXE}" + + export MYPYPATH="$(pwd):{MYPYPATH_PATH}" + + # Workspace rules run in a different location from aspect rules. Here we + # normalize if the external source isn't found. + if [ ! -f $mypy ]; then + mypy=${mypy#${root}} + fi + + # We need the return code of mypy. + set +o errexit + output=$($mypy {VERBOSE_OPT} --bazel {PACKAGE_ROOTS} --config-file {MYPY_INI_PATH} --cache-map {CACHE_MAP_TRIPLES} -- {SRCS} 2>&1) + status=$? + set -o errexit + + if [ ! -z "$report_file" ]; then + echo "${output}" > "${report_file}" + fi + + if [[ $status -ne 0 ]]; then + echo "${output}" # Show MyPy's error to end-user via Bazel's console logging + exit 1 + fi + +} + +main "$@"