diff --git a/.bazelrc b/.bazelrc index f4ea3bf0..22d4e20a 100644 --- a/.bazelrc +++ b/.bazelrc @@ -25,6 +25,7 @@ build:_extended_gpu_oss --platforms //bazel/platforms:extended_conda_gpu_env --h build:py3.8 --repo_env=BAZEL_CONDA_PYTHON_VERSION=3.8 build:py3.9 --repo_env=BAZEL_CONDA_PYTHON_VERSION=3.9 build:py3.10 --repo_env=BAZEL_CONDA_PYTHON_VERSION=3.10 +build:py3.11 --repo_env=BAZEL_CONDA_PYTHON_VERSION=3.11 # Default diff --git a/CHANGELOG.md b/CHANGELOG.md index c8e1429f..4e4da97a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,28 @@ # Release History +## 1.2.3 + +### Bug Fixes + +- Registry: Now when providing Decimal Type column to a DOUBLE or FLOAT feature will not error out but auto cast with + warnings. +- Registry: Improve the error message when specifying currently unsupported `pip_requirements` argument. +- Model Development: Fix precision_recall_fscore_support incorrect results when `average="samples"`. +- Model Registry: Fix an issue that leads to description, metrics or tags are not correctly returned in newly created + Model Registry (PrPr) due to Snowflake BCR [2024_01]( + https://docs.snowflake.com/en/release-notes/bcr-bundles/2024_01/bcr-1483) + +### Behavior Changes + +- Feature Store: `FeatureStore.suspend_feature_view` and `FeatureStore.resume_feature_view` doesn't mutate input feature + view argument any more. The updated status only reflected in the returned feature view object. + +### New Features + +- Model Development: support `score_samples` method for all the classes, including Pipeline, + GridSearchCV, RandomizedSearchCV, PCA, IsolationForest, ... +- Registry: Support deleting a version of a model. + ## 1.2.2 ### Bug Fixes diff --git a/README.md b/README.md index af37ef3a..0d75f737 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ If you don't have a Snowflake account yet, you can [sign up for a 30-day free tr Follow the [installation instructions](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#installing-snowpark-ml) in the Snowflake documentation. -Python versions 3.8, 3.9 & 3.10 are supported. You can use [miniconda](https://docs.conda.io/en/latest/miniconda.html) or +Python versions 3.8 to 3.11 are supported. You can use [miniconda](https://docs.conda.io/en/latest/miniconda.html) or [anaconda](https://www.anaconda.com/) to create a Conda environment (recommended), or [virtualenv](https://docs.python.org/3/tutorial/venv.html) to create a virtual environment. diff --git a/WORKSPACE b/WORKSPACE index 83e0e3ba..78d0bdae 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -4,9 +4,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_jar") http_jar( name = "bazel_diff", - sha256 = "9c4546623a8b9444c06370165ea79a897fcb9881573b18fa5c9ee5c8ba0867e2", + sha256 = "eca2d221f5c3ec9545c841ed62d319bbb59e447a1ceade563bc8f8e1b9186a34", urls = [ - "https://github.com/Tinder/bazel-diff/releases/download/4.3.0/bazel-diff_deploy.jar", + "https://github.com/Tinder/bazel-diff/releases/download/5.0.1/bazel-diff_deploy.jar", ], ) diff --git a/bazel/get_affected_targets.sh b/bazel/get_affected_targets.sh index 166cfaf6..13d78f2b 100755 --- a/bazel/get_affected_targets.sh +++ b/bazel/get_affected_targets.sh @@ -69,38 +69,34 @@ trap 'rm -rf "${working_dir}"' EXIT starting_hashes_json="${working_dir}/starting_hashes.json" final_hashes_json="${working_dir}/final_hashes.json" impacted_targets_path="${working_dir}/impacted_targets.txt" -bazel_diff="${working_dir}/bazel_diff" -seed_file="${working_dir}/bazel_diff_seed" -ci_hash_file="${working_dir}/ci_scripts_hash" - -cat <"${seed_file}" -${ci_hash_file} -${workspace_path}/requirements.yml -SeedFileContent - -"${bazel}" run --config=pre_build :bazel-diff --script_path="${bazel_diff}" +ci_hash_file_pr="${working_dir}/ci_hash_file_pr" +ci_hash_file_base="${working_dir}/ci_hash_file_base" git -C "${workspace_path}" checkout "${pr_revision}" --quiet trap 'git -C "${workspace_path}" checkout "${current_revision}" --quiet' EXIT echo "Generating Hashes for Revision '${pr_revision}'" -git ls-files -s "${workspace_path}/ci" | git hash-object --stdin >"${ci_hash_file}" +git ls-files -s "${workspace_path}/ci" "${workspace_path}/bazel" | git hash-object --stdin >"${ci_hash_file_pr}" -"${bazel_diff}" generate-hashes -w "${workspace_path}" -b "${bazel}" -s "${seed_file}" "${final_hashes_json}" +"${bazel}" run --config=pre_build :bazel-diff -- generate-hashes -w "${workspace_path}" -b "${bazel}" "${final_hashes_json}" MERGE_BASE_MAIN=$(git merge-base "${pr_revision}" main) git -C "${workspace_path}" checkout "${MERGE_BASE_MAIN}" --quiet echo "Generating Hashes for merge base ${MERGE_BASE_MAIN}" -git ls-files -s "${workspace_path}/ci" | git hash-object --stdin >"${ci_hash_file}" +git ls-files -s "${workspace_path}/ci" "${workspace_path}/bazel" | git hash-object --stdin >"${ci_hash_file_base}" -$"${bazel_diff}" generate-hashes -w "${workspace_path}" -b "${bazel}" -s "${seed_file}" "${starting_hashes_json}" +"${bazel}" run --config=pre_build :bazel-diff -- generate-hashes -w "${workspace_path}" -b "${bazel}" "${starting_hashes_json}" git -C "${workspace_path}" checkout "${pr_revision}" --quiet echo "Determining Impacted Targets and output to ${output_path}" -$"${bazel_diff}" get-impacted-targets -sh "${starting_hashes_json}" -fh "${final_hashes_json}" -o "${impacted_targets_path}" +"${bazel}" run --config=pre_build :bazel-diff -- get-impacted-targets -sh "${starting_hashes_json}" -fh "${final_hashes_json}" -o "${impacted_targets_path}" + +if ! cmp -s "$ci_hash_file_pr" "$ci_hash_file_base"; then + echo '//...' >> "${impacted_targets_path}" +fi filter_query_rules_file="${working_dir}/filter_query_rules" diff --git a/bazel/requirements/templates/meta.tpl.yaml b/bazel/requirements/templates/meta.tpl.yaml index 2ddac5b8..6eab3374 100644 --- a/bazel/requirements/templates/meta.tpl.yaml +++ b/bazel/requirements/templates/meta.tpl.yaml @@ -13,7 +13,7 @@ requirements: - python - bazel >=6.0.0 run: - - python>=3.8,<3.11 + - python>=3.8,<3.12 run_constrained: - pytorch<2.1.0 # [win] diff --git a/bazel/requirements/templates/pyproject.toml b/bazel/requirements/templates/pyproject.toml index 60ada30c..7ca3a073 100644 --- a/bazel/requirements/templates/pyproject.toml +++ b/bazel/requirements/templates/pyproject.toml @@ -22,6 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Database", "Topic :: Software Development", "Topic :: Software Development :: Libraries", @@ -29,7 +30,7 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Scientific/Engineering :: Information Analysis" ] -requires-python = ">=3.8, <4" +requires-python = ">=3.8, <3.12" dynamic = ["version", "readme"] [project.urls] diff --git a/ci/build_and_run_tests.sh b/ci/build_and_run_tests.sh index e6fe56ca..53f4d51e 100755 --- a/ci/build_and_run_tests.sh +++ b/ci/build_and_run_tests.sh @@ -40,6 +40,7 @@ WITH_SNOWPARK=false MODE="continuous_run" PYTHON_VERSION=3.8 PYTHON_JENKINS_ENABLE="/opt/rh/rh-python38/enable" +PYTHON_ENABLE_SCRIPT="bin/activate" SNOWML_DIR="snowml" SNOWPARK_DIR="snowpark-python" IS_NT=false @@ -96,21 +97,6 @@ while (($#)); do shift done -case ${PYTHON_VERSION} in - 3.8) - PYTHON_EXECUTABLE="python3.8" - PYTHON_JENKINS_ENABLE="/opt/rh/rh-python38/enable" - ;; - 3.9) - PYTHON_EXECUTABLE="python3.9" - PYTHON_JENKINS_ENABLE="/opt/rh/rh-python39/enable" - ;; - 3.10) - PYTHON_EXECUTABLE="python3.10" - PYTHON_JENKINS_ENABLE="/opt/rh/rh-python310/enable" - ;; -esac - echo "Running build_and_run_tests with PYTHON_VERSION ${PYTHON_VERSION}" EXT="" @@ -148,9 +134,53 @@ case "${PLATFORM}_${ARCH}" in ;; esac -# Verify that the requested python version exists -# TODO(SNOW-845592): ideally we should download python from conda if it's not present. Currently we just fail. -if [ "${ENV}" = "pip" ]; then +if [ ${IS_NT} = true ]; then + EXT=".exe" + PYTHON_ENABLE_SCRIPT="Scripts/activate" + BAZEL_ADDITIONAL_BUILD_FLAGS+=(--nobuild_python_zip) + BAZEL_ADDITIONAL_BUILD_FLAGS+=(--enable_runfiles) + BAZEL_ADDITIONAL_BUILD_FLAGS+=(--action_env="USERPROFILE=${USERPROFILE}") + BAZEL_ADDITIONAL_BUILD_FLAGS+=(--host_action_env="USERPROFILE=${USERPROFILE}") + BAZEL_ADDITIONAL_STARTUP_FLAGS+=(--output_user_root=C:/broot) +fi + +case ${PYTHON_VERSION} in + 3.8) + if [ ${IS_NT} = true ]; then + PYTHON_EXECUTABLE="py -3.8" + else + PYTHON_EXECUTABLE="python3.8" + fi + PYTHON_JENKINS_ENABLE="/opt/rh/rh-python38/enable" + ;; + 3.9) + if [ ${IS_NT} = true ]; then + PYTHON_EXECUTABLE="py -3.9" + else + PYTHON_EXECUTABLE="python3.9" + fi + PYTHON_JENKINS_ENABLE="/opt/rh/rh-python39/enable" + ;; + 3.10) + if [ ${IS_NT} = true ]; then + PYTHON_EXECUTABLE="py -3.10" + else + PYTHON_EXECUTABLE="python3.10" + fi + PYTHON_JENKINS_ENABLE="/opt/rh/rh-python310/enable" + ;; + 3.11) + if [ ${IS_NT} = true ]; then + PYTHON_EXECUTABLE="py -3.11" + else + PYTHON_EXECUTABLE="python3.11" + fi + PYTHON_JENKINS_ENABLE="/opt/rh/rh-python311/enable" + ;; +esac + +# TODO(SNOW-901629): Use native python provided in the node once SNOW-1046060 resolved +if [[ "${ENV}" = "pip" && ${IS_NT} = false ]]; then set +eu # shellcheck source=/dev/null source ${PYTHON_JENKINS_ENABLE} @@ -163,13 +193,6 @@ if [ "${ENV}" = "pip" ]; then set -eu fi -if [ ${IS_NT} = true ]; then - EXT=".exe" - BAZEL_ADDITIONAL_BUILD_FLAGS+=(--nobuild_python_zip) - BAZEL_ADDITIONAL_BUILD_FLAGS+=(--enable_runfiles) - BAZEL_ADDITIONAL_STARTUP_FLAGS+=(--output_user_root=D:/broot) -fi - cd "${WORKSPACE}" # Check and download yq if not presented. @@ -210,12 +233,7 @@ pushd "${TEMP_TEST_DIR}" rsync -av --exclude-from "${EXCLUDE_TESTS}" "../${SNOWML_DIR}/tests" . popd -# Bazel on windows is consuming a lot of memory, let's clean it before proceed to avoid OOM. -if [ ${IS_NT} = true ]; then - "${BAZEL}" "${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]+"${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]}"}" clean --expunge - "${BAZEL}" "${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]+"${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]}"}" shutdown -fi - +"${BAZEL}" "${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]+"${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]}"}" clean --expunge popd # Build snowml package @@ -228,10 +246,12 @@ if [ "${ENV}" = "pip" ]; then pushd ${SNOWPARK_DIR} rm -rf venv ${PYTHON_EXECUTABLE} -m venv venv - source venv/bin/activate - ${PYTHON_EXECUTABLE} -m pip install -U pip setuptools wheel + # shellcheck disable=SC1090 + source "venv/${PYTHON_ENABLE_SCRIPT}" + python --version + python -m pip install -U pip setuptools wheel echo "Building snowpark wheel from main:$(git rev-parse HEAD)." - pip wheel . --no-deps + python -m pip wheel . --no-deps cp "$(find . -maxdepth 1 -iname 'snowflake_snowpark_python-*.whl')" "${WORKSPACE}" deactivate popd @@ -240,7 +260,7 @@ if [ "${ENV}" = "pip" ]; then # Build SnowML pushd ${SNOWML_DIR} "${BAZEL}" "${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]+"${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]}"}" build "${BAZEL_ADDITIONAL_BUILD_FLAGS[@]+"${BAZEL_ADDITIONAL_BUILD_FLAGS[@]}"}" //:wheel - cp "$(${BAZEL} info bazel-bin)/dist/snowflake_ml_python-${VERSION}-py3-none-any.whl" "${WORKSPACE}" + cp "$("${BAZEL}" "${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]+"${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]}"}" info bazel-bin)/dist/snowflake_ml_python-${VERSION}-py3-none-any.whl" "${WORKSPACE}" popd else # Clean conda cache @@ -284,21 +304,23 @@ if [ "${ENV}" = "pip" ]; then # Create testing env ${PYTHON_EXECUTABLE} -m venv testenv - source testenv/bin/activate + # shellcheck disable=SC1090 + source "testenv/${PYTHON_ENABLE_SCRIPT}" # Install all of the packages in single line, # otherwise it will fail in dependency resolution. - ${PYTHON_EXECUTABLE} -m pip install --upgrade pip - ${PYTHON_EXECUTABLE} -m pip list - ${PYTHON_EXECUTABLE} -m pip install "snowflake_ml_python-${VERSION}-py3-none-any.whl[all]" "pytest-xdist[psutil]==2.5.0" -r "${WORKSPACE}/${SNOWML_DIR}/requirements.txt" --no-cache-dir --force-reinstall + python --version + python -m pip install --upgrade pip + python -m pip list + python -m pip install "snowflake_ml_python-${VERSION}-py3-none-any.whl[all]" "pytest-xdist[psutil]==2.5.0" -r "${WORKSPACE}/${SNOWML_DIR}/requirements.txt" --no-cache-dir --force-reinstall if [ "${WITH_SNOWPARK}" = true ]; then cp "$(find "${WORKSPACE}" -maxdepth 1 -iname 'snowflake_snowpark_python-*.whl')" "${TEMP_TEST_DIR}" - ${PYTHON_EXECUTABLE} -m pip install "$(find . -maxdepth 1 -iname 'snowflake_snowpark_python-*.whl')" --no-deps --force-reinstall + python -m pip install "$(find . -maxdepth 1 -iname 'snowflake_snowpark_python-*.whl')" --no-deps --force-reinstall fi - ${PYTHON_EXECUTABLE} -m pip list + python -m pip list # Run the tests set +e - TEST_SRCDIR="${TEMP_TEST_DIR}" ${PYTHON_EXECUTABLE} -m pytest "${COMMON_PYTEST_FLAG[@]}" -m "not pip_incompatible" tests/integ/ + TEST_SRCDIR="${TEMP_TEST_DIR}" python -m pytest "${COMMON_PYTEST_FLAG[@]}" -m "not pip_incompatible" tests/integ/ TEST_RETCODE=$? set -e else diff --git a/ci/conda_recipe/README.md b/ci/conda_recipe/README.md index af412a73..54824f7d 100644 --- a/ci/conda_recipe/README.md +++ b/ci/conda_recipe/README.md @@ -6,7 +6,7 @@ Conda's guide on building a conda package from a wheel: To invoke conda build: ```sh -conda build --prefix-length=0 --python=[3.8|3.9|3.10] ci/conda_recipe +conda build --prefix-length=0 --python=[3.8|3.9|3.10|3.11] ci/conda_recipe ``` - `--prefix-length=0`: prevent the conda build environment from being created in diff --git a/ci/conda_recipe/conda_build_config.yaml b/ci/conda_recipe/conda_build_config.yaml index afb39f11..00ccdcbf 100644 --- a/ci/conda_recipe/conda_build_config.yaml +++ b/ci/conda_recipe/conda_build_config.yaml @@ -3,3 +3,4 @@ python: - 3.8 - 3.9 - 3.10 + - 3.11 diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index 20347c94..486baec4 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -17,7 +17,7 @@ build: noarch: python package: name: snowflake-ml-python - version: 1.2.2 + version: 1.2.3 requirements: build: - python @@ -26,7 +26,7 @@ requirements: - absl-py>=0.15,<2 - aiohttp!=4.0.0a0, !=4.0.0a1 - anyio>=3.5.0,<4 - - cachetools>=3.1.1,<5 + - cachetools>=3.1.1,<6 - cloudpickle>=2.0.0 - fsspec>=2022.11,<2024 - importlib_resources>=5.1.4, <6 @@ -46,7 +46,7 @@ requirements: - sqlparse>=0.4,<1 - typing-extensions>=4.1.0,<5 - xgboost>=1.7.3,<2 - - python>=3.8,<3.11 + - python>=3.8,<3.12 run_constrained: - lightgbm==3.3.5 - mlflow>=2.1.0,<2.4 diff --git a/codegen/sklearn_wrapper_generator.py b/codegen/sklearn_wrapper_generator.py index 637a7317..793aaf37 100644 --- a/codegen/sklearn_wrapper_generator.py +++ b/codegen/sklearn_wrapper_generator.py @@ -583,6 +583,7 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None: self.fit_predict_docstring = "" self.fit_transform_docstring = "" self.predict_proba_docstring = "" + self.score_samples_docstring = "" self.score_docstring = "" self.predict_log_proba_docstring = "" self.decision_function_docstring = "" @@ -728,6 +729,7 @@ def _populate_function_doc_fields(self) -> None: "transform", "score", "kneighbors", + "score_samples", ] _CLASS_FUNC = {name: func for name, func in inspect.getmembers(self.class_object[1])} for _each_method in _METHODS: @@ -757,6 +759,7 @@ def _populate_function_doc_fields(self) -> None: self.predict_docstring = self.estimator_function_docstring["predict"] self.fit_predict_docstring = self.estimator_function_docstring["fit_predict"] self.predict_proba_docstring = self.estimator_function_docstring["predict_proba"] + self.score_samples_docstring = self.estimator_function_docstring["score_samples"] self.predict_log_proba_docstring = self.estimator_function_docstring["predict_log_proba"] self.decision_function_docstring = self.estimator_function_docstring["decision_function"] self.score_docstring = self.estimator_function_docstring["score"] diff --git a/codegen/sklearn_wrapper_template.py_template b/codegen/sklearn_wrapper_template.py_template index dfed46b6..068a7646 100644 --- a/codegen/sklearn_wrapper_template.py_template +++ b/codegen/sklearn_wrapper_template.py_template @@ -601,6 +601,38 @@ class {transform.original_class_name}(BaseTransformer): return output_df + @available_if(original_estimator_has_callable("score_samples")) # type: ignore[misc] + @telemetry.send_api_usage_telemetry( + project=_PROJECT, + subproject=_SUBPROJECT, + custom_tags=dict([("autogen", True)]), + ) + def score_samples( + self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "score_samples_" + ) -> Union[DataFrame, pd.DataFrame]: + """{transform.score_samples_docstring} + output_cols_prefix: Prefix for the response columns + + Returns: + Output dataset with probability of the sample for each class in the model. + """ + super()._check_dataset_type(dataset) + if isinstance(dataset, DataFrame): + output_df = self._batch_inference( + dataset=dataset, + inference_method="score_samples", + expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + expected_output_cols_type="float" + ) + elif isinstance(dataset, pd.DataFrame): + output_df = self._sklearn_inference( + dataset=dataset, + inference_method="score_samples", + expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + ) + + return output_df + @available_if(original_estimator_has_callable("score")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( project=_PROJECT, diff --git a/codegen/transformer_autogen_test_template.py_template b/codegen/transformer_autogen_test_template.py_template index 488aa272..94f8d3a0 100644 --- a/codegen/transformer_autogen_test_template.py_template +++ b/codegen/transformer_autogen_test_template.py_template @@ -183,7 +183,7 @@ class {transform.test_class_name}(TestCase): else: np.testing.assert_allclose(actual_arr, sklearn_numpy_arr, rtol=1.e-1, atol=1.e-2) - expected_methods = ["predict_proba", "predict_log_proba", "decision_function", "kneighbors"] + expected_methods = ["predict_proba", "predict_log_proba", "decision_function", "kneighbors", "score_samples"] for m in expected_methods: assert not ( callable(getattr(sklearn_reg, m, None)) @@ -234,6 +234,14 @@ class {transform.test_class_name}(TestCase): np.testing.assert_allclose( actual_inference_result[:, 0, :], sklearn_inference_result[:, 0, :], rtol=1.e-1, atol=1.e-2 ) + elif ( + m == "score_samples" + and reg.__class__.__name__ == 'BernoulliRBM' + ): + # score_samples is not deterministic for BernoulliRBM: + # it computes a quantity called the free energy on X, + # then on a randomly corrupted version of X, and returns the log of the logistic function of the difference. + assert actual_inference_result.shape == sklearn_inference_result.shape else: np.testing.assert_allclose( actual_inference_result, sklearn_inference_result, rtol=1.e-1, atol=1.e-2 diff --git a/requirements.txt b/requirements.txt index a0a23af1..8b2074ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,7 +45,7 @@ snowflake-snowpark-python==1.8.0 sphinx==5.0.2 sqlparse==0.4.4 starlette==0.27.0 -tensorflow==2.13.1 +tensorflow==2.13.0 tokenizers==0.13.2 toml==0.10.2 torch==2.0.1 diff --git a/requirements.yml b/requirements.yml index 24bd26cc..80eafd96 100644 --- a/requirements.yml +++ b/requirements.yml @@ -255,7 +255,7 @@ version_requirements: '>=0.4,<1' - name: tensorflow dev_version_conda: 2.10.0 - dev_version_pypi: 2.13.1 + dev_version_pypi: 2.13.0 version_requirements: '>=2.9,<3,!=2.12.0' requirements_extra_tags: - tensorflow @@ -306,7 +306,7 @@ from_channel: conda-forge - name: cachetools dev_version: 4.2.2 - version_requirements: '>=3.1.1,<5' + version_requirements: '>=3.1.1,<6' - name: pytimeparse dev_version: 1.1.8 version_requirements: '>=1.1.8,<2' diff --git a/snowflake/ml/feature_store/_internal/scripts/install-snowpark-ml-conda.sh b/snowflake/ml/feature_store/_internal/scripts/install-snowpark-ml-conda.sh index 969a28df..4eb55c8f 100755 --- a/snowflake/ml/feature_store/_internal/scripts/install-snowpark-ml-conda.sh +++ b/snowflake/ml/feature_store/_internal/scripts/install-snowpark-ml-conda.sh @@ -3,7 +3,7 @@ # Setup a conda environment & installs snowpark ML. # # Usage -# install-snowpark-ml-conda.sh [-d ] [-n ] [-p 3.8|3.9|3.10] [-h] +# install-snowpark-ml-conda.sh [-d ] [-n ] [-p ] [-h] set -o pipefail set -eu @@ -36,9 +36,9 @@ DEFAULT_FILENAME=$(dirname "$PROG")/snowflake-ml-python-1.0.12-fs-0.2.0-conda.zi function help() { exitcode=$1 && shift - echo "Usage: ${PROG} [-d ] [-n ] [-p 3.8|3.9|3.10] [-h]" + echo "Usage: ${PROG} [-d ] [-n ] [-p ] [-h]" echo " -d OUTPUT_DIR: Optional, default is ${CHANNEL_HOME}" - echo " -p PY_VERSION: Optional, default is 3.8. Options are 3.9, 3.10." + echo " -p PY_VERSION: Optional, default is 3.8." if [ "${CONDA_DEFAULT_ENV-}" ]; then echo " -n CONDA_ENV_NAME: Optional, default is \`${CONDA_DEFAULT_ENV}\` (current environment). If an existing env provided, it will reuse. It will create otherwise." else @@ -59,12 +59,7 @@ while (($#)); do ;; -p) shift - if [[ $1 = "3.8" || $1 = "3.9" || $1 == "3.10" ]]; then - PY_VERSION=$1 - else - echo "Invalid python version: $1" - help 1 - fi + PY_VERSION=$1 ;; -h|--help) help 0 diff --git a/snowflake/ml/feature_store/entity.py b/snowflake/ml/feature_store/entity.py index d88546a6..3140e900 100644 --- a/snowflake/ml/feature_store/entity.py +++ b/snowflake/ml/feature_store/entity.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Optional from snowflake.ml._internal.utils.sql_identifier import ( SqlIdentifier, @@ -35,7 +35,8 @@ def __init__(self, name: str, join_keys: List[str], desc: str = "") -> None: self.name: SqlIdentifier = SqlIdentifier(name) self.join_keys: List[SqlIdentifier] = to_sql_identifiers(join_keys) - self.desc = desc + self.owner: Optional[str] = None + self.desc: str = desc def _validate(self, name: str, join_keys: List[str]) -> None: if len(name) > _ENTITY_NAME_LENGTH_LIMIT: @@ -62,6 +63,12 @@ def _to_dict(self) -> Dict[str, str]: entity_dict[k] = str(v) return entity_dict + @staticmethod + def _construct_entity(name: str, join_keys: List[str], desc: str, owner: str) -> "Entity": + e = Entity(name, join_keys, desc) + e.owner = owner + return e + def __repr__(self) -> str: states = (f"{k}={v}" for k, v in vars(self).items()) return f"{type(self).__name__}({', '.join(states)})" @@ -70,4 +77,9 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, Entity): return False - return self.name == other.name and self.desc == other.desc and self.join_keys == other.join_keys + return ( + self.name == other.name + and self.desc == other.desc + and self.join_keys == other.join_keys + and self.owner == other.owner + ) diff --git a/snowflake/ml/feature_store/feature_store.py b/snowflake/ml/feature_store/feature_store.py index e130b650..2517dce3 100644 --- a/snowflake/ml/feature_store/feature_store.py +++ b/snowflake/ml/feature_store/feature_store.py @@ -87,12 +87,17 @@ def switch_warehouse( @functools.wraps(f) def wrapper(self: FeatureStore, /, *args: _Args.args, **kargs: _Args.kwargs) -> _RT: original_warehouse = self._session.get_current_warehouse() + if original_warehouse is not None: + original_warehouse = SqlIdentifier(original_warehouse) + warehouse_updated = False try: - if self._default_warehouse is not None: + if original_warehouse != self._default_warehouse: self._session.use_warehouse(self._default_warehouse) + warehouse_updated = True return f(self, *args, **kargs) finally: - self._session.use_warehouse(original_warehouse) # type: ignore[arg-type] + if warehouse_updated and original_warehouse is not None: + self._session.use_warehouse(original_warehouse) return wrapper @@ -278,10 +283,15 @@ def register_feature_view( """ Materialize a FeatureView to Snowflake backend. Incremental maintenance for updates on the source data will be automated if refresh_freq is set. - NOTE: Each new materialization will trigger a full FeatureView history refresh for the data included in the FeatureView. + Examples: + ... + draft_fv = FeatureView(name="my_fv", entities=[entities], feature_df) + registered_fv = fs.register_feature_view(feature_view=draft_fv, version="v1") + ... + Args: feature_view: FeatureView instance to materialize. version: version of the registered FeatureView. @@ -306,12 +316,13 @@ def register_feature_view( version = FeatureViewVersion(version) if feature_view.status != FeatureViewStatus.DRAFT: - raise snowml_exceptions.SnowflakeMLException( - error_code=error_codes.OBJECT_ALREADY_EXISTS, - original_exception=ValueError( - f"FeatureView {feature_view.name}/{feature_view.version} has already been registered." - ), + warnings.warn( + f"FeatureView {feature_view.name}/{feature_view.version} has already been registered. " + + "Skipping registration.", + stacklevel=2, + category=UserWarning, ) + return feature_view # TODO: ideally we should move this to FeatureView creation time for e in feature_view.entities: @@ -472,8 +483,9 @@ def list_feature_views( fvs = self._find_feature_views(entity_name, feature_view_name) else: fvs = [] + entities = self.list_entities().collect() for row in self._get_fv_backend_representations(feature_view_name, prefix_match=True): - fvs.append(self._compose_feature_view(row)) + fvs.append(self._compose_feature_view(row, entities)) if as_dataframe: result = None @@ -511,7 +523,7 @@ def get_feature_view(self, name: str, version: str) -> FeatureView: original_exception=ValueError(f"Failed to find FeatureView {name}/{version}: {results}"), ) - return self._compose_feature_view(results[0]) + return self._compose_feature_view(results[0], self.list_entities().collect()) @dispatch_decorator(prpr_version="1.0.8") def merge_features( @@ -620,7 +632,7 @@ def resume_feature_view(self, feature_view: FeatureView) -> FeatureView: feature_view: FeatureView to resume. Returns: - FeatureView with updated status. + A new feature view with updated status. Raises: SnowflakeMLException: [ValueError] FeatureView is not in suspended status. @@ -646,7 +658,7 @@ def suspend_feature_view(self, feature_view: FeatureView) -> FeatureView: feature_view: FeatureView to suspend. Returns: - FeatureView with updated status. + A new feature view with updated status. Raises: SnowflakeMLException: [ValueError] FeatureView is not in running status. @@ -712,6 +724,7 @@ def list_entities(self) -> DataFrame: F.col('"name"').substr(prefix_len, _ENTITY_NAME_LENGTH_LIMIT).alias("NAME"), F.col('"allowed_values"').alias("JOIN_KEYS"), F.col('"comment"').alias("DESC"), + F.col('"owner"').alias("OWNER"), ), ) @@ -747,10 +760,12 @@ def get_entity(self, name: str) -> Entity: raw_join_keys = result[0]["JOIN_KEYS"] join_keys = raw_join_keys.strip("[]").split(",") - return Entity( - name=result[0]["NAME"], + + return Entity._construct_entity( + name=SqlIdentifier(result[0]["NAME"], case_sensitive=True).identifier(), join_keys=join_keys, desc=result[0]["DESC"], + owner=result[0]["OWNER"], ) @dispatch_decorator(prpr_version="1.0.8") @@ -1339,7 +1354,7 @@ def join_cols(cols: List[SqlIdentifier], end_comma: bool, rename: bool, prefix: {f_ts_col} {s_ts_col}, {join_keys_str}, {join_cols(s_only_cols, end_comma=True, rename=False, prefix='null AS ')} - {join_cols(f_only_cols,end_comma=False, rename=False)} + {join_cols(f_only_cols, end_comma=False, rename=False)} FROM {f_table_name}""" union_cte = f""" unioned_{layer} AS ( @@ -1422,9 +1437,8 @@ def _update_feature_view_status(self, feature_view: FeatureView, operation: str) original_exception=RuntimeError(f"Failed to update feature view {fully_qualified_name}'s status: {e}"), ) from e - feature_view._status = self.get_feature_view(feature_view.name, feature_view.version).status logger.info(f"Successfully {operation} FeatureView {feature_view.name}/{feature_view.version}.") - return feature_view + return self.get_feature_view(feature_view.name, feature_view.version) def _find_feature_views( self, entity_name: SqlIdentifier, feature_view_name: Optional[SqlIdentifier] @@ -1464,6 +1478,8 @@ def _find_feature_views( error_code=error_codes.INTERNAL_SNOWPARK_ERROR, original_exception=RuntimeError(f"Failed to retrieve feature views' information: {e}"), ) from e + + entities = self.list_entities().collect() outputs = [] for r in results: if entity_name == SqlIdentifier(r["TAG_VALUE"], case_sensitive=True): @@ -1472,14 +1488,25 @@ def _find_feature_views( obj_name = SqlIdentifier(r["OBJECT_NAME"], case_sensitive=True) if feature_view_name is not None: if fv_name == feature_view_name: - outputs.append(self._compose_feature_view(fv_maps[obj_name])) + outputs.append(self._compose_feature_view(fv_maps[obj_name], entities)) else: continue else: - outputs.append(self._compose_feature_view(fv_maps[obj_name])) + outputs.append(self._compose_feature_view(fv_maps[obj_name], entities)) return outputs - def _compose_feature_view(self, row: Row) -> FeatureView: + def _compose_feature_view(self, row: Row, entity_list: List[Row]) -> FeatureView: + def find_and_compose_entity(name: str) -> Entity: + name = SqlIdentifier(name).resolved() + for e in entity_list: + if e["NAME"] == name: + return Entity( + name=SqlIdentifier(e["NAME"], case_sensitive=True).identifier(), + join_keys=e["JOIN_KEYS"].strip("[]").split(","), + desc=e["DESC"], + ) + raise RuntimeError(f"Cannot find entity {name} from retrieved entity list: {entity_list}") + name, version = row["name"].split(_FEATURE_VIEW_NAME_DELIMITER) name = SqlIdentifier(name, case_sensitive=True) m = re.match(_DT_OR_VIEW_QUERY_PATTERN, row["text"]) @@ -1494,7 +1521,7 @@ def _compose_feature_view(self, row: Row) -> FeatureView: df = self._session.sql(query) desc = m.group("comment") entity_names = m.group("entities") - entities = [self.get_entity(n) for n in entity_names.split(_FEATURE_VIEW_ENTITY_TAG_DELIMITER)] + entities = [find_and_compose_entity(n) for n in entity_names.split(_FEATURE_VIEW_ENTITY_TAG_DELIMITER)] ts_col = m.group("ts_col") timestamp_col = ts_col if ts_col != _TIMESTAMP_COL_PLACEHOLDER else None @@ -1515,6 +1542,7 @@ def _compose_feature_view(self, row: Row) -> FeatureView: warehouse=SqlIdentifier(row["warehouse"], case_sensitive=True).identifier(), refresh_mode=row["refresh_mode"], refresh_mode_reason=row["refresh_mode_reason"], + owner=row["owner"], ) return fv else: @@ -1522,7 +1550,7 @@ def _compose_feature_view(self, row: Row) -> FeatureView: df = self._session.sql(query) desc = m.group("comment") entity_names = m.group("entities") - entities = [self.get_entity(n) for n in entity_names.split(_FEATURE_VIEW_ENTITY_TAG_DELIMITER)] + entities = [find_and_compose_entity(n) for n in entity_names.split(_FEATURE_VIEW_ENTITY_TAG_DELIMITER)] ts_col = m.group("ts_col") timestamp_col = ts_col if ts_col != _TIMESTAMP_COL_PLACEHOLDER else None @@ -1541,6 +1569,7 @@ def _compose_feature_view(self, row: Row) -> FeatureView: warehouse=None, refresh_mode=None, refresh_mode_reason=None, + owner=row["owner"], ) return fv diff --git a/snowflake/ml/feature_store/feature_view.py b/snowflake/ml/feature_store/feature_view.py index ccceff88..1f4b8b67 100644 --- a/snowflake/ml/feature_store/feature_view.py +++ b/snowflake/ml/feature_store/feature_view.py @@ -138,6 +138,7 @@ def __init__( self._warehouse: Optional[SqlIdentifier] = None self._refresh_mode: Optional[str] = None self._refresh_mode_reason: Optional[str] = None + self._owner: Optional[str] = None self._validate() def slice(self, names: List[str]) -> FeatureViewSlice: @@ -291,6 +292,10 @@ def refresh_mode(self) -> Optional[str]: def refresh_mode_reason(self) -> Optional[str]: return self._refresh_mode_reason + @property + def owner(self) -> Optional[str]: + return self._owner + def _get_query(self) -> str: if len(self._feature_df.queries["queries"]) != 1: raise ValueError( @@ -354,6 +359,7 @@ def __eq__(self, other: object) -> bool: and self.warehouse == other.warehouse and self.refresh_mode == other.refresh_mode and self.refresh_mode_reason == other.refresh_mode_reason + and self._owner == other._owner ) def _to_dict(self) -> Dict[str, str]: @@ -394,9 +400,15 @@ def from_json(cls, json_str: str, session: Session) -> FeatureView: if _FEATURE_OBJ_TYPE not in json_dict or json_dict[_FEATURE_OBJ_TYPE] != cls.__name__: raise ValueError(f"Invalid json str for {cls.__name__}: {json_str}") + entities = [] + for e_json in json_dict["_entities"]: + e = Entity(e_json["name"], e_json["join_keys"], e_json["desc"]) + e.owner = e_json["owner"] + entities.append(e) + return FeatureView._construct_feature_view( name=json_dict["_name"], - entities=[Entity(**e) for e in json_dict["_entities"]], + entities=entities, feature_df=session.sql(json_dict["_query"]), timestamp_col=json_dict["_timestamp_col"], desc=json_dict["_desc"], @@ -409,6 +421,7 @@ def from_json(cls, json_str: str, session: Session) -> FeatureView: warehouse=json_dict["_warehouse"], refresh_mode=json_dict["_refresh_mode"], refresh_mode_reason=json_dict["_refresh_mode_reason"], + owner=json_dict["_owner"], ) @staticmethod @@ -439,6 +452,7 @@ def _construct_feature_view( warehouse: Optional[str], refresh_mode: Optional[str], refresh_mode_reason: Optional[str], + owner: Optional[str], ) -> FeatureView: fv = FeatureView( name=name, @@ -455,5 +469,6 @@ def _construct_feature_view( fv._warehouse = SqlIdentifier(warehouse) if warehouse is not None else None fv._refresh_mode = refresh_mode fv._refresh_mode_reason = refresh_mode_reason + fv._owner = owner fv.attach_feature_desc(feature_descs) return fv diff --git a/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.ipynb b/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.ipynb index 8c4aec7b..4ae6b64e 100644 --- a/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.ipynb +++ b/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.ipynb @@ -5,20 +5,9 @@ "id": "0bb54abc", "metadata": {}, "source": [ - "- snowflake-ml-python version: 1.2.0\n", - "- Feature Store PrPr Version: 0.4.0\n", - "- Updated date: 1/3/2024" - ] - }, - { - "cell_type": "markdown", - "id": "144c09b2", - "metadata": {}, - "source": [ - "## Before getting started\n", - "\n", - "### Watch out object name case sensitivity\n", - "The Model Registry and Feature Store are not consistent with each other in the way they case names for databases, schemas, and other SQL objects. (Keep in mind that the objects in both APIs are Snowflake objects on the back end.) The model registry preserves the case of names for these objects, while the feature store converts names to uppercase unless you enclose them in double quotes. The way the feature store handles names is consistent with Snowflake’s identifier requirements. We are working to make this more consistent. In the meantime, we suggest using uppercase names in both APIs to ensure correct interoperation between the feature store and the model registry." + "- snowflake-ml-python version: 1.2.2\n", + "- Feature Store PrPr Version: 0.5.1\n", + "- Updated date: 2/12/2024" ] }, { diff --git a/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.pdf b/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.pdf index af5b0a63..c96e8ec5 100644 Binary files a/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.pdf and b/snowflake/ml/feature_store/notebooks/customer_demo/Basic_Feature_Demo.pdf differ diff --git a/snowflake/ml/feature_store/notebooks/customer_demo/DBT_External_Feature_Pipeline_Demo.ipynb b/snowflake/ml/feature_store/notebooks/customer_demo/DBT_External_Feature_Pipeline_Demo.ipynb index 91d4c115..a0fad0cf 100644 --- a/snowflake/ml/feature_store/notebooks/customer_demo/DBT_External_Feature_Pipeline_Demo.ipynb +++ b/snowflake/ml/feature_store/notebooks/customer_demo/DBT_External_Feature_Pipeline_Demo.ipynb @@ -5,8 +5,9 @@ "id": "5f46aef7-1fc7-408e-acf1-0dc030981c58", "metadata": {}, "source": [ - "- snowflake-ml-python version: 1.2.1\n", - "- Last updated on: 1/30/2024" + "- snowflake-ml-python version: 1.2.2\n", + "- Feature Store PrPr version: 0.5.1\n", + "- Last updated on: 2/12/2024" ] }, { @@ -525,7 +526,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.18" } }, "nbformat": 4, diff --git a/snowflake/ml/feature_store/notebooks/customer_demo/DBT_External_Feature_Pipeline_Demo.pdf b/snowflake/ml/feature_store/notebooks/customer_demo/DBT_External_Feature_Pipeline_Demo.pdf index 97666219..d1fd0e3a 100644 Binary files a/snowflake/ml/feature_store/notebooks/customer_demo/DBT_External_Feature_Pipeline_Demo.pdf and b/snowflake/ml/feature_store/notebooks/customer_demo/DBT_External_Feature_Pipeline_Demo.pdf differ diff --git a/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.ipynb b/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.ipynb index 009d2e10..be8b3766 100644 --- a/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.ipynb +++ b/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.ipynb @@ -5,20 +5,9 @@ "id": "4f029c96", "metadata": {}, "source": [ - "- snowflake-ml-python version: 1.2.0\n", - "- Feature Store PrPr version: 0.4.0\n", - "- Updated date: 1/3/2024" - ] - }, - { - "cell_type": "markdown", - "id": "5ba51119", - "metadata": {}, - "source": [ - "## Before getting started\n", - "\n", - "### Watch out object name case sensitivity\n", - "The Model Registry and Feature Store are not consistent with each other in the way they case names for databases, schemas, and other SQL objects. (Keep in mind that the objects in both APIs are Snowflake objects on the back end.) The model registry preserves the case of names for these objects, while the feature store converts names to uppercase unless you enclose them in double quotes. The way the feature store handles names is consistent with Snowflake’s identifier requirements. We are working to make this more consistent. In the meantime, we suggest using uppercase names in both APIs to ensure correct interoperation between the feature store and the model registry." + "- snowflake-ml-python version: 1.2.2\n", + "- Feature Store PrPr version: 0.5.1\n", + "- Updated date: 2/12/2024" ] }, { diff --git a/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.pdf b/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.pdf index 0a0cb99b..5d49da4b 100644 Binary files a/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.pdf and b/snowflake/ml/feature_store/notebooks/customer_demo/Time_Series_Feature_Demo.pdf differ diff --git a/snowflake/ml/model/_client/model/model_impl.py b/snowflake/ml/model/_client/model/model_impl.py index f1591305..d70629f9 100644 --- a/snowflake/ml/model/_client/model/model_impl.py +++ b/snowflake/ml/model/_client/model/model_impl.py @@ -216,8 +216,25 @@ def show_versions(self) -> pd.DataFrame: ) return pd.DataFrame([row.as_dict() for row in rows]) + @telemetry.send_api_usage_telemetry( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) def delete_version(self, version_name: str) -> None: - raise NotImplementedError("Deleting version has not been supported yet.") + """Drop a version of the model. + + Args: + version_name: The name of the version. + """ + statement_params = telemetry.get_statement_params( + project=_TELEMETRY_PROJECT, + subproject=_TELEMETRY_SUBPROJECT, + ) + self._model_ops.delete_model_or_version( + model_name=self._model_name, + version_name=sql_identifier.SqlIdentifier(version_name), + statement_params=statement_params, + ) @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, diff --git a/snowflake/ml/model/_client/model/model_impl_test.py b/snowflake/ml/model/_client/model/model_impl_test.py index c097e09d..95a519e0 100644 --- a/snowflake/ml/model/_client/model/model_impl_test.py +++ b/snowflake/ml/model/_client/model/model_impl_test.py @@ -183,6 +183,15 @@ def test_default_setter(self) -> None: statement_params=mock.ANY, ) + def test_delete_version(self) -> None: + with mock.patch.object(self.m_model._model_ops, "delete_model_or_version") as mock_delete_model_or_version: + self.m_model.delete_version(version_name="V2") + mock_delete_model_or_version.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V2"), + statement_params=mock.ANY, + ) + def test_show_tags(self) -> None: m_res = {'DB."schema".MYTAG': "tag content", 'MYDB.SCHEMA."my_another_tag"': "1"} with mock.patch.object(self.m_model._model_ops, "show_tags", return_value=m_res) as mock_show_tags: diff --git a/snowflake/ml/model/_client/ops/model_ops.py b/snowflake/ml/model/_client/ops/model_ops.py index d7ae1f8c..f06dde43 100644 --- a/snowflake/ml/model/_client/ops/model_ops.py +++ b/snowflake/ml/model/_client/ops/model_ops.py @@ -447,8 +447,14 @@ def delete_model_or_version( version_name: Optional[sql_identifier.SqlIdentifier] = None, statement_params: Optional[Dict[str, Any]] = None, ) -> None: - # TODO: Delete version is not supported yet. - self._model_client.drop_model( - model_name=model_name, - statement_params=statement_params, - ) + if version_name: + self._model_version_client.drop_version( + model_name=model_name, + version_name=version_name, + statement_params=statement_params, + ) + else: + self._model_client.drop_model( + model_name=model_name, + statement_params=statement_params, + ) diff --git a/snowflake/ml/model/_client/ops/model_ops_test.py b/snowflake/ml/model/_client/ops/model_ops_test.py index 48729902..2c7709e2 100644 --- a/snowflake/ml/model/_client/ops/model_ops_test.py +++ b/snowflake/ml/model/_client/ops/model_ops_test.py @@ -892,7 +892,7 @@ def test_set_default_version_2(self) -> None: ) mock_set_default_version.assert_not_called() - def test_delete_model_or_version(self) -> None: + def test_delete_model_or_version_1(self) -> None: with mock.patch.object( self.m_ops._model_client, "drop_model", @@ -906,6 +906,22 @@ def test_delete_model_or_version(self) -> None: statement_params=self.m_statement_params, ) + def test_delete_model_or_version_2(self) -> None: + with mock.patch.object( + self.m_ops._model_version_client, + "drop_version", + ) as mock_drop_version: + self.m_ops.delete_model_or_version( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V2"), + statement_params=self.m_statement_params, + ) + mock_drop_version.assert_called_once_with( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V2"), + statement_params=self.m_statement_params, + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_client/sql/model_version.py b/snowflake/ml/model/_client/sql/model_version.py index 7861d2a4..6e64d07c 100644 --- a/snowflake/ml/model/_client/sql/model_version.py +++ b/snowflake/ml/model/_client/sql/model_version.py @@ -222,3 +222,16 @@ def set_metadata( ), statement_params=statement_params, ).has_dimensions(expected_rows=1, expected_cols=1).validate() + + def drop_version( + self, + *, + model_name: sql_identifier.SqlIdentifier, + version_name: sql_identifier.SqlIdentifier, + statement_params: Optional[Dict[str, Any]] = None, + ) -> None: + query_result_checker.SqlResultValidator( + self._session, + f"ALTER MODEL {self.fully_qualified_model_name(model_name)} DROP VERSION {version_name.identifier()}", + statement_params=statement_params, + ).has_dimensions(expected_rows=1, expected_cols=1).validate() diff --git a/snowflake/ml/model/_client/sql/model_version_test.py b/snowflake/ml/model/_client/sql/model_version_test.py index 0fc1c3e1..abc29f8d 100644 --- a/snowflake/ml/model/_client/sql/model_version_test.py +++ b/snowflake/ml/model/_client/sql/model_version_test.py @@ -240,6 +240,23 @@ def test_set_metadata(self) -> None: statement_params=m_statement_params, ) + def test_drop_version(self) -> None: + m_statement_params = {"test": "1"} + m_df = mock_data_frame.MockDataFrame( + collect_result=[Row("Model MODEL successfully altered.")], collect_statement_params=m_statement_params + ) + self.m_session.add_mock_sql("""ALTER MODEL TEMP."test".MODEL DROP VERSION V2""", m_df) + c_session = cast(Session, self.m_session) + model_version_sql.ModelVersionSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).drop_version( + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("V2"), + statement_params=m_statement_params, + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_signatures/snowpark_test.py b/snowflake/ml/model/_signatures/snowpark_test.py index 1a9502e1..242efb95 100644 --- a/snowflake/ml/model/_signatures/snowpark_test.py +++ b/snowflake/ml/model/_signatures/snowpark_test.py @@ -224,6 +224,23 @@ def test_validate_data_with_features(self) -> None: ) model_signature._validate_snowpark_data(df, fts) + fts = [ + core.FeatureSpec("a", core.DataType.INT16), + core.FeatureSpec("b", core.DataType.FLOAT), + ] + schema = spt.StructType( + [spt.StructField('"a"', spt.DecimalType(6, 0)), spt.StructField('"b"', spt.DecimalType(6, 2))] + ) + df = self._session.create_dataframe( + [ + [decimal.Decimal(1), decimal.Decimal(2.5)], + [decimal.Decimal(1), decimal.Decimal(6.8)], + ], + schema, + ) + with self.assertWarnsRegex(UserWarning, "is being automatically converted to DOUBLE in the Snowpark DataFrame"): + model_signature._validate_snowpark_data(df, fts) + fts = [ core.FeatureSpec("a", core.DataType.INT16), core.FeatureSpec("b", core.DataType.FLOAT), diff --git a/snowflake/ml/model/model_signature.py b/snowflake/ml/model/model_signature.py index 1a2fc5fd..81db40d2 100644 --- a/snowflake/ml/model/model_signature.py +++ b/snowflake/ml/model/model_signature.py @@ -528,11 +528,8 @@ def _validate_snowpark_type_feature( if not ( isinstance( field_data_type, - (spt._IntegralType, spt.FloatType, spt.DoubleType), + (spt._IntegralType, spt.FloatType, spt.DoubleType, spt.DecimalType), ) - # We are not allowing > 0 scale as it will become a decimal.Decimal - # Although it is castable, the support will be done as another effort. - or (isinstance(field_data_type, spt.DecimalType) and field_data_type.scale == 0) ): raise snowml_exceptions.SnowflakeMLException( error_code=error_codes.INVALID_DATA, @@ -550,6 +547,17 @@ def _validate_snowpark_type_feature( f"because of its original type {field_data_type} is non-Numeric." ), ) + if isinstance(field_data_type, spt.DecimalType) and field_data_type.scale > 0: + warnings.warn( + ( + f"Type {field_data_type} is being automatically converted to DOUBLE in the Snowpark DataFrame. " + "This automatic conversion may lead to potential precision loss and rounding errors. " + "If you wish to prevent this conversion, you should manually perform " + "the necessary data type conversion." + ), + category=UserWarning, + stacklevel=2, + ) min_v, max_v = value_range if ( max_v > np.finfo(ft_type._numpy_type).max # type: ignore[arg-type] diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py b/snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py index aff9ec13..55c4cec7 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py @@ -283,7 +283,7 @@ def fit_search_snowpark( "snowflake-snowpark-python<2", "fastparquet<2023.11", "pyarrow<14", - "cachetools<5", + "cachetools<6", ] @sproc( # type: ignore[misc] diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py index ece6f771..477f596a 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py @@ -83,6 +83,8 @@ def batch_inference( statement_params=statement_params, ) def vec_batch_infer(ds: PandasSeries[dict]) -> PandasSeries[dict]: # type: ignore[type-arg] + import numbers + import numpy as np import pandas as pd @@ -120,6 +122,9 @@ def vec_batch_infer(ds: PandasSeries[dict]) -> PandasSeries[dict]: # type: igno ): # In case of kneighbors, functions return a tuple of ndarrays. transformed_numpy_array = np.stack(inference_res, axis=1) + elif isinstance(inference_res, numbers.Number): + # In case of BernoulliRBM, functions return a float + transformed_numpy_array = np.array([inference_res]) else: transformed_numpy_array = inference_res diff --git a/snowflake/ml/modeling/metrics/BUILD.bazel b/snowflake/ml/modeling/metrics/BUILD.bazel index 5e798929..114033ca 100644 --- a/snowflake/ml/modeling/metrics/BUILD.bazel +++ b/snowflake/ml/modeling/metrics/BUILD.bazel @@ -11,6 +11,7 @@ py_library( ":init", ":metrics_utils", "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/_internal/utils:result", ], ) diff --git a/snowflake/ml/modeling/metrics/classification.py b/snowflake/ml/modeling/metrics/classification.py index 82fd3cd2..048e9894 100644 --- a/snowflake/ml/modeling/metrics/classification.py +++ b/snowflake/ml/modeling/metrics/classification.py @@ -7,11 +7,14 @@ import cloudpickle import numpy as np import numpy.typing as npt +import sklearn +from packaging import version from sklearn import exceptions, metrics import snowflake.snowpark._internal.utils as snowpark_utils from snowflake import snowpark from snowflake.ml._internal import telemetry +from snowflake.ml._internal.utils import result from snowflake.ml.modeling.metrics import metrics_utils from snowflake.snowpark import functions as F, types as T from snowflake.snowpark._internal.utils import ( @@ -791,6 +794,80 @@ def precision_recall_fscore_support( support - None (if average is not None) or array of int, shape = [n_unique_labels] The number of occurrences of each label in the y true column(s). """ + if average == "samples": + metrics_utils.check_label_columns(y_true_col_names, y_pred_col_names) + + session = df._session + assert session is not None + sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE) + sklearn_release = version.parse(sklearn.__version__).release + statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT) + + cols = metrics_utils.flatten_cols([y_true_col_names, y_pred_col_names, sample_weight_col_name]) + queries = df[cols].queries["queries"] + + pickled_result_module = cloudpickle.dumps(result) + + @F.sproc( # type: ignore[misc] + is_permanent=False, + session=session, + name=sproc_name, + replace=True, + packages=[ + "cloudpickle", + f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*", + "snowflake-snowpark-python", + ], + statement_params=statement_params, + anonymous=True, + ) + def precision_recall_fscore_support_anon_sproc(session: snowpark.Session) -> bytes: + for query in queries[:-1]: + _ = session.sql(query).collect(statement_params=statement_params) + sp_df = session.sql(queries[-1]) + df = sp_df.to_pandas(statement_params=statement_params) + df.columns = sp_df.columns + + y_true = df[y_true_col_names] + y_pred = df[y_pred_col_names] + sample_weight = df[sample_weight_col_name] if sample_weight_col_name else None + + with warnings.catch_warnings(record=True) as w: + p, r, f, s = metrics.precision_recall_fscore_support( + y_true, + y_pred, + beta=beta, + labels=labels, + pos_label=pos_label, + average=average, + warn_for=warn_for, + sample_weight=sample_weight, + zero_division=zero_division, + ) + + # handle zero_division warnings + warning = None + if len(w) > 0 and issubclass(w[-1].category, exceptions.UndefinedMetricWarning): + warning = w[-1] + + result_module = cloudpickle.loads(pickled_result_module) + return result_module.serialize(session, (p, r, f, s, warning)) # type: ignore[no-any-return] + + kwargs = telemetry.get_sproc_statement_params_kwargs( + precision_recall_fscore_support_anon_sproc, statement_params + ) + result_object = result.deserialize(session, precision_recall_fscore_support_anon_sproc(session, **kwargs)) + + res: Union[ + Tuple[float, float, float, None], + Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], + ] = result_object[:4] + warning = result_object[-1] + if warning: + warnings.warn(warning.message, category=warning.category, stacklevel=2) + return res + + # Distributed when average != "samples" session = df._session assert session is not None diff --git a/snowflake/ml/modeling/model_selection/grid_search_cv.py b/snowflake/ml/modeling/model_selection/grid_search_cv.py index 87cc794c..93442a97 100644 --- a/snowflake/ml/modeling/model_selection/grid_search_cv.py +++ b/snowflake/ml/modeling/model_selection/grid_search_cv.py @@ -691,6 +691,45 @@ def decision_function( return output_df + @available_if(original_estimator_has_callable("score_samples")) # type: ignore[misc] + @telemetry.send_api_usage_telemetry( + project=_PROJECT, + subproject=_SUBPROJECT, + custom_tags=dict([("autogen", True)]), + ) + def score_samples( + self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "score_samples_" + ) -> Union[DataFrame, pd.DataFrame]: + """Call score_samples on the estimator with the best found parameters. + Only available if refit=True and the underlying estimator supports score_samples. + + Args: + dataset (Union[DataFrame, pd.DataFrame]): + Snowpark or Pandas DataFrame. + output_cols_prefix (str): + Prefix for the response columns. Defaults to "score_samples_". + + Returns: + Union[DataFrame, pd.DataFrame]: + Output dataset with results of the decision function for the samples in input dataset. + """ + super()._check_dataset_type(dataset) + if isinstance(dataset, DataFrame): + output_df = self._batch_inference( + dataset=dataset, + inference_method="score_samples", + expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + expected_output_cols_type="float", + ) + elif isinstance(dataset, pd.DataFrame): + output_df = self._sklearn_inference( + dataset=dataset, + inference_method="score_samples", + expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + ) + + return output_df + @available_if(original_estimator_has_callable("score")) # type: ignore[misc] def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float: """ diff --git a/snowflake/ml/modeling/model_selection/randomized_search_cv.py b/snowflake/ml/modeling/model_selection/randomized_search_cv.py index 23c3dd54..36a50637 100644 --- a/snowflake/ml/modeling/model_selection/randomized_search_cv.py +++ b/snowflake/ml/modeling/model_selection/randomized_search_cv.py @@ -702,6 +702,45 @@ def decision_function( return output_df + @available_if(original_estimator_has_callable("score_samples")) # type: ignore[misc] + @telemetry.send_api_usage_telemetry( + project=_PROJECT, + subproject=_SUBPROJECT, + custom_tags=dict([("autogen", True)]), + ) + def score_samples( + self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "score_samples_" + ) -> Union[DataFrame, pd.DataFrame]: + """Call score_samples on the estimator with the best found parameters. + Only available if refit=True and the underlying estimator supports score_samples. + + Args: + dataset (Union[DataFrame, pd.DataFrame]): + Snowpark or Pandas DataFrame. + output_cols_prefix (str): + Prefix for the response columns. Defaults to "score_samples_". + + Returns: + Union[DataFrame, pd.DataFrame]: + Output dataset with results of the decision function for the samples in input dataset. + """ + super()._check_dataset_type(dataset) + if isinstance(dataset, DataFrame): + output_df = self._batch_inference( + dataset=dataset, + inference_method="score_samples", + expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + expected_output_cols_type="float", + ) + elif isinstance(dataset, pd.DataFrame): + output_df = self._sklearn_inference( + dataset=dataset, + inference_method="score_samples", + expected_output_cols_list=self._get_output_column_names(output_cols_prefix), + ) + + return output_df + @available_if(original_estimator_has_callable("score")) # type: ignore[misc] def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float: """ diff --git a/snowflake/ml/modeling/pipeline/pipeline.py b/snowflake/ml/modeling/pipeline/pipeline.py index c57871f0..d4806792 100644 --- a/snowflake/ml/modeling/pipeline/pipeline.py +++ b/snowflake/ml/modeling/pipeline/pipeline.py @@ -392,6 +392,25 @@ def predict(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[sno """ return self._invoke_estimator_func("predict", dataset) + @metaestimators.available_if(_final_step_has("score_samples")) # type: ignore[misc] + @telemetry.send_api_usage_telemetry( + project=_PROJECT, + subproject=_SUBPROJECT, + ) + def score_samples( + self, dataset: Union[snowpark.DataFrame, pd.DataFrame] + ) -> Union[snowpark.DataFrame, pd.DataFrame]: + """ + Transform the dataset by applying all the transformers in order and predict using the estimator. + + Args: + dataset: Input dataset. + + Returns: + Output dataset. + """ + return self._invoke_estimator_func("score_samples", dataset) + @metaestimators.available_if(_final_step_has("predict_proba")) # type: ignore[misc] @telemetry.send_api_usage_telemetry( project=_PROJECT, diff --git a/snowflake/ml/registry/model_registry.py b/snowflake/ml/registry/model_registry.py index ae8a7a18..f19de18c 100644 --- a/snowflake/ml/registry/model_registry.py +++ b/snowflake/ml/registry/model_registry.py @@ -167,7 +167,7 @@ def _create_registry_views( # Create views on most recent metadata items. metadata_view_name_prefix = identifier.concat_names([metadata_table_name, "_LAST_"]) metadata_view_template = formatting.unwrap( - """CREATE OR REPLACE VIEW {database}.{schema}.{attribute_view} COPY GRANTS AS + """CREATE OR REPLACE TEMPORARY VIEW {database}.{schema}.{attribute_view} COPY GRANTS AS SELECT DISTINCT MODEL_ID, {select_expression} AS {final_attribute_name} FROM {metadata_table} WHERE ATTRIBUTE_NAME = '{attribute_name}'""" ) @@ -177,7 +177,9 @@ def _create_registry_views( metadata_select_fields = [] for attribute_name in _LIST_METADATA_ATTRIBUTE: view_name = identifier.concat_names([metadata_view_name_prefix, attribute_name]) - select_expression = f"(LAST_VALUE(VALUE) OVER (PARTITION BY MODEL_ID ORDER BY SEQUENCE_ID))['{attribute_name}']" + select_expression = ( + f"(LAST_VALUE(VALUE) OVER (PARTITION BY MODEL_ID ORDER BY EVENT_TIMESTAMP))['{attribute_name}']" + ) sql = metadata_view_template.format( database=database_name, schema=schema_name, @@ -221,7 +223,7 @@ def _create_registry_views( registry_view_name = identifier.concat_names([registry_table_name, "_VIEW"]) metadata_select_fields_formatted = ",".join(metadata_select_fields) session.sql( - f"""CREATE OR REPLACE VIEW {fully_qualified_schema_name}.{registry_view_name} COPY GRANTS AS + f"""CREATE OR REPLACE TEMPORARY VIEW {fully_qualified_schema_name}.{registry_view_name} COPY GRANTS AS SELECT {registry_table_name}.*, {metadata_select_fields_formatted} FROM {registry_table_name} {metadata_views_join}""" ).collect(statement_params=statement_params) @@ -229,7 +231,7 @@ def _create_registry_views( # Create artifact view. it joins artifact tables with registry table on model id. artifact_view_name = identifier.concat_names([artifact_table_name, "_VIEW"]) session.sql( - f"""CREATE OR REPLACE VIEW {fully_qualified_schema_name}.{artifact_view_name} COPY GRANTS AS + f"""CREATE OR REPLACE TEMPORARY VIEW {fully_qualified_schema_name}.{artifact_view_name} COPY GRANTS AS SELECT {registry_table_name}.NAME AS MODEL_NAME, {registry_table_name}.VERSION AS MODEL_VERSION, @@ -265,7 +267,7 @@ def _create_active_permanent_deployment_view( # Active deployments are those whose last operation is not DROP. active_deployments_view_name = identifier.concat_names([deployment_table_name, "_VIEW"]) active_deployments_view_expr = f""" - CREATE OR REPLACE VIEW {fully_qualified_schema_name}.{active_deployments_view_name} + CREATE OR REPLACE TEMPORARY VIEW {fully_qualified_schema_name}.{active_deployments_view_name} COPY GRANTS AS SELECT DEPLOYMENT_NAME, @@ -343,6 +345,17 @@ def __init__( statement_params = self._get_statement_params(inspect.currentframe()) self._svm.validate_schema_version(statement_params) + _create_registry_views( + session, + self._name, + self._schema, + self._registry_table, + self._metadata_table, + self._deployment_table, + self._artifact_table, + statement_params, + ) + # Private methods def _get_statement_params(self, frame: Optional[types.FrameType]) -> Dict[str, Any]: @@ -2067,10 +2080,6 @@ def create_model_registry( # These might be exposed as parameters in the future. database_name = identifier.get_inferred_name(database_name) schema_name = identifier.get_inferred_name(schema_name) - registry_table_name = identifier.get_inferred_name(_MODELS_TABLE_NAME) - metadata_table_name = identifier.get_inferred_name(_METADATA_TABLE_NAME) - deployment_table_name = identifier.get_inferred_name(_DEPLOYMENT_TABLE_NAME) - artifact_table_name = identifier.get_inferred_name(_initial_schema._ARTIFACT_TABLE_NAME) statement_params = telemetry.get_function_usage_statement_params( project=_TELEMETRY_PROJECT, @@ -2090,16 +2099,6 @@ def create_model_registry( svm.try_upgrade(statement_params) - _create_registry_views( - session, - database_name, - schema_name, - registry_table_name, - metadata_table_name, - deployment_table_name, - artifact_table_name, - statement_params, - ) finally: if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call] # Restore the db & schema to the original ones diff --git a/snowflake/ml/registry/model_registry_test.py b/snowflake/ml/registry/model_registry_test.py index ba539850..5052b9a6 100644 --- a/snowflake/ml/registry/model_registry_test.py +++ b/snowflake/ml/registry/model_registry_test.py @@ -433,6 +433,7 @@ def setup_open_call(self) -> None: ) self._mock_show_version_table_exists({}) self._mock_select_from_version_table({}, _schema._CURRENT_SCHEMA_VERSION) + self.setup_create_views_call() def setup_list_model_call(self) -> mock_data_frame.MockDataFrame: """Setup the expected calls originating from list_model.""" @@ -449,7 +450,7 @@ def setup_create_views_call(self) -> None: """Setup the expected calls originating from _create_views.""" self.add_session_mock_sql( query=( - f"""CREATE OR REPLACE VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_DEPLOYMENTS_TABLE_NAME}_VIEW + f"""CREATE OR REPLACE TEMPORARY VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_DEPLOYMENTS_TABLE_NAME}_VIEW COPY GRANTS AS SELECT DEPLOYMENT_NAME, @@ -474,12 +475,13 @@ def setup_create_views_call(self) -> None: ) self.add_session_mock_sql( query=( - f"""CREATE OR REPLACE VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}_LAST_DESCRIPTION + f"""CREATE OR REPLACE TEMPORARY VIEW + {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}_LAST_DESCRIPTION COPY GRANTS AS SELECT DISTINCT MODEL_ID, (LAST_VALUE(VALUE) OVER ( - PARTITION BY MODEL_ID ORDER BY SEQUENCE_ID))['DESCRIPTION'] + PARTITION BY MODEL_ID ORDER BY EVENT_TIMESTAMP))['DESCRIPTION'] as DESCRIPTION FROM {_METADATA_TABLE_NAME} WHERE ATTRIBUTE_NAME = 'DESCRIPTION'""" ), @@ -489,12 +491,12 @@ def setup_create_views_call(self) -> None: ) self.add_session_mock_sql( query=( - f"""CREATE OR REPLACE VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}_LAST_METRICS + f"""CREATE OR REPLACE TEMPORARY VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}_LAST_METRICS COPY GRANTS AS SELECT DISTINCT MODEL_ID, (LAST_VALUE(VALUE) OVER ( - PARTITION BY MODEL_ID ORDER BY SEQUENCE_ID))['METRICS'] + PARTITION BY MODEL_ID ORDER BY EVENT_TIMESTAMP))['METRICS'] as METRICS FROM {_METADATA_TABLE_NAME} WHERE ATTRIBUTE_NAME = 'METRICS'""" ), @@ -504,12 +506,12 @@ def setup_create_views_call(self) -> None: ) self.add_session_mock_sql( query=( - f"""CREATE OR REPLACE VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}_LAST_TAGS + f"""CREATE OR REPLACE TEMPORARY VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}_LAST_TAGS COPY GRANTS AS SELECT DISTINCT MODEL_ID, (LAST_VALUE(VALUE) OVER ( - PARTITION BY MODEL_ID ORDER BY SEQUENCE_ID))['TAGS'] + PARTITION BY MODEL_ID ORDER BY EVENT_TIMESTAMP))['TAGS'] as TAGS FROM {_METADATA_TABLE_NAME} WHERE ATTRIBUTE_NAME = 'TAGS'""" ), @@ -519,7 +521,7 @@ def setup_create_views_call(self) -> None: ) self.add_session_mock_sql( query=( - f"""CREATE OR REPLACE VIEW + f"""CREATE OR REPLACE TEMPORARY VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_METADATA_TABLE_NAME}_LAST_REGISTRATION COPY GRANTS AS SELECT DISTINCT MODEL_ID, EVENT_TIMESTAMP as REGISTRATION_TIMESTAMP @@ -531,7 +533,7 @@ def setup_create_views_call(self) -> None: ) self.add_session_mock_sql( query=( - f"""CREATE OR REPLACE VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_REGISTRY_TABLE_NAME}_VIEW + f"""CREATE OR REPLACE TEMPORARY VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_REGISTRY_TABLE_NAME}_VIEW COPY GRANTS AS SELECT {_REGISTRY_TABLE_NAME}.*, {_METADATA_TABLE_NAME}_LAST_DESCRIPTION.DESCRIPTION AS DESCRIPTION, @@ -555,7 +557,7 @@ def setup_create_views_call(self) -> None: ) self.add_session_mock_sql( query=( - f"""CREATE OR REPLACE VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_ARTIFACTS_TABLE_NAME}_VIEW + f"""CREATE OR REPLACE TEMPORARY VIEW {_DATABASE_NAME}.{_SCHEMA_NAME}.{_ARTIFACTS_TABLE_NAME}_VIEW COPY GRANTS AS SELECT {_REGISTRY_TABLE_NAME}.NAME AS MODEL_NAME, @@ -596,6 +598,7 @@ def setup_open_existing(self) -> None: ) self._mock_show_version_table_exists({}) self._mock_select_from_version_table({}, _schema._CURRENT_SCHEMA_VERSION) + self.setup_create_views_call() def setup_schema_upgrade_calls(self, statement_params: Dict[str, str]) -> None: self._mock_show_version_table_exists(statement_params) @@ -755,8 +758,6 @@ def test_create_new(self) -> None: self.setup_schema_upgrade_calls(statement_params) - self.setup_create_views_call() - model_registry.create_model_registry( session=cast(snowpark.Session, self._session), database_name=_DATABASE_NAME, @@ -781,7 +782,6 @@ def test_create_if_not_exists(self) -> None: self._mock_create_artifacts_table_not_exists(statement_params) self.setup_schema_upgrade_calls(statement_params) - self.setup_create_views_call() # 2. SQL queries issued by ModelRegistry constructor. self.setup_open_existing() diff --git a/snowflake/ml/registry/registry.py b/snowflake/ml/registry/registry.py index 57fb8701..22387889 100644 --- a/snowflake/ml/registry/registry.py +++ b/snowflake/ml/registry/registry.py @@ -108,6 +108,8 @@ def log_model( to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel is not specified, Snowflake Anaconda Channel will be used. Defaults to None. pip_requirements: List of Pip package specifications. Defaults to None. + Currently it is not supported since Model can only executed in Snowflake Warehouse where all + dependencies are required to be retrieved from Snowflake Anaconda Channel. python_version: Python version in which the model is run. Defaults to None. code_paths: List of directories containing code to import. Defaults to None. ext_modules: List of external modules to pickle with the model object. @@ -130,6 +132,9 @@ def log_model( - max_batch_size: Maximum batch size that the method could accept in the Snowflake Warehouse. Defaults to None, determined automatically by Snowflake. + Raises: + NotImplementedError: `pip_requirements` is not supported. + Returns: ModelVersion: ModelVersion object corresponding to the model just logged. """ @@ -138,6 +143,12 @@ def log_model( project=_TELEMETRY_PROJECT, subproject=_MODEL_TELEMETRY_SUBPROJECT, ) + if pip_requirements: + raise NotImplementedError( + "Currently `pip_requirements` is not supported since Model can only executed " + "in Snowflake Warehouse where all dependencies are required to be retrieved " + "from Snowflake Anaconda Channel." + ) return self._model_manager.log_model( model=model, model_name=model_name, @@ -145,7 +156,7 @@ def log_model( comment=comment, metrics=metrics, conda_dependencies=conda_dependencies, - pip_requirements=pip_requirements, + pip_requirements=None, python_version=python_version, signatures=signatures, sample_input_data=sample_input_data, diff --git a/snowflake/ml/registry/registry_test.py b/snowflake/ml/registry/registry_test.py index ef42d1ed..886bfc05 100644 --- a/snowflake/ml/registry/registry_test.py +++ b/snowflake/ml/registry/registry_test.py @@ -114,7 +114,6 @@ def test_log_model(self) -> None: m_model = mock.MagicMock() m_conda_dependency = mock.MagicMock() m_sample_input_data = mock.MagicMock() - m_pip_requirements = mock.MagicMock() m_signatures = mock.MagicMock() m_options = mock.MagicMock() m_python_version = mock.MagicMock() @@ -130,7 +129,6 @@ def test_log_model(self) -> None: comment=m_comment, metrics=m_metrics, conda_dependencies=m_conda_dependency, - pip_requirements=m_pip_requirements, python_version=m_python_version, signatures=m_signatures, sample_input_data=m_sample_input_data, @@ -145,7 +143,7 @@ def test_log_model(self) -> None: comment=m_comment, metrics=m_metrics, conda_dependencies=m_conda_dependency, - pip_requirements=m_pip_requirements, + pip_requirements=None, python_version=m_python_version, signatures=m_signatures, sample_input_data=m_sample_input_data, diff --git a/snowflake/ml/version.bzl b/snowflake/ml/version.bzl index ac48ae7f..6e677b35 100644 --- a/snowflake/ml/version.bzl +++ b/snowflake/ml/version.bzl @@ -1,2 +1,2 @@ # This is parsed by regex in conda reciper meta file. Make sure not to break it. -VERSION = "1.2.2" +VERSION = "1.2.3" diff --git a/tests/integ/snowflake/ml/feature_store/feature_store_case_sensitivity_test.py b/tests/integ/snowflake/ml/feature_store/feature_store_case_sensitivity_test.py index 02d89bf8..a936b07b 100644 --- a/tests/integ/snowflake/ml/feature_store/feature_store_case_sensitivity_test.py +++ b/tests/integ/snowflake/ml/feature_store/feature_store_case_sensitivity_test.py @@ -216,7 +216,8 @@ def test_entity_names(self, equi_names: List[str], diff_names: List[str]) -> Non # retrieve with equivalent name is fine. for equi_name in equi_names: - fs.get_entity(equi_name) + e_2 = fs.get_entity(equi_name) + self.assertEqual(e_2.name, SqlIdentifier(equi_name)) # delete with different names will fail for diff_name in diff_names: @@ -224,8 +225,8 @@ def test_entity_names(self, equi_names: List[str], diff_names: List[str]) -> Non fs.delete_entity(diff_name) # register with different names is fine - e_2 = Entity(name=diff_names[0], join_keys=["a"]) - fs.register_entity(e_2) + e_3 = Entity(name=diff_names[0], join_keys=["a"]) + fs.register_entity(e_3) # registered two entiteis. self.assertEqual(len(fs.list_entities().collect()), 2) diff --git a/tests/integ/snowflake/ml/feature_store/feature_store_test.py b/tests/integ/snowflake/ml/feature_store/feature_store_test.py index 0e616f01..0b816cff 100644 --- a/tests/integ/snowflake/ml/feature_store/feature_store_test.py +++ b/tests/integ/snowflake/ml/feature_store/feature_store_test.py @@ -199,6 +199,7 @@ def test_create_and_delete_entities(self) -> None: "NAME": ["aD", "PRODUCT", "USER"], "JOIN_KEYS": ['["AID"]', '["CID","PID"]', '["uid"]'], "DESC": ["", "", ""], + "OWNER": ["REGTEST_RL", "REGTEST_RL", "REGTEST_RL"], }, sort_cols=["NAME"], ) @@ -220,6 +221,7 @@ def test_create_and_delete_entities(self) -> None: "NAME": ["PRODUCT", "USER"], "JOIN_KEYS": ['["CID","PID"]', '["uid"]'], "DESC": ["", ""], + "OWNER": ["REGTEST_RL", "REGTEST_RL"], }, sort_cols=["NAME"], ) @@ -249,12 +251,17 @@ def test_retrieve_entity(self) -> None: e1 = Entity(name="foo", join_keys=["a", "b"], desc="my foo") e2 = Entity(name="bar", join_keys=["c"]) - fs.register_entity(e1) fs.register_entity(e2) + re1 = fs.get_entity("foo") + re2 = fs.get_entity("bar") - self.assertEqual(e1, fs.get_entity("foo")) - self.assertEqual(e2, fs.get_entity("bar")) + self.assertEqual(e1.name, re1.name) + self.assertEqual(e1.join_keys, re1.join_keys) + self.assertEqual(e1.desc, re1.desc) + self.assertEqual(e2.name, re2.name) + self.assertEqual(e2.join_keys, re2.join_keys) + self.assertEqual(e2.desc, re2.desc) compare_dataframe( actual_df=fs.list_entities().to_pandas(), @@ -262,6 +269,7 @@ def test_retrieve_entity(self) -> None: "NAME": ["FOO", "BAR"], "JOIN_KEYS": ['["A","B"]', '["C"]'], "DESC": ["my foo", ""], + "OWNER": ["REGTEST_RL", "REGTEST_RL"], }, sort_cols=["NAME"], ) @@ -513,6 +521,10 @@ def test_create_duplicated_feature_view(self) -> None: ) fv = fs.register_feature_view(feature_view=fv, version="v1") + with self.assertWarnsRegex(UserWarning, "FeatureView .* has already been registered."): + fv = fs.register_feature_view(feature_view=fv, version="v1") + self.assertIsNotNone(fv) + fv = FeatureView( name="fv", entities=[e], @@ -635,11 +647,11 @@ def test_register_with_cron_expr(self) -> None: self.assertEqual(res[0]["state"], "started") self.assertEqual(fv.refresh_freq, "DOWNSTREAM") - fs.suspend_feature_view(fv) + fv = fs.suspend_feature_view(fv) res = self._session.sql(f"SHOW TASKS LIKE '{task_name}' IN SCHEMA {fs._config.full_schema_path}").collect() self.assertEqual(res[0]["state"], "suspended") - fs.resume_feature_view(fv) + fv = fs.resume_feature_view(fv) res = self._session.sql(f"SHOW TASKS LIKE '{task_name}' IN SCHEMA {fs._config.full_schema_path}").collect() self.assertEqual(res[0]["state"], "started") @@ -1088,6 +1100,7 @@ def test_list_feature_views(self) -> None: "WAREHOUSE", "REFRESH_MODE", "REFRESH_MODE_REASON", + "OWNER", "PHYSICAL_NAME", ], ) diff --git a/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py b/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py index de9a92cb..af082029 100644 --- a/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py +++ b/tests/integ/snowflake/ml/model/_client/model/model_impl_integ_test.py @@ -86,9 +86,20 @@ def test_default(self) -> None: self.assertEqual(self._model.default.version_name, VERSION_NAME2) @unittest.skipUnless( - test_env_utils.get_current_snowflake_version() >= version.parse("8.2.0"), - "TAG on model only available when the Snowflake Version is newer than 8.2.0", + test_env_utils.get_current_snowflake_version() >= version.parse("8.7.0"), + "Drop version on model only available when the Snowflake Version is newer than 8.7.0", ) + def test_delete_version(self) -> None: + model, test_features, _ = model_factory.ModelFactory.prepare_sklearn_model() + self.registry.log_model( + model=model, + model_name=MODEL_NAME, + version_name="V3", + sample_input_data=test_features, + ) + self._model.delete_version("V3") + self.assertLen(self._model.show_versions(), 2) + def test_tag(self) -> None: fq_tag_name1 = identifier.get_schema_level_object_identifier(self._test_db, self._test_schema, self._tag_name1) fq_tag_name2 = identifier.get_schema_level_object_identifier(self._test_db, self._test_schema, self._tag_name2) diff --git a/tests/integ/snowflake/ml/modeling/model_selection/BUILD.bazel b/tests/integ/snowflake/ml/modeling/model_selection/BUILD.bazel index 83890da1..ab6a8de9 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/BUILD.bazel +++ b/tests/integ/snowflake/ml/modeling/model_selection/BUILD.bazel @@ -9,6 +9,7 @@ py_test( shard_count = 5, deps = [ "//snowflake/ml/modeling/decomposition:pca", + "//snowflake/ml/modeling/ensemble:isolation_forest", "//snowflake/ml/modeling/ensemble:random_forest_classifier", "//snowflake/ml/modeling/model_selection:grid_search_cv", "//snowflake/ml/modeling/svm:svc", @@ -25,6 +26,7 @@ py_test( shard_count = 3, deps = [ "//snowflake/ml/modeling/decomposition:pca", + "//snowflake/ml/modeling/ensemble:isolation_forest", "//snowflake/ml/modeling/ensemble:random_forest_classifier", "//snowflake/ml/modeling/model_selection:randomized_search_cv", "//snowflake/ml/modeling/svm:svc", diff --git a/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py b/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py index 10c32a70..9e6ff1d0 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py +++ b/tests/integ/snowflake/ml/modeling/model_selection/grid_search_integ_test.py @@ -8,13 +8,16 @@ from absl.testing import absltest, parameterized from sklearn.datasets import load_iris from sklearn.decomposition import PCA as SkPCA -from sklearn.ensemble import RandomForestClassifier as SkRandomForestClassifier +from sklearn.ensemble import ( + IsolationForest as SkIsolationForest, + RandomForestClassifier as SkRandomForestClassifier, +) from sklearn.model_selection import GridSearchCV as SkGridSearchCV from sklearn.svm import SVC as SkSVC, SVR as SkSVR from xgboost import XGBClassifier as SkXGBClassifier from snowflake.ml.modeling.decomposition import PCA -from snowflake.ml.modeling.ensemble import RandomForestClassifier +from snowflake.ml.modeling.ensemble import IsolationForest, RandomForestClassifier from snowflake.ml.modeling.model_selection import GridSearchCV from snowflake.ml.modeling.svm import SVC, SVR from snowflake.ml.modeling.xgboost import XGBClassifier @@ -345,6 +348,31 @@ def test_not_fitted_exception(self) -> None: ): reg.predict_proba(self._input_df) + def test_score_samples(self) -> None: + param_grid = {"max_features": [1, 2]} + sklearn_reg = SkGridSearchCV( + estimator=SkIsolationForest(random_state=0), param_grid=param_grid, scoring="accuracy" + ) + reg = GridSearchCV(estimator=IsolationForest(random_state=0), param_grid=param_grid, scoring="accuracy") + reg.set_input_cols(self._input_cols) + output_cols = ["OUTPUT_" + c for c in self._label_col] + reg.set_output_cols(output_cols) + reg.set_label_cols(self._label_col) + + reg.fit(self._input_df) + sklearn_reg.fit(X=self._input_df_pandas[self._input_cols], y=self._input_df_pandas[self._label_col].squeeze()) + + # Test score_samples + actual_score_samples_result = reg.score_samples(self._input_df).to_pandas().sort_values(by="INDEX") + actual_output_cols = [c for c in actual_score_samples_result.columns if c.find("score_samples") >= 0] + actual_score_samples_result = actual_score_samples_result[actual_output_cols].to_numpy() + sklearn_score_samples_array = sklearn_reg.score_samples(self._input_df_pandas[self._input_cols]) + np.testing.assert_allclose(actual_score_samples_result.flatten(), sklearn_score_samples_array.flatten()) + + actual_pandas_result = reg.score_samples(self._input_df_pandas[self._input_cols]) + actual_pandas_result = actual_pandas_result[actual_output_cols].to_numpy() + np.testing.assert_allclose(actual_pandas_result.flatten(), sklearn_score_samples_array.flatten()) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/modeling/model_selection/randomized_search_integ_test.py b/tests/integ/snowflake/ml/modeling/model_selection/randomized_search_integ_test.py index ac3b9ebe..1a29eb74 100644 --- a/tests/integ/snowflake/ml/modeling/model_selection/randomized_search_integ_test.py +++ b/tests/integ/snowflake/ml/modeling/model_selection/randomized_search_integ_test.py @@ -8,13 +8,16 @@ from absl.testing import absltest, parameterized from sklearn.datasets import load_iris from sklearn.decomposition import PCA as SkPCA -from sklearn.ensemble import RandomForestClassifier as SkRandomForestClassifier +from sklearn.ensemble import ( + IsolationForest as SkIsolationForest, + RandomForestClassifier as SkRandomForestClassifier, +) from sklearn.model_selection import RandomizedSearchCV as SkRandomizedSearchCV from sklearn.svm import SVC as SkSVC from xgboost import XGBClassifier as SkXGBClassifier from snowflake.ml.modeling.decomposition import PCA -from snowflake.ml.modeling.ensemble import RandomForestClassifier +from snowflake.ml.modeling.ensemble import IsolationForest, RandomForestClassifier from snowflake.ml.modeling.model_selection import RandomizedSearchCV from snowflake.ml.modeling.svm import SVC from snowflake.ml.modeling.xgboost import XGBClassifier @@ -297,6 +300,33 @@ def test_not_fitted_exception(self) -> None: ): reg.predict_proba(self._input_df) + def test_score_samples(self) -> None: + params = {"max_features": range(1, 3)} + sklearn_reg = SkRandomizedSearchCV( + estimator=SkIsolationForest(random_state=0), param_distributions=params, scoring="accuracy" + ) + reg = RandomizedSearchCV( + estimator=IsolationForest(random_state=0), param_distributions=params, scoring="accuracy" + ) + reg.set_input_cols(self._input_cols) + output_cols = ["OUTPUT_" + c for c in self._label_col] + reg.set_output_cols(output_cols) + reg.set_label_cols(self._label_col) + + reg.fit(self._input_df) + sklearn_reg.fit(X=self._input_df_pandas[self._input_cols], y=self._input_df_pandas[self._label_col].squeeze()) + + # Test score_samples + actual_score_samples_result = reg.score_samples(self._input_df).to_pandas().sort_values(by="INDEX") + actual_output_cols = [c for c in actual_score_samples_result.columns if c.find("score_samples") >= 0] + actual_score_samples_result = actual_score_samples_result[actual_output_cols].to_numpy() + sklearn_score_samples_array = sklearn_reg.score_samples(self._input_df_pandas[self._input_cols]) + np.testing.assert_allclose(actual_score_samples_result.flatten(), sklearn_score_samples_array.flatten()) + + actual_pandas_result = reg.score_samples(self._input_df_pandas[self._input_cols]) + actual_pandas_result = actual_pandas_result[actual_output_cols].to_numpy() + np.testing.assert_allclose(actual_pandas_result.flatten(), sklearn_score_samples_array.flatten()) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/modeling/pipeline/BUILD.bazel b/tests/integ/snowflake/ml/modeling/pipeline/BUILD.bazel index c3b2af0f..972d1e14 100644 --- a/tests/integ/snowflake/ml/modeling/pipeline/BUILD.bazel +++ b/tests/integ/snowflake/ml/modeling/pipeline/BUILD.bazel @@ -12,6 +12,7 @@ py_test( srcs = ["pipeline_test.py"], shard_count = SHARD_COUNT, deps = [ + "//snowflake/ml/modeling/ensemble:isolation_forest", "//snowflake/ml/modeling/linear_model:linear_regression", "//snowflake/ml/modeling/linear_model:logistic_regression", "//snowflake/ml/modeling/pipeline", diff --git a/tests/integ/snowflake/ml/modeling/pipeline/pipeline_test.py b/tests/integ/snowflake/ml/modeling/pipeline/pipeline_test.py index 471504c5..4f4a766f 100644 --- a/tests/integ/snowflake/ml/modeling/pipeline/pipeline_test.py +++ b/tests/integ/snowflake/ml/modeling/pipeline/pipeline_test.py @@ -13,6 +13,7 @@ from absl.testing.absltest import TestCase, main from sklearn.compose import ColumnTransformer as SkColumnTransformer from sklearn.datasets import load_diabetes, load_iris +from sklearn.ensemble import IsolationForest as SklearnIsolationForest from sklearn.linear_model import ( LinearRegression as SklearnLinearRegression, LogisticRegression as SklearnLogisticRegression, @@ -25,6 +26,7 @@ from snowflake.ml.model.model_signature import DataType, FeatureSpec, ModelSignature from snowflake.ml.modeling import pipeline as snowml_pipeline +from snowflake.ml.modeling.ensemble import IsolationForest from snowflake.ml.modeling.linear_model import ( LinearRegression as SnowmlLinearRegression, LogisticRegression as SnowmlLogisticRegression, @@ -471,6 +473,36 @@ def test_pipeline_with_label_encoder_output_col(self) -> None: assert "TARGET_OUT" in snow_df_output.columns + def test_pipeline_score_samples(self) -> None: + input_df_pandas = load_iris(as_frame=True).frame + # Normalize column names + input_df_pandas.columns = [inflection.parameterize(c, "_").upper() for c in input_df_pandas.columns] + input_df_pandas["INDEX"] = input_df_pandas.reset_index().index + + input_df = self._session.create_dataframe(input_df_pandas) + + input_cols = [c for c in input_df_pandas.columns if not c.startswith("TARGET") and not c.startswith("INDEX")] + label_cols = ["TARGET"] + + estimator = IsolationForest(input_cols=input_cols, label_cols=label_cols, random_state=0) + + pipeline = snowml_pipeline.Pipeline(steps=[("estimator", estimator)]) + + # fit and predict + pipeline.fit(input_df) + output_df = pipeline.score_samples(input_df) + actual_results = ( + output_df.to_pandas().sort_values(by="INDEX")["SCORE_SAMPLES_"].astype(float).to_numpy().flatten() + ) + + # Do the same with SKLearn + skpipeline = SkPipeline(steps=[("estimator", SklearnIsolationForest(random_state=0))]) + + skpipeline.fit(input_df_pandas[input_cols], input_df_pandas[label_cols]) + sk_predict_results = skpipeline.score_samples(input_df_pandas[input_cols]) + + np.testing.assert_allclose(actual_results, sk_predict_results) + if __name__ == "__main__": main() diff --git a/tests/integ/snowflake/ml/registry/model/registry_custom_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_custom_model_test.py index 96371706..f9aaca9e 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_custom_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_custom_model_test.py @@ -108,6 +108,23 @@ def test_custom_demo_model_sp( }, ) + def test_custom_demo_model_decimal( + self, + ) -> None: + import decimal + + lm = DemoModel(custom_model.ModelContext()) + arr = [[decimal.Decimal(1.2), 2.3, 3.4], [decimal.Decimal(4.6), 2.7, 5.5]] + sp_df = self._session.create_dataframe(arr, schema=['"c1"', '"c2"', '"c3"']) + y_df_expected = pd.DataFrame([[1.2, 2.3, 3.4, 1.2], [4.6, 2.7, 5.5, 4.6]], columns=["c1", "c2", "c3", "output"]) + self._test_registry_model( + model=lm, + sample_input=sp_df, + prediction_assert_fns={ + "predict": (sp_df, lambda res: dataframe_utils.check_sp_df_res(res, y_df_expected, check_dtype=False)) + }, + ) + def test_custom_demo_model_sp_one_query( self, ) -> None: diff --git a/third_party/rules_conda/env.bzl b/third_party/rules_conda/env.bzl index f1f293f1..0d44ae1c 100644 --- a/third_party/rules_conda/env.bzl +++ b/third_party/rules_conda/env.bzl @@ -156,7 +156,7 @@ conda_create_rule = repository_rule( "python_version": attr.string( mandatory = True, doc = "The Python version to use when creating the environment.", - values = ["3.8", "3.9", "3.10"], + values = ["3.8", "3.9", "3.10", "3.11"], ), "quiet": attr.bool( default = True,