diff --git a/skl2onnx/_parse.py b/skl2onnx/_parse.py index dd2d989fa..7910beb49 100644 --- a/skl2onnx/_parse.py +++ b/skl2onnx/_parse.py @@ -439,7 +439,8 @@ def _apply_zipmap(zipmap_options, scope, model, input_type, zipmap_operator.inputs = probability_tensor label_type = Int64TensorType([None]) - if (isinstance(model.classes_, list) and + if (hasattr(model, "classes_") and + isinstance(model.classes_, list) and isinstance(model.classes_[0], np.ndarray)): # multi-label problem pass diff --git a/skl2onnx/_supported_operators.py b/skl2onnx/_supported_operators.py index a539ac129..1653d076b 100644 --- a/skl2onnx/_supported_operators.py +++ b/skl2onnx/_supported_operators.py @@ -75,7 +75,11 @@ ) # Multi-class -from sklearn.multiclass import OneVsRestClassifier, OneVsOneClassifier +from sklearn.multiclass import ( + _ConstantPredictor, + OneVsRestClassifier, + OneVsOneClassifier +) # Tree-based models from sklearn.ensemble import ( @@ -269,6 +273,7 @@ # included in the following list and one output for everything not in # the list. sklearn_classifier_list = list(filter(lambda m: m is not None, [ + _ConstantPredictor, AdaBoostClassifier, BaggingClassifier, BernoulliNB, @@ -316,6 +321,7 @@ # equivalent in terms of conversion. def build_sklearn_operator_name_map(): res = {k: "Sklearn" + k.__name__ for k in [ + _ConstantPredictor, AdaBoostClassifier, AdaBoostRegressor, BaggingClassifier, diff --git a/skl2onnx/common/utils_classifier.py b/skl2onnx/common/utils_classifier.py index 26992dbf0..dceb59753 100644 --- a/skl2onnx/common/utils_classifier.py +++ b/skl2onnx/common/utils_classifier.py @@ -37,6 +37,9 @@ def get_label_classes(scope, op, node_names=False): classes = op.classes_ elif hasattr(op, 'intercept_'): classes = len(op.intercept_) + elif hasattr(op, "y_"): + # _ConstantPredictor + classes = np.array(list(sorted(set(op.y_)))) else: raise RuntimeError( "No known ways to retrieve the number of classes for class %r." diff --git a/skl2onnx/operator_converters/one_vs_rest_classifier.py b/skl2onnx/operator_converters/one_vs_rest_classifier.py index 6486d7242..48d0fb3b4 100644 --- a/skl2onnx/operator_converters/one_vs_rest_classifier.py +++ b/skl2onnx/operator_converters/one_vs_rest_classifier.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import numpy as np from sklearn.base import is_regressor from sklearn.svm import LinearSVC from ..proto import onnx_proto @@ -13,6 +14,8 @@ apply_slice, apply_sub, apply_cast, apply_abs, apply_add, apply_div) from ..common.utils_classifier import _finalize_converter_classes from ..common.data_types import guess_proto_type, Int64TensorType +from ..algebra.onnx_ops import ( + OnnxReshape, OnnxShape, OnnxSlice, OnnxTile) from .._supported_operators import sklearn_operator_name_map @@ -172,9 +175,44 @@ def convert_one_vs_rest_classifier(scope: Scope, operator: Operator, op.classes_, proto_dtype) +def convert_constant_predictor_classifier(scope: Scope, operator: Operator, + container: ModelComponentContainer): + """ + Converts a *_ConstantPredictor* into *ONNX* format. + """ + op_version = container.target_opset + proto_dtype = guess_proto_type(operator.inputs[0].type) + if proto_dtype != onnx_proto.TensorProto.DOUBLE: + proto_dtype = onnx_proto.TensorProto.FLOAT + op = operator.raw_operator + dtype = {onnx_proto.TensorProto.DOUBLE: np.float64, + onnx_proto.TensorProto.FLOAT: np.float32} + shape = OnnxShape(operator.inputs[0].full_name, op_version=op_version) + first = OnnxSlice(shape, np.array([0], dtype=np.int64), + np.array([1], dtype=np.int64), op_version=op_version) + y = op.y_.astype(dtype[proto_dtype]).ravel() + labels = OnnxTile(y.astype(np.int64), + first, op_version=op_version, + output_names=[operator.outputs[0].full_name]) + + cst = np.hstack([(1 - y).astype(y.dtype), y]) + proba_flat = OnnxTile(cst, first, op_version=op_version) + proba_reshape = OnnxReshape( + proba_flat, np.array([-1, 2], dtype=np.int64), + output_names=[operator.outputs[1].full_name]) + + labels.add_to(scope, container) + proba_reshape.add_to(scope, container) + + register_converter('SklearnOneVsRestClassifier', convert_one_vs_rest_classifier, options={'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'output_class_labels': [False, True], 'raw_scores': [True, False]}) + +register_converter('Sklearn_ConstantPredictor', + convert_constant_predictor_classifier, + options={'zipmap': [True, False, 'columns'], + 'nocl': [True, False]}) diff --git a/skl2onnx/shape_calculators/one_vs_rest_classifier.py b/skl2onnx/shape_calculators/one_vs_rest_classifier.py index 0e2c5b8d4..579db17d3 100644 --- a/skl2onnx/shape_calculators/one_vs_rest_classifier.py +++ b/skl2onnx/shape_calculators/one_vs_rest_classifier.py @@ -1,8 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 from ..common._registration import register_shape_calculator +from ..common.data_types import Int64TensorType from ..common.shape_calculator import calculate_linear_classifier_output_shapes +def calculate_constant_predictor_output_shapes(operator): + N = operator.inputs[0].get_first_dimension() + operator.outputs[0].type = Int64TensorType([N]) + operator.outputs[1].type.shape = [N, 2] + + +register_shape_calculator('Sklearn_ConstantPredictor', + calculate_constant_predictor_output_shapes) + register_shape_calculator('SklearnOneVsRestClassifier', calculate_linear_classifier_output_shapes) diff --git a/tests/test_sklearn_constant_predictor.py b/tests/test_sklearn_constant_predictor.py new file mode 100644 index 000000000..f2846b654 --- /dev/null +++ b/tests/test_sklearn_constant_predictor.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Tests scikit-learn's SGDClassifier converter.""" + +import unittest +import numpy as np +from sklearn.multiclass import _ConstantPredictor +from onnxruntime import __version__ as ort_version +from skl2onnx import to_onnx + +from skl2onnx.common.data_types import FloatTensorType, DoubleTensorType + +from test_utils import ( + dump_data_and_model, + TARGET_OPSET +) + +ort_version = ".".join(ort_version.split(".")[:2]) + + +class TestConstantPredictorConverter(unittest.TestCase): + def test_constant_predictor_float(self): + model = _ConstantPredictor() + X = np.array([[1, 2]]) + y = np.array([0]) + model.fit(X, y) + test_x = np.array([[1, 0], [2, 8]]) + + model_onnx = to_onnx( + model, "scikit-learn ConstantPredictor", + initial_types=[("input", FloatTensorType([None, X.shape[1]]))], + target_opset=TARGET_OPSET, + options={'zipmap': False}) + + self.assertIsNotNone(model_onnx is not None) + dump_data_and_model(test_x.astype(np.float32), model, model_onnx, + basename="SklearnConstantPredictorFloat") + + def test_constant_predictor_double(self): + model = _ConstantPredictor() + X = np.array([[1, 2]]) + y = np.array([0]) + model.fit(X, y) + test_x = np.array([[1, 0], [2, 8]]) + + model_onnx = to_onnx( + model, "scikit-learn ConstantPredictor", + initial_types=[("input", DoubleTensorType([None, X.shape[1]]))], + target_opset=TARGET_OPSET, + options={'zipmap': False}) + + self.assertIsNotNone(model_onnx is not None) + dump_data_and_model(test_x.astype(np.float64), model, model_onnx, + basename="SklearnConstantPredictorDouble") + + +if __name__ == "__main__": + unittest.main(verbosity=3)