Skip to content

Commit e9aac2d

Browse files
snowflake-provisionerSnowflake Authors
and
Snowflake Authors
authored
Project import generated by Copybara. (#14)
GitOrigin-RevId: 09d233ee68cf95998a2373d91e0a91da685df8e6 Co-authored-by: Snowflake Authors <[email protected]>
1 parent 4044414 commit e9aac2d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1441
-938
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ repos:
3131
# The first two lines of meta.yaml does not work with check-yaml
3232
exclude: >
3333
(?x)^(
34-
recipe/meta.yaml|
34+
ci/conda_recipe/meta.yaml|
3535
.github/repo_meta.yaml|
3636
)$
3737
- id: debug-statements

ci/conda_recipe/meta.yaml

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
# (dependencies, version number) from a common place. We also need to define that
33
# common place, as currently it's a BUILD rule.
44
# See https://docs.conda.io/projects/conda-build/en/latest/resources/define-metadata.html#templating-with-jinja
5+
{% set version_match = load_file_regex(load_file='../../snowflake/ml/version.bzl', regex_pattern='VERSION = "(\d\.\d\.\d*)"\s.*') %}
56

67
package:
78
name: snowflake-ml-python
8-
version: 0.3.0 # this has to be in sync with snowflake/ml/BUILD.bazel and snowflake/ml/version.py
9+
version: {{ version_match.group(1) }}
910

1011
source:
1112
path: ../../
@@ -26,7 +27,7 @@ requirements:
2627
- pyyaml>=6.0,<7
2728
- scipy>=1.9,<2
2829
- snowflake-connector-python
29-
- snowflake-snowpark-python>=1.3.0,<=2
30+
- snowflake-snowpark-python>=1.4.0,<=2
3031
- sqlparse>=0.4,<1
3132

3233
# TODO(snandamuri): Versions of these packages must be exactly same between user's workspace and
@@ -35,7 +36,7 @@ requirements:
3536
# versions that are available in the snowflake conda channel. Since there is no way to specify allow list of
3637
# versions in the requirements file, we are pinning the versions here.
3738
- joblib>=1.0.0,<=1.1.1
38-
- scikit-learn==1.2.1
39+
- scikit-learn>=1.2.1,<2
3940
- xgboost==1.7.3
4041
about:
4142
home: https://github.com/snowflakedb/snowflake-ml-python

codegen/codegen_rules.bzl

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def autogen_estimators(module, estimator_info_list):
8383
deps = [
8484
":init",
8585
"//snowflake/ml/framework:framework",
86-
"//snowflake/ml/utils:telemetry",
86+
"//snowflake/ml/_internal:telemetry",
8787
"//snowflake/ml/_internal/utils:temp_file_utils",
8888
"//snowflake/ml/_internal/utils:query_result_checker",
8989
"//snowflake/ml/_internal/utils:pkg_version_utils",

codegen/sklearn_wrapper_template.py_template

+39-15
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ import numpy as np
1212
{transform.estimator_imports}
1313
from sklearn.utils.metaestimators import available_if
1414

15-
from snowflake.ml.framework.base import BaseEstimator, BaseTransformer
16-
from snowflake.ml.utils import telemetry
15+
from snowflake.ml.framework.base import BaseTransformer
16+
from snowflake.ml._internal import telemetry
1717
from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
1818
from snowflake.ml._internal.utils import pkg_version_utils, identifier
1919
from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get_temp_file_path
@@ -98,7 +98,7 @@ def _validate_sklearn_args(args: Dict[str, Any], klass: type) -> Dict[str, Any]:
9898
return result
9999

100100

101-
class {transform.original_class_name}(BaseEstimator, BaseTransformer):
101+
class {transform.original_class_name}(BaseTransformer):
102102
r"""{transform.estimator_class_docstring}
103103
"""
104104

@@ -203,9 +203,6 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
203203
local_result_file_name = get_temp_file_path()
204204
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
205205

206-
# Put locally serialized transform on stage.
207-
session.file.put(local_transform_file_name, stage_transform_file_name, auto_compress=False, overwrite=True)
208-
209206
fit_sproc_name = "SNOWML_FIT_{{safe_id}}".format(safe_id=self.id)
210207
statement_params = telemetry.get_function_usage_statement_params(
211208
project=_PROJECT,
@@ -216,6 +213,8 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
216213
api_calls=[sproc],
217214
custom_tags=dict([("autogen", True)]),
218215
)
216+
# Put locally serialized transform on stage.
217+
session.file.put(local_transform_file_name, stage_transform_file_name, auto_compress=False, overwrite=True, statement_params=statement_params)
219218

220219
@sproc(
221220
is_permanent=False,
@@ -244,13 +243,13 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
244243

245244
# Execute snowpark query and obtain the results as pandas dataframe
246245
# NB: this implies that the result data must fit into memory.
247-
df = session.sql(sql_query).to_pandas()
246+
df = session.sql(sql_query).to_pandas(statement_params=statement_params)
248247

249248
local_transform_file = tempfile.NamedTemporaryFile(delete=True)
250249
local_transform_file_name = local_transform_file.name
251250
local_transform_file.close()
252251

253-
session.file.get(stage_transform_file_name, local_transform_file_name)
252+
session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
254253

255254
estimator = joblib.load(os.path.join(local_transform_file_name, os.listdir(local_transform_file_name)[0]))
256255

@@ -270,7 +269,7 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
270269
local_result_file.close()
271270

272271
joblib_dump_files = joblib.dump(estimator, local_result_file_name)
273-
session.file.put(local_result_file_name, stage_result_file_name, auto_compress = False, overwrite = True)
272+
session.file.put(local_result_file_name, stage_result_file_name, auto_compress = False, overwrite = True, statement_params=statement_params)
274273

275274
# Note: you can add something like + "|" + str(df) to the return string
276275
# to pass debug information to the caller.
@@ -303,7 +302,7 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
303302
if len(fields) > 1:
304303
print("\n".join(fields[1:]))
305304

306-
session.file.get(os.path.join(stage_result_file_name, sproc_export_file_name), local_result_file_name)
305+
session.file.get(os.path.join(stage_result_file_name, sproc_export_file_name), local_result_file_name, statement_params=statement_params)
307306
self._sklearn_object = joblib.load(os.path.join(local_result_file_name, sproc_export_file_name))
308307

309308
cleanup_temp_files([local_transform_file_name, local_result_file_name])
@@ -543,6 +542,11 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
543542
subproject=_SUBPROJECT,
544543
custom_tags=dict([("autogen", True)]),
545544
)
545+
@telemetry.add_stmt_params_to_df(
546+
project=_PROJECT,
547+
subproject=_SUBPROJECT,
548+
custom_tags=dict([("autogen", True)]),
549+
)
546550
def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
547551
"""Predict lable values for each example in the input dataset.
548552

@@ -578,6 +582,11 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
578582
subproject=_SUBPROJECT,
579583
custom_tags=dict([("autogen", True)]),
580584
)
585+
@telemetry.add_stmt_params_to_df(
586+
project=_PROJECT,
587+
subproject=_SUBPROJECT,
588+
custom_tags=dict([("autogen", True)]),
589+
)
581590
def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
582591
"""Transform the dataset.
583592

@@ -647,6 +656,11 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
647656
subproject=_SUBPROJECT,
648657
custom_tags=dict([("autogen", True)]),
649658
)
659+
@telemetry.add_stmt_params_to_df(
660+
project=_PROJECT,
661+
subproject=_SUBPROJECT,
662+
custom_tags=dict([("autogen", True)]),
663+
)
650664
def predict_proba(
651665
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_proba_"
652666
) -> Union[DataFrame, pd.DataFrame]:
@@ -688,6 +702,11 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
688702
subproject=_SUBPROJECT,
689703
custom_tags=dict([("autogen", True)]),
690704
)
705+
@telemetry.add_stmt_params_to_df(
706+
project=_PROJECT,
707+
subproject=_SUBPROJECT,
708+
custom_tags=dict([("autogen", True)]),
709+
)
691710
def predict_log_proba(
692711
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_log_proba_"
693712
) -> Union[DataFrame, pd.DataFrame]:
@@ -729,6 +748,11 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
729748
subproject=_SUBPROJECT,
730749
custom_tags=dict([("autogen", True)]),
731750
)
751+
@telemetry.add_stmt_params_to_df(
752+
project=_PROJECT,
753+
subproject=_SUBPROJECT,
754+
custom_tags=dict([("autogen", True)]),
755+
)
732756
def decision_function(
733757
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "decision_function_"
734758
) -> Union[DataFrame, pd.DataFrame]:
@@ -837,9 +861,6 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
837861
).validate()
838862

839863
stage_score_file_name = os.path.join(score_stage_name, os.path.basename(local_score_file_name))
840-
# Put locally serialized score on stage.
841-
session.file.put(local_score_file_name, stage_score_file_name, auto_compress=False, overwrite=True)
842-
843864
score_sproc_name = "SNOWML_SCORE_{{safe_id}}".format(safe_id=self.id)
844865
statement_params = telemetry.get_function_usage_statement_params(
845866
project=_PROJECT,
@@ -850,6 +871,9 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
850871
api_calls=[sproc],
851872
custom_tags=dict([("autogen", True)]),
852873
)
874+
# Put locally serialized score on stage.
875+
session.file.put(local_score_file_name, stage_score_file_name, auto_compress=False, overwrite=True, statement_params=statement_params)
876+
853877
@sproc(
854878
is_permanent=False,
855879
name=score_sproc_name,
@@ -874,13 +898,13 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
874898
import inspect
875899
{transform.fit_sproc_imports}
876900

877-
df = session.sql(sql_query).to_pandas()
901+
df = session.sql(sql_query).to_pandas(statement_params=statement_params)
878902

879903
local_score_file = tempfile.NamedTemporaryFile(delete=True)
880904
local_score_file_name = local_score_file.name
881905
local_score_file.close()
882906

883-
session.file.get(stage_score_file_name, local_score_file_name)
907+
session.file.get(stage_score_file_name, local_score_file_name, statement_params=statement_params)
884908
estimator = joblib.load(os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0]))
885909
argspec = inspect.getfullargspec(estimator.score)
886910
if "X" in argspec.args:

conda-env-snowflake.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ dependencies:
3232
- ruamel.yaml==0.17.21
3333
- s3fs==2022.10.0
3434
- scipy==1.9.3
35-
- scikit-learn==1.2.1
36-
- snowflake-snowpark-python==1.3.0
35+
- scikit-learn==1.2.2
36+
- snowflake-snowpark-python==1.4.0
3737
- sqlparse==0.4.3
3838
- typing-extensions==4.3.0
3939
- xgboost==1.7.3

conda-env.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ dependencies:
2828
- pytorch==1.12.1
2929
- ruamel.yaml==0.17.21
3030
- s3fs==2022.10.0
31-
- scikit-learn==1.2.1
31+
- scikit-learn==1.2.2
3232
- scipy==1.9.3
33-
- snowflake-snowpark-python==1.3.0
33+
- snowflake-snowpark-python==1.4.0
3434
- sqlparse==0.4.3
3535
- tensorflow==2.9.1
3636
- torchdata==0.4.1

snowflake/ml/BUILD.bazel

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("//bazel:py_rules.bzl", "py_library", "snowml_wheel")
2+
load(":version.bzl", "VERSION")
23

34
package(default_visibility = ["//visibility:public"])
45

@@ -10,6 +11,12 @@ _LIGHTGBM_REQUIRES = ["lightgbm==3.3.5"]
1011

1112
_ALL_REQUIRES = _TENSORFLOW_REQUIRES + _PYTORCH_REQUIRES + _LIGHTGBM_REQUIRES
1213

14+
genrule(
15+
name = "generate_version",
16+
outs = ["version.py"],
17+
cmd = "echo 'VERSION=\"" + VERSION + "\"'> $@",
18+
)
19+
1320
py_library(
1421
name = "version",
1522
srcs = ["version.py"],
@@ -36,19 +43,19 @@ snowml_wheel(
3643
"pyyaml>=6.0,<7",
3744
"scipy>=1.9,<2",
3845
"snowflake-connector-python[pandas]",
39-
"snowflake-snowpark-python>=1.3.0,<2",
46+
"snowflake-snowpark-python>=1.4.0,<2",
4047
"sqlparse>=0.4,<1",
4148

4249
# TODO(snandamuri): Versions of these packages must be exactly same between user's workspace and
4350
# snowpark sandbox. Generic definitions like scikit-learn>=1.1.0,<2 wont work because snowflake conda channel
4451
# only has a few allowlisted versions of scikit-learn available, so we must force users to use scikit-learn
4552
# versions that are available in the snowflake conda channel. Since there is no way to specify allow list of
4653
# versions in the requirements file, we are pinning the versions here.
47-
"scikit-learn==1.2.1",
54+
"scikit-learn>=1.2.1,<2",
4855
"xgboost==1.7.3",
4956
"joblib>=1.0.0,<=1.1.1", # All the release versions between 1.0.0 and 1.1.1 are available in SF Conda channel.
5057
],
51-
version = "0.3.0", # this has to be in sync with version.py and ci/conda_recipe/meta.yaml
58+
version = VERSION,
5259
deps = [
5360
"//snowflake/ml/metrics:metrics_pkg",
5461
"//snowflake/ml/preprocessing:preprocessing_pkg",

snowflake/ml/_internal/BUILD.bazel

+17
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,20 @@ py_test(
5050
"//snowflake/ml/test_utils:mock_session",
5151
],
5252
)
53+
54+
py_library(
55+
name = "telemetry",
56+
srcs = ["telemetry.py"],
57+
deps = [
58+
"//snowflake/ml/_internal:env",
59+
],
60+
)
61+
62+
py_test(
63+
name = "telemetry_test",
64+
srcs = ["telemetry_test.py"],
65+
deps = [
66+
":telemetry",
67+
"//snowflake/ml/_internal:env",
68+
],
69+
)

snowflake/ml/_internal/env.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
from snowflake.ml import version
44

55
SOURCE = "SnowML"
6-
VERSION = version.get_version()
6+
VERSION = version.VERSION
77
PYTHON_VERSION = platform.python_version()
88
OS = platform.system()

snowflake/ml/_internal/env_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def validate_requirements_in_snowflake_conda_channel(
275275
)
276276
sql = textwrap.dedent(
277277
f"""
278-
SELECT *
278+
SELECT PACKAGE_NAME, VERSION
279279
FROM information_schema.packages
280280
WHERE ({pkg_names_str})
281281
AND language = 'python';
@@ -289,7 +289,7 @@ def validate_requirements_in_snowflake_conda_channel(
289289
query=sql,
290290
)
291291
.has_column("VERSION")
292-
.has_dimensions(expected_rows=None, expected_cols=3)
292+
.has_dimensions(expected_rows=None, expected_cols=2)
293293
.validate()
294294
)
295295
for row in result:

0 commit comments

Comments
 (0)