Skip to content

Commit 0cd258b

Browse files
snowflake-provisionerSnowflake Authors
and
Snowflake Authors
authored
Project import generated by Copybara. (#11)
GitOrigin-RevId: bb45da353e3d3cc7b84b11926c2c294eb1ca6359 Co-authored-by: Snowflake Authors <[email protected]>
1 parent 46c50a4 commit 0cd258b

10 files changed

+740
-872
lines changed

ci/conda_recipe/meta.yaml

+10-2
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,22 @@ requirements:
2020
run:
2121
- python
2222
- absl-py>=0.15,<2
23+
- anyio>=3.5.0,<4
2324
- fsspec>=2022.11,<=2023.1
2425
- numpy>=1.23,<1.24
2526
- pyyaml>=6.0,<7
2627
- scipy>=1.9,<2
27-
- scikit-learn==1.2.1
2828
- snowflake-connector-python
29-
- snowflake-snowpark-python>=1.0.0,<=1.3
29+
- snowflake-snowpark-python>=1.3.0,<=2
3030
- sqlparse>=0.4,<1
31+
32+
# TODO(snandamuri): Versions of these packages must be exactly same between user's workspace and
33+
# snowpark sandbox. Generic definitions like scikit-learn>=1.1.0,<2 wont work because snowflake conda channel
34+
# only has a few allowlisted versions of scikit-learn available, so we must force users to use scikit-learn
35+
# versions that are available in the snowflake conda channel. Since there is no way to specify allow list of
36+
# versions in the requirements file, we are pinning the versions here.
37+
- joblib>=1.0.0,<=1.1.1
38+
- scikit-learn==1.2.1
3139
- xgboost==1.7.3
3240
about:
3341
home: https://github.com/snowflakedb/snowflake-ml-python

codegen/sklearn_wrapper_generator.py

+12-34
Original file line numberDiff line numberDiff line change
@@ -369,17 +369,11 @@ class WrapperGeneratorBase:
369369
370370
original_class_name INFERRED Class name for the given scikit-learn
371371
estimator.
372-
estimator_class_name GENERATED Name for the new estimator class.
373-
transformer_class_name GENERATED [TODO] Name for the new transformer
374-
class.
375372
module_name INFERRED Name of the module that given class is
376373
is contained in.
377374
estimator_imports GENERATED Imports needed for the estimator / fit()
378375
call.
379376
fit_sproc_imports GENERATED Imports needed for the fit sproc call.
380-
transform_function_name INFERRED Name for the transformer function. This
381-
will be one of "transform" or
382-
"predict()" depending on the class.
383377
------------------------------------------------------------------------------------
384378
SIGNATURES AND ARGUMENTS
385379
------------------------------------------------------------------------------------
@@ -444,9 +438,6 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None:
444438

445439
# Naming of the class.
446440
self.original_class_name = ""
447-
self.estimator_class_name = ""
448-
self.transformer_class_name = ""
449-
self.transform_function_name = ""
450441

451442
# The signature and argument passing the __init__ functions.
452443
self.original_init_signature = inspect.Signature()
@@ -456,33 +447,32 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None:
456447
self.sklearn_init_args_dict = ""
457448
self.estimator_init_member_args = ""
458449

450+
# Doc strings
459451
self.original_class_docstring = ""
460452
self.estimator_class_docstring = ""
461453
self.transformer_class_docstring = ""
462-
463-
self.estimator_imports = ""
464-
self.estimator_imports_list: List[str] = []
465-
466454
self.original_fit_docstring = ""
467455
self.fit_docstring = ""
468456
self.original_transform_docstring = ""
469457
self.transform_docstring = ""
470458

459+
# Import strings
460+
self.estimator_imports = ""
461+
self.estimator_imports_list: List[str] = []
462+
self.additional_import_statements = ""
463+
464+
# Test strings
471465
self.test_dataset_func = ""
472466
self.test_estimator_input_args = ""
473467
self.test_estimator_input_args_list: List[str] = []
474468
self.test_class_name = ""
475469
self.test_estimator_imports = ""
476470
self.test_estimator_imports_list: List[str] = []
477471

478-
self.additional_import_statements = ""
479-
472+
# Dependencies
480473
self.predict_udf_deps = ""
481474
self.fit_sproc_deps = ""
482475

483-
# TODO(amauser): Make fit a no-op if there is no internal state
484-
# TODO(amauser): handling sparse input and output (LabelBinarizer)
485-
486476
def _format_default_value(self, default_value: Any) -> str:
487477
if isinstance(default_value, str):
488478
return f'"{default_value}"'
@@ -561,26 +551,13 @@ def split_long_lines(line: str) -> str:
561551
self.estimator_class_docstring = class_docstring
562552

563553
def _populate_class_names(self) -> None:
564-
# TODO(snandamuri): All the 3 fields have exact same value. Do we really need these
565-
# 3 separate fields?
566554
self.original_class_name = self.class_object[0]
567-
self.estimator_class_name = self.original_class_name
568-
self.transformer_class_name = self.estimator_class_name
569-
570555
self.test_class_name = f"{self.original_class_name}Test"
571556

572557
def _populate_function_names_and_signatures(self) -> None:
573558
for member in inspect.getmembers(self.class_object[1]):
574559
if member[0] == "__init__":
575560
self.original_init_signature = inspect.signature(member[1])
576-
elif member[0] == "predict" or member[0] == "transform":
577-
if self.transform_function_name != "":
578-
print("ERROR: Class has both transform() and predict() methods.")
579-
# TODO(snandamuri): Add support for both transform() and predict() methods in estimators.
580-
# For now, resolve to predict() method when both predict() and transform() are available.
581-
self.transform_function_name = "predict"
582-
else:
583-
self.transform_function_name = member[0]
584561

585562
signature_lines = []
586563
sklearn_init_lines = []
@@ -642,6 +619,7 @@ def _populate_function_names_and_signatures(self) -> None:
642619
self.estimator_init_member_args = "\n ".join(init_member_args)
643620
self.estimator_args_transform_calls = "\n ".join(arg_transform_calls)
644621

622+
# TODO(snandamuri): Implement type inference for classifiers.
645623
self.udf_datatype = "float" if self._from_data_py or self._is_regressor else ""
646624

647625
def _populate_file_paths(self) -> None:
@@ -825,7 +803,7 @@ def generate(self) -> "SklearnWrapperGenerator":
825803
self.test_estimator_input_args_list.extend(["min_samples_leaf=1", "max_leaf_nodes=100"])
826804

827805
self.fit_sproc_deps = self.predict_udf_deps = (
828-
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'scikit-learn=={sklearn.__version__}',"
806+
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'scikit-learn=={sklearn.__version__}', "
829807
"f'xgboost=={xgboost.__version__}', f'joblib=={joblib.__version__}'"
830808
)
831809
self._construct_string_from_lists()
@@ -842,7 +820,7 @@ def generate(self) -> "XGBoostWrapperGenerator":
842820
self.test_estimator_input_args_list.extend(["random_state=0", "subsample=1.0", "colsample_bynode=1.0"])
843821
self.fit_sproc_imports = "import xgboost"
844822
self.fit_sproc_deps = self.predict_udf_deps = (
845-
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'xgboost=={xgboost.__version__}',"
823+
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'xgboost=={xgboost.__version__}', "
846824
"f'joblib=={joblib.__version__}'"
847825
)
848826
self._construct_string_from_lists()
@@ -859,7 +837,7 @@ def generate(self) -> "LightGBMWrapperGenerator":
859837
self.test_estimator_input_args_list.extend(["random_state=0"])
860838
self.fit_sproc_imports = "import lightgbm"
861839
self.fit_sproc_deps = self.predict_udf_deps = (
862-
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'lightgbm=={lightgbm.__version__}',"
840+
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'lightgbm=={lightgbm.__version__}', "
863841
"f'joblib=={joblib.__version__}'"
864842
)
865843
self._construct_string_from_lists()

0 commit comments

Comments
 (0)