diff --git a/docs/examples/plot_cast_transformer.py b/docs/examples/plot_cast_transformer.py index 34efc74f8..82cfef35e 100644 --- a/docs/examples/plot_cast_transformer.py +++ b/docs/examples/plot_cast_transformer.py @@ -30,7 +30,6 @@ """ import onnxruntime import onnx -import numpy import os import math import numpy as np diff --git a/docs/examples/plot_tfidfvectorizer.py b/docs/examples/plot_tfidfvectorizer.py index 96321bfbf..290eb5e36 100644 --- a/docs/examples/plot_tfidfvectorizer.py +++ b/docs/examples/plot_tfidfvectorizer.py @@ -24,7 +24,6 @@ import matplotlib.pyplot as plt import os from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer -import numpy import onnxruntime as rt from skl2onnx.common.data_types import StringTensorType from skl2onnx import convert_sklearn diff --git a/skl2onnx/operator_converters/decision_tree.py b/skl2onnx/operator_converters/decision_tree.py index e92fe285c..e0a6326f8 100644 --- a/skl2onnx/operator_converters/decision_tree.py +++ b/skl2onnx/operator_converters/decision_tree.py @@ -3,6 +3,7 @@ import numbers import numpy as np +from onnx.numpy_helper import from_array from ..common._apply_operation import ( apply_cast, apply_concat, @@ -124,7 +125,7 @@ def predict( [indices_name, dummy_proba_name], op_domain=op_domain, op_version=op_version, - **attrs + **attrs, ) else: zero_name = scope.get_unique_variable_name("zero") @@ -243,7 +244,7 @@ def _append_decision_output( dpath, op_domain=op_domain, op_version=op_version, - **attrs + **attrs, ) if n_out is None: @@ -306,6 +307,53 @@ def convert_sklearn_decision_tree_classifier( dtype = np.float32 op = operator.raw_operator options = scope.get_options(op, dict(decision_path=False, decision_leaf=False)) + if np.asarray(op.classes_).size == 1: + # The model was trained with one label. + # There is no need to build a tree. + if op.n_outputs_ != 1: + raise RuntimeError( + f"One training class and multiple outputs is not " + f"supported yet for class {op.__class__.__name__!r}." + ) + if options["decision_path"] or options["decision_leaf"]: + raise RuntimeError( + f"One training class, option 'decision_path' " + f"or 'decision_leaf' are not supported for " + f"class {op.__class__.__name__!r}." + ) + + zero = scope.get_unique_variable_name("zero") + one = scope.get_unique_variable_name("one") + new_shape = scope.get_unique_variable_name("new_shape") + container.add_initializer(zero, onnx_proto.TensorProto.INT64, [1], [0]) + container.add_initializer(one, onnx_proto.TensorProto.INT64, [1], [1]) + container.add_initializer(new_shape, onnx_proto.TensorProto.INT64, [2], [-1, 1]) + shape = scope.get_unique_variable_name("shape") + container.add_node("Shape", [operator.inputs[0].full_name], [shape]) + shape_sliced = scope.get_unique_variable_name("shape_sliced") + container.add_node("Slice", [shape, zero, one, zero], [shape_sliced]) + + # labels + container.add_node( + "ConstantOfShape", + [shape_sliced], + [operator.outputs[0].full_name], + value=from_array(np.array([op.classes_[0]], dtype=np.int64)), + ) + + # probabilities + probas = scope.get_unique_variable_name("probas") + container.add_node( + "ConstantOfShape", + [shape_sliced], + [probas], + value=from_array(np.array([1], dtype=dtype)), + ) + container.add_node( + "Reshape", [probas, new_shape], [operator.outputs[1].full_name] + ) + return + if op.n_outputs_ == 1: attrs = get_default_tree_classifier_attribute_pairs() attrs["name"] = scope.get_unique_operator_name(op_type) @@ -355,7 +403,7 @@ def convert_sklearn_decision_tree_classifier( [operator.outputs[0].full_name, operator.outputs[1].full_name], op_domain=op_domain, op_version=op_version, - **attrs + **attrs, ) n_out = 2 @@ -510,7 +558,7 @@ def convert_sklearn_decision_tree_regressor( operator.outputs[0].full_name, op_domain=op_domain, op_version=op_version, - **attrs + **attrs, ) options = scope.get_options(op, dict(decision_path=False, decision_leaf=False)) diff --git a/tests/test_sklearn_classifiers_extreme.py b/tests/test_sklearn_classifiers_extreme.py new file mode 100644 index 000000000..d26ca31a5 --- /dev/null +++ b/tests/test_sklearn_classifiers_extreme.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 + +import unittest +import numpy as np + +try: + from onnx.reference import ReferenceEvaluator +except ImportError: + ReferenceEvaluator = None +from sklearn.tree import DecisionTreeClassifier +from onnxruntime import InferenceSession +from skl2onnx import to_onnx +from test_utils import TARGET_OPSET + + +class TestSklearnClassifiersExtreme(unittest.TestCase): + def test_one_training_class(self): + x = np.eye(4, dtype=np.float32) + y = np.array([5, 5, 5, 5], dtype=np.int64) + + cl = DecisionTreeClassifier() + cl = cl.fit(x, y) + + expected = [cl.predict(x), cl.predict_proba(x)] + onx = to_onnx(cl, x, target_opset=TARGET_OPSET, options={"zipmap": False}) + + for cls in [ + (lambda onx: ReferenceEvaluator(onx, verbose=0)) + if ReferenceEvaluator is not None + else None, + lambda onx: InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ), + ]: + if cls is None: + continue + sess = cls(onx) + res = sess.run(None, {"X": x}) + self.assertEqual(len(res), len(expected)) + for e, g in zip(expected, res): + self.assertEqual(e.tolist(), g.tolist()) + + +if __name__ == "__main__": + unittest.main(verbosity=2)