@@ -369,17 +369,11 @@ class WrapperGeneratorBase:
369
369
370
370
original_class_name INFERRED Class name for the given scikit-learn
371
371
estimator.
372
- estimator_class_name GENERATED Name for the new estimator class.
373
- transformer_class_name GENERATED [TODO] Name for the new transformer
374
- class.
375
372
module_name INFERRED Name of the module that given class is
376
373
is contained in.
377
374
estimator_imports GENERATED Imports needed for the estimator / fit()
378
375
call.
379
376
fit_sproc_imports GENERATED Imports needed for the fit sproc call.
380
- transform_function_name INFERRED Name for the transformer function. This
381
- will be one of "transform" or
382
- "predict()" depending on the class.
383
377
------------------------------------------------------------------------------------
384
378
SIGNATURES AND ARGUMENTS
385
379
------------------------------------------------------------------------------------
@@ -444,9 +438,6 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None:
444
438
445
439
# Naming of the class.
446
440
self .original_class_name = ""
447
- self .estimator_class_name = ""
448
- self .transformer_class_name = ""
449
- self .transform_function_name = ""
450
441
451
442
# The signature and argument passing the __init__ functions.
452
443
self .original_init_signature = inspect .Signature ()
@@ -456,33 +447,32 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None:
456
447
self .sklearn_init_args_dict = ""
457
448
self .estimator_init_member_args = ""
458
449
450
+ # Doc strings
459
451
self .original_class_docstring = ""
460
452
self .estimator_class_docstring = ""
461
453
self .transformer_class_docstring = ""
462
-
463
- self .estimator_imports = ""
464
- self .estimator_imports_list : List [str ] = []
465
-
466
454
self .original_fit_docstring = ""
467
455
self .fit_docstring = ""
468
456
self .original_transform_docstring = ""
469
457
self .transform_docstring = ""
470
458
459
+ # Import strings
460
+ self .estimator_imports = ""
461
+ self .estimator_imports_list : List [str ] = []
462
+ self .additional_import_statements = ""
463
+
464
+ # Test strings
471
465
self .test_dataset_func = ""
472
466
self .test_estimator_input_args = ""
473
467
self .test_estimator_input_args_list : List [str ] = []
474
468
self .test_class_name = ""
475
469
self .test_estimator_imports = ""
476
470
self .test_estimator_imports_list : List [str ] = []
477
471
478
- self .additional_import_statements = ""
479
-
472
+ # Dependencies
480
473
self .predict_udf_deps = ""
481
474
self .fit_sproc_deps = ""
482
475
483
- # TODO(amauser): Make fit a no-op if there is no internal state
484
- # TODO(amauser): handling sparse input and output (LabelBinarizer)
485
-
486
476
def _format_default_value (self , default_value : Any ) -> str :
487
477
if isinstance (default_value , str ):
488
478
return f'"{ default_value } "'
@@ -561,26 +551,13 @@ def split_long_lines(line: str) -> str:
561
551
self .estimator_class_docstring = class_docstring
562
552
563
553
def _populate_class_names (self ) -> None :
564
- # TODO(snandamuri): All the 3 fields have exact same value. Do we really need these
565
- # 3 separate fields?
566
554
self .original_class_name = self .class_object [0 ]
567
- self .estimator_class_name = self .original_class_name
568
- self .transformer_class_name = self .estimator_class_name
569
-
570
555
self .test_class_name = f"{ self .original_class_name } Test"
571
556
572
557
def _populate_function_names_and_signatures (self ) -> None :
573
558
for member in inspect .getmembers (self .class_object [1 ]):
574
559
if member [0 ] == "__init__" :
575
560
self .original_init_signature = inspect .signature (member [1 ])
576
- elif member [0 ] == "predict" or member [0 ] == "transform" :
577
- if self .transform_function_name != "" :
578
- print ("ERROR: Class has both transform() and predict() methods." )
579
- # TODO(snandamuri): Add support for both transform() and predict() methods in estimators.
580
- # For now, resolve to predict() method when both predict() and transform() are available.
581
- self .transform_function_name = "predict"
582
- else :
583
- self .transform_function_name = member [0 ]
584
561
585
562
signature_lines = []
586
563
sklearn_init_lines = []
@@ -642,6 +619,7 @@ def _populate_function_names_and_signatures(self) -> None:
642
619
self .estimator_init_member_args = "\n " .join (init_member_args )
643
620
self .estimator_args_transform_calls = "\n " .join (arg_transform_calls )
644
621
622
+ # TODO(snandamuri): Implement type inference for classifiers.
645
623
self .udf_datatype = "float" if self ._from_data_py or self ._is_regressor else ""
646
624
647
625
def _populate_file_paths (self ) -> None :
@@ -825,7 +803,7 @@ def generate(self) -> "SklearnWrapperGenerator":
825
803
self .test_estimator_input_args_list .extend (["min_samples_leaf=1" , "max_leaf_nodes=100" ])
826
804
827
805
self .fit_sproc_deps = self .predict_udf_deps = (
828
- "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'scikit-learn=={sklearn.__version__}',"
806
+ "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'scikit-learn=={sklearn.__version__}', "
829
807
"f'xgboost=={xgboost.__version__}', f'joblib=={joblib.__version__}'"
830
808
)
831
809
self ._construct_string_from_lists ()
@@ -842,7 +820,7 @@ def generate(self) -> "XGBoostWrapperGenerator":
842
820
self .test_estimator_input_args_list .extend (["random_state=0" , "subsample=1.0" , "colsample_bynode=1.0" ])
843
821
self .fit_sproc_imports = "import xgboost"
844
822
self .fit_sproc_deps = self .predict_udf_deps = (
845
- "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'xgboost=={xgboost.__version__}',"
823
+ "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'xgboost=={xgboost.__version__}', "
846
824
"f'joblib=={joblib.__version__}'"
847
825
)
848
826
self ._construct_string_from_lists ()
@@ -859,7 +837,7 @@ def generate(self) -> "LightGBMWrapperGenerator":
859
837
self .test_estimator_input_args_list .extend (["random_state=0" ])
860
838
self .fit_sproc_imports = "import lightgbm"
861
839
self .fit_sproc_deps = self .predict_udf_deps = (
862
- "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'lightgbm=={lightgbm.__version__}',"
840
+ "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'lightgbm=={lightgbm.__version__}', "
863
841
"f'joblib=={joblib.__version__}'"
864
842
)
865
843
self ._construct_string_from_lists ()
0 commit comments