Skip to content

Commit

Permalink
Fix converter for DecisionTreeClassifier if n_classses == 1 (#1008)
Browse files Browse the repository at this point in the history
* Fix converter for DecisionTreeClassifier if n_classses == 1

Signed-off-by: Xavier Dupre <[email protected]>

* list or np.array

Signed-off-by: Xavier Dupre <[email protected]>

* lint

Signed-off-by: Xavier Dupre <[email protected]>

* froze lightgbm version

Signed-off-by: Xavier Dupre <[email protected]>

* black

Signed-off-by: Xavier Dupre <[email protected]>

* Refactor with black (#1009)

* Refactor with black

Signed-off-by: Xavier Dupre <[email protected]>

* remove unnecessary skip condition

Signed-off-by: Xavier Dupre <[email protected]>

* freeze lightgbm version

Signed-off-by: Xavier Dupre <[email protected]>

* add ruff to github action

Signed-off-by: Xavier Dupre <[email protected]>

* update badge on README.md

Signed-off-by: Xavier Dupre <[email protected]>

---------

Signed-off-by: Xavier Dupre <[email protected]>

* fix old CI

Signed-off-by: Xavier Dupre <[email protected]>

---------

Signed-off-by: Xavier Dupre <[email protected]>
Signed-off-by: Xavier Dupré <[email protected]>
  • Loading branch information
xadupre authored Aug 1, 2023
1 parent 8a4a803 commit 3ef5e13
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 6 deletions.
1 change: 0 additions & 1 deletion docs/examples/plot_cast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"""
import onnxruntime
import onnx
import numpy
import os
import math
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion docs/examples/plot_tfidfvectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import matplotlib.pyplot as plt
import os
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
import numpy
import onnxruntime as rt
from skl2onnx.common.data_types import StringTensorType
from skl2onnx import convert_sklearn
Expand Down
56 changes: 52 additions & 4 deletions skl2onnx/operator_converters/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numbers
import numpy as np
from onnx.numpy_helper import from_array
from ..common._apply_operation import (
apply_cast,
apply_concat,
Expand Down Expand Up @@ -124,7 +125,7 @@ def predict(
[indices_name, dummy_proba_name],
op_domain=op_domain,
op_version=op_version,
**attrs
**attrs,
)
else:
zero_name = scope.get_unique_variable_name("zero")
Expand Down Expand Up @@ -243,7 +244,7 @@ def _append_decision_output(
dpath,
op_domain=op_domain,
op_version=op_version,
**attrs
**attrs,
)

if n_out is None:
Expand Down Expand Up @@ -306,6 +307,53 @@ def convert_sklearn_decision_tree_classifier(
dtype = np.float32
op = operator.raw_operator
options = scope.get_options(op, dict(decision_path=False, decision_leaf=False))
if np.asarray(op.classes_).size == 1:
# The model was trained with one label.
# There is no need to build a tree.
if op.n_outputs_ != 1:
raise RuntimeError(
f"One training class and multiple outputs is not "
f"supported yet for class {op.__class__.__name__!r}."
)
if options["decision_path"] or options["decision_leaf"]:
raise RuntimeError(
f"One training class, option 'decision_path' "
f"or 'decision_leaf' are not supported for "
f"class {op.__class__.__name__!r}."
)

zero = scope.get_unique_variable_name("zero")
one = scope.get_unique_variable_name("one")
new_shape = scope.get_unique_variable_name("new_shape")
container.add_initializer(zero, onnx_proto.TensorProto.INT64, [1], [0])
container.add_initializer(one, onnx_proto.TensorProto.INT64, [1], [1])
container.add_initializer(new_shape, onnx_proto.TensorProto.INT64, [2], [-1, 1])
shape = scope.get_unique_variable_name("shape")
container.add_node("Shape", [operator.inputs[0].full_name], [shape])
shape_sliced = scope.get_unique_variable_name("shape_sliced")
container.add_node("Slice", [shape, zero, one, zero], [shape_sliced])

# labels
container.add_node(
"ConstantOfShape",
[shape_sliced],
[operator.outputs[0].full_name],
value=from_array(np.array([op.classes_[0]], dtype=np.int64)),
)

# probabilities
probas = scope.get_unique_variable_name("probas")
container.add_node(
"ConstantOfShape",
[shape_sliced],
[probas],
value=from_array(np.array([1], dtype=dtype)),
)
container.add_node(
"Reshape", [probas, new_shape], [operator.outputs[1].full_name]
)
return

if op.n_outputs_ == 1:
attrs = get_default_tree_classifier_attribute_pairs()
attrs["name"] = scope.get_unique_operator_name(op_type)
Expand Down Expand Up @@ -355,7 +403,7 @@ def convert_sklearn_decision_tree_classifier(
[operator.outputs[0].full_name, operator.outputs[1].full_name],
op_domain=op_domain,
op_version=op_version,
**attrs
**attrs,
)

n_out = 2
Expand Down Expand Up @@ -510,7 +558,7 @@ def convert_sklearn_decision_tree_regressor(
operator.outputs[0].full_name,
op_domain=op_domain,
op_version=op_version,
**attrs
**attrs,
)

options = scope.get_options(op, dict(decision_path=False, decision_leaf=False))
Expand Down
45 changes: 45 additions & 0 deletions tests/test_sklearn_classifiers_extreme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0

import unittest
import numpy as np

try:
from onnx.reference import ReferenceEvaluator
except ImportError:
ReferenceEvaluator = None
from sklearn.tree import DecisionTreeClassifier
from onnxruntime import InferenceSession
from skl2onnx import to_onnx
from test_utils import TARGET_OPSET


class TestSklearnClassifiersExtreme(unittest.TestCase):
def test_one_training_class(self):
x = np.eye(4, dtype=np.float32)
y = np.array([5, 5, 5, 5], dtype=np.int64)

cl = DecisionTreeClassifier()
cl = cl.fit(x, y)

expected = [cl.predict(x), cl.predict_proba(x)]
onx = to_onnx(cl, x, target_opset=TARGET_OPSET, options={"zipmap": False})

for cls in [
(lambda onx: ReferenceEvaluator(onx, verbose=0))
if ReferenceEvaluator is not None
else None,
lambda onx: InferenceSession(
onx.SerializeToString(), providers=["CPUExecutionProvider"]
),
]:
if cls is None:
continue
sess = cls(onx)
res = sess.run(None, {"X": x})
self.assertEqual(len(res), len(expected))
for e, g in zip(expected, res):
self.assertEqual(e.tolist(), g.tolist())


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit 3ef5e13

Please sign in to comment.