@@ -83,24 +83,6 @@ class {transform.original_class_name}(BaseTransformer):
83
83
"""
84
84
return str(uuid4()).replace("-", "_").upper()
85
85
86
- def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
87
- """
88
- Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
89
-
90
- Args:
91
- dataset: Input dataset.
92
- """
93
- if not self.input_cols:
94
- cols = [
95
- c for c in dataset.columns
96
- if c not in self.get_label_cols() and c != self.sample_weight_col
97
- ]
98
- self.set_input_cols(input_cols=cols)
99
-
100
- if not self.output_cols:
101
- cols = [identifier.concat_names(ids=['OUTPUT_', c]) for c in self.label_cols]
102
- self.set_output_cols(output_cols=cols)
103
-
104
86
def set_input_cols(self, input_cols: Optional[Union[str, Iterable[str]]]) -> "{transform.original_class_name}":
105
87
"""
106
88
Input columns setter.
@@ -737,12 +719,22 @@ class {transform.original_class_name}(BaseTransformer):
737
719
self._model_signature_dict["predict"] = ModelSignature(inputs,
738
720
([] if self._drop_input_cols else inputs)
739
721
+ outputs)
722
+ # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
723
+ # For outlier models, returns -1 for outliers and 1 for inliers.
724
+ # Clusterer returns int64 cluster labels.
725
+ elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
726
+ outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
727
+ self._model_signature_dict["predict"] = ModelSignature(inputs,
728
+ ([] if self._drop_input_cols else inputs)
729
+ + outputs)
730
+
740
731
# For regressor, the type of predict is float64
741
732
elif self._sklearn_object._estimator_type == 'regressor':
742
733
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
743
734
self._model_signature_dict["predict"] = ModelSignature(inputs,
744
735
([] if self._drop_input_cols else inputs)
745
736
+ outputs)
737
+
746
738
for prob_func in PROB_FUNCTIONS:
747
739
if hasattr(self, prob_func):
748
740
output_cols_prefix: str = f"{{prob_func}}_"
0 commit comments