@@ -12,8 +12,8 @@ import numpy as np
12
12
{transform.estimator_imports}
13
13
from sklearn.utils.metaestimators import available_if
14
14
15
- from snowflake.ml.framework.base import BaseEstimator, BaseTransformer
16
- from snowflake.ml.utils import telemetry
15
+ from snowflake.ml.framework.base import BaseTransformer
16
+ from snowflake.ml._internal import telemetry
17
17
from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
18
18
from snowflake.ml._internal.utils import pkg_version_utils, identifier
19
19
from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get_temp_file_path
@@ -98,7 +98,7 @@ def _validate_sklearn_args(args: Dict[str, Any], klass: type) -> Dict[str, Any]:
98
98
return result
99
99
100
100
101
- class {transform.original_class_name}(BaseEstimator, BaseTransformer):
101
+ class {transform.original_class_name}(BaseTransformer):
102
102
r"""{transform.estimator_class_docstring}
103
103
"""
104
104
@@ -203,9 +203,6 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
203
203
local_result_file_name = get_temp_file_path()
204
204
stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
205
205
206
- # Put locally serialized transform on stage.
207
- session.file.put(local_transform_file_name, stage_transform_file_name, auto_compress=False, overwrite=True)
208
-
209
206
fit_sproc_name = "SNOWML_FIT_{{safe_id}}".format(safe_id=self.id)
210
207
statement_params = telemetry.get_function_usage_statement_params(
211
208
project=_PROJECT,
@@ -216,6 +213,8 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
216
213
api_calls=[sproc],
217
214
custom_tags=dict([("autogen", True)]),
218
215
)
216
+ # Put locally serialized transform on stage.
217
+ session.file.put(local_transform_file_name, stage_transform_file_name, auto_compress=False, overwrite=True, statement_params=statement_params)
219
218
220
219
@sproc(
221
220
is_permanent=False,
@@ -244,13 +243,13 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
244
243
245
244
# Execute snowpark query and obtain the results as pandas dataframe
246
245
# NB: this implies that the result data must fit into memory.
247
- df = session.sql(sql_query).to_pandas()
246
+ df = session.sql(sql_query).to_pandas(statement_params=statement_params )
248
247
249
248
local_transform_file = tempfile.NamedTemporaryFile(delete=True)
250
249
local_transform_file_name = local_transform_file.name
251
250
local_transform_file.close()
252
251
253
- session.file.get(stage_transform_file_name, local_transform_file_name)
252
+ session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params )
254
253
255
254
estimator = joblib.load(os.path.join(local_transform_file_name, os.listdir(local_transform_file_name)[0]))
256
255
@@ -270,7 +269,7 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
270
269
local_result_file.close()
271
270
272
271
joblib_dump_files = joblib.dump(estimator, local_result_file_name)
273
- session.file.put(local_result_file_name, stage_result_file_name, auto_compress = False, overwrite = True)
272
+ session.file.put(local_result_file_name, stage_result_file_name, auto_compress = False, overwrite = True, statement_params=statement_params )
274
273
275
274
# Note: you can add something like + "|" + str(df) to the return string
276
275
# to pass debug information to the caller.
@@ -303,7 +302,7 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
303
302
if len(fields) > 1:
304
303
print("\n".join(fields[1:]))
305
304
306
- session.file.get(os.path.join(stage_result_file_name, sproc_export_file_name), local_result_file_name)
305
+ session.file.get(os.path.join(stage_result_file_name, sproc_export_file_name), local_result_file_name, statement_params=statement_params )
307
306
self._sklearn_object = joblib.load(os.path.join(local_result_file_name, sproc_export_file_name))
308
307
309
308
cleanup_temp_files([local_transform_file_name, local_result_file_name])
@@ -543,6 +542,11 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
543
542
subproject=_SUBPROJECT,
544
543
custom_tags=dict([("autogen", True)]),
545
544
)
545
+ @telemetry.add_stmt_params_to_df(
546
+ project=_PROJECT,
547
+ subproject=_SUBPROJECT,
548
+ custom_tags=dict([("autogen", True)]),
549
+ )
546
550
def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
547
551
"""Predict lable values for each example in the input dataset.
548
552
@@ -578,6 +582,11 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
578
582
subproject=_SUBPROJECT,
579
583
custom_tags=dict([("autogen", True)]),
580
584
)
585
+ @telemetry.add_stmt_params_to_df(
586
+ project=_PROJECT,
587
+ subproject=_SUBPROJECT,
588
+ custom_tags=dict([("autogen", True)]),
589
+ )
581
590
def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
582
591
"""Transform the dataset.
583
592
@@ -647,6 +656,11 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
647
656
subproject=_SUBPROJECT,
648
657
custom_tags=dict([("autogen", True)]),
649
658
)
659
+ @telemetry.add_stmt_params_to_df(
660
+ project=_PROJECT,
661
+ subproject=_SUBPROJECT,
662
+ custom_tags=dict([("autogen", True)]),
663
+ )
650
664
def predict_proba(
651
665
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_proba_"
652
666
) -> Union[DataFrame, pd.DataFrame]:
@@ -688,6 +702,11 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
688
702
subproject=_SUBPROJECT,
689
703
custom_tags=dict([("autogen", True)]),
690
704
)
705
+ @telemetry.add_stmt_params_to_df(
706
+ project=_PROJECT,
707
+ subproject=_SUBPROJECT,
708
+ custom_tags=dict([("autogen", True)]),
709
+ )
691
710
def predict_log_proba(
692
711
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_log_proba_"
693
712
) -> Union[DataFrame, pd.DataFrame]:
@@ -729,6 +748,11 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
729
748
subproject=_SUBPROJECT,
730
749
custom_tags=dict([("autogen", True)]),
731
750
)
751
+ @telemetry.add_stmt_params_to_df(
752
+ project=_PROJECT,
753
+ subproject=_SUBPROJECT,
754
+ custom_tags=dict([("autogen", True)]),
755
+ )
732
756
def decision_function(
733
757
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "decision_function_"
734
758
) -> Union[DataFrame, pd.DataFrame]:
@@ -837,9 +861,6 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
837
861
).validate()
838
862
839
863
stage_score_file_name = os.path.join(score_stage_name, os.path.basename(local_score_file_name))
840
- # Put locally serialized score on stage.
841
- session.file.put(local_score_file_name, stage_score_file_name, auto_compress=False, overwrite=True)
842
-
843
864
score_sproc_name = "SNOWML_SCORE_{{safe_id}}".format(safe_id=self.id)
844
865
statement_params = telemetry.get_function_usage_statement_params(
845
866
project=_PROJECT,
@@ -850,6 +871,9 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
850
871
api_calls=[sproc],
851
872
custom_tags=dict([("autogen", True)]),
852
873
)
874
+ # Put locally serialized score on stage.
875
+ session.file.put(local_score_file_name, stage_score_file_name, auto_compress=False, overwrite=True, statement_params=statement_params)
876
+
853
877
@sproc(
854
878
is_permanent=False,
855
879
name=score_sproc_name,
@@ -874,13 +898,13 @@ class {transform.original_class_name}(BaseEstimator, BaseTransformer):
874
898
import inspect
875
899
{transform.fit_sproc_imports}
876
900
877
- df = session.sql(sql_query).to_pandas()
901
+ df = session.sql(sql_query).to_pandas(statement_params=statement_params )
878
902
879
903
local_score_file = tempfile.NamedTemporaryFile(delete=True)
880
904
local_score_file_name = local_score_file.name
881
905
local_score_file.close()
882
906
883
- session.file.get(stage_score_file_name, local_score_file_name)
907
+ session.file.get(stage_score_file_name, local_score_file_name, statement_params=statement_params )
884
908
estimator = joblib.load(os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0]))
885
909
argspec = inspect.getfullargspec(estimator.score)
886
910
if "X" in argspec.args:
0 commit comments