-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix converter for DecisionTreeClassifier if n_classses == 1 (#1008)
* Fix converter for DecisionTreeClassifier if n_classses == 1 Signed-off-by: Xavier Dupre <[email protected]> * list or np.array Signed-off-by: Xavier Dupre <[email protected]> * lint Signed-off-by: Xavier Dupre <[email protected]> * froze lightgbm version Signed-off-by: Xavier Dupre <[email protected]> * black Signed-off-by: Xavier Dupre <[email protected]> * Refactor with black (#1009) * Refactor with black Signed-off-by: Xavier Dupre <[email protected]> * remove unnecessary skip condition Signed-off-by: Xavier Dupre <[email protected]> * freeze lightgbm version Signed-off-by: Xavier Dupre <[email protected]> * add ruff to github action Signed-off-by: Xavier Dupre <[email protected]> * update badge on README.md Signed-off-by: Xavier Dupre <[email protected]> --------- Signed-off-by: Xavier Dupre <[email protected]> * fix old CI Signed-off-by: Xavier Dupre <[email protected]> --------- Signed-off-by: Xavier Dupre <[email protected]> Signed-off-by: Xavier Dupré <[email protected]>
- Loading branch information
Showing
4 changed files
with
97 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,6 @@ | |
""" | ||
import onnxruntime | ||
import onnx | ||
import numpy | ||
import os | ||
import math | ||
import numpy as np | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import unittest | ||
import numpy as np | ||
|
||
try: | ||
from onnx.reference import ReferenceEvaluator | ||
except ImportError: | ||
ReferenceEvaluator = None | ||
from sklearn.tree import DecisionTreeClassifier | ||
from onnxruntime import InferenceSession | ||
from skl2onnx import to_onnx | ||
from test_utils import TARGET_OPSET | ||
|
||
|
||
class TestSklearnClassifiersExtreme(unittest.TestCase): | ||
def test_one_training_class(self): | ||
x = np.eye(4, dtype=np.float32) | ||
y = np.array([5, 5, 5, 5], dtype=np.int64) | ||
|
||
cl = DecisionTreeClassifier() | ||
cl = cl.fit(x, y) | ||
|
||
expected = [cl.predict(x), cl.predict_proba(x)] | ||
onx = to_onnx(cl, x, target_opset=TARGET_OPSET, options={"zipmap": False}) | ||
|
||
for cls in [ | ||
(lambda onx: ReferenceEvaluator(onx, verbose=0)) | ||
if ReferenceEvaluator is not None | ||
else None, | ||
lambda onx: InferenceSession( | ||
onx.SerializeToString(), providers=["CPUExecutionProvider"] | ||
), | ||
]: | ||
if cls is None: | ||
continue | ||
sess = cls(onx) | ||
res = sess.run(None, {"X": x}) | ||
self.assertEqual(len(res), len(expected)) | ||
for e, g in zip(expected, res): | ||
self.assertEqual(e.tolist(), g.tolist()) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main(verbosity=2) |