-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement converter for _ConstantPredictor (#979)
* Implement converter for _ConstantPredictor Signed-off-by: xadupre <[email protected]> * add test for double Signed-off-by: xadupre <[email protected]> * fix shape Signed-off-by: xadupre <[email protected]> * remove unused variable Signed-off-by: xadupre <[email protected]> --------- Signed-off-by: xadupre <[email protected]>
- Loading branch information
Showing
6 changed files
with
118 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,18 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from ..common._registration import register_shape_calculator | ||
from ..common.data_types import Int64TensorType | ||
from ..common.shape_calculator import calculate_linear_classifier_output_shapes | ||
|
||
|
||
def calculate_constant_predictor_output_shapes(operator): | ||
N = operator.inputs[0].get_first_dimension() | ||
operator.outputs[0].type = Int64TensorType([N]) | ||
operator.outputs[1].type.shape = [N, 2] | ||
|
||
|
||
register_shape_calculator('Sklearn_ConstantPredictor', | ||
calculate_constant_predictor_output_shapes) | ||
|
||
register_shape_calculator('SklearnOneVsRestClassifier', | ||
calculate_linear_classifier_output_shapes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Tests scikit-learn's SGDClassifier converter.""" | ||
|
||
import unittest | ||
import numpy as np | ||
from sklearn.multiclass import _ConstantPredictor | ||
from onnxruntime import __version__ as ort_version | ||
from skl2onnx import to_onnx | ||
|
||
from skl2onnx.common.data_types import FloatTensorType, DoubleTensorType | ||
|
||
from test_utils import ( | ||
dump_data_and_model, | ||
TARGET_OPSET | ||
) | ||
|
||
ort_version = ".".join(ort_version.split(".")[:2]) | ||
|
||
|
||
class TestConstantPredictorConverter(unittest.TestCase): | ||
def test_constant_predictor_float(self): | ||
model = _ConstantPredictor() | ||
X = np.array([[1, 2]]) | ||
y = np.array([0]) | ||
model.fit(X, y) | ||
test_x = np.array([[1, 0], [2, 8]]) | ||
|
||
model_onnx = to_onnx( | ||
model, "scikit-learn ConstantPredictor", | ||
initial_types=[("input", FloatTensorType([None, X.shape[1]]))], | ||
target_opset=TARGET_OPSET, | ||
options={'zipmap': False}) | ||
|
||
self.assertIsNotNone(model_onnx is not None) | ||
dump_data_and_model(test_x.astype(np.float32), model, model_onnx, | ||
basename="SklearnConstantPredictorFloat") | ||
|
||
def test_constant_predictor_double(self): | ||
model = _ConstantPredictor() | ||
X = np.array([[1, 2]]) | ||
y = np.array([0]) | ||
model.fit(X, y) | ||
test_x = np.array([[1, 0], [2, 8]]) | ||
|
||
model_onnx = to_onnx( | ||
model, "scikit-learn ConstantPredictor", | ||
initial_types=[("input", DoubleTensorType([None, X.shape[1]]))], | ||
target_opset=TARGET_OPSET, | ||
options={'zipmap': False}) | ||
|
||
self.assertIsNotNone(model_onnx is not None) | ||
dump_data_and_model(test_x.astype(np.float64), model, model_onnx, | ||
basename="SklearnConstantPredictorDouble") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main(verbosity=3) |