Skip to content

Commit

Permalink
Implement converter for _ConstantPredictor (#979)
Browse files Browse the repository at this point in the history
* Implement converter for _ConstantPredictor

Signed-off-by: xadupre <[email protected]>

* add test for double

Signed-off-by: xadupre <[email protected]>

* fix shape

Signed-off-by: xadupre <[email protected]>

* remove unused variable

Signed-off-by: xadupre <[email protected]>

---------

Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre authored Mar 18, 2023
1 parent 1313eb9 commit 2432f70
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 2 deletions.
3 changes: 2 additions & 1 deletion skl2onnx/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ def _apply_zipmap(zipmap_options, scope, model, input_type,
zipmap_operator.inputs = probability_tensor
label_type = Int64TensorType([None])

if (isinstance(model.classes_, list) and
if (hasattr(model, "classes_") and
isinstance(model.classes_, list) and
isinstance(model.classes_[0], np.ndarray)):
# multi-label problem
pass
Expand Down
8 changes: 7 additions & 1 deletion skl2onnx/_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@
)

# Multi-class
from sklearn.multiclass import OneVsRestClassifier, OneVsOneClassifier
from sklearn.multiclass import (
_ConstantPredictor,
OneVsRestClassifier,
OneVsOneClassifier
)

# Tree-based models
from sklearn.ensemble import (
Expand Down Expand Up @@ -269,6 +273,7 @@
# included in the following list and one output for everything not in
# the list.
sklearn_classifier_list = list(filter(lambda m: m is not None, [
_ConstantPredictor,
AdaBoostClassifier,
BaggingClassifier,
BernoulliNB,
Expand Down Expand Up @@ -316,6 +321,7 @@
# equivalent in terms of conversion.
def build_sklearn_operator_name_map():
res = {k: "Sklearn" + k.__name__ for k in [
_ConstantPredictor,
AdaBoostClassifier,
AdaBoostRegressor,
BaggingClassifier,
Expand Down
3 changes: 3 additions & 0 deletions skl2onnx/common/utils_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def get_label_classes(scope, op, node_names=False):
classes = op.classes_
elif hasattr(op, 'intercept_'):
classes = len(op.intercept_)
elif hasattr(op, "y_"):
# _ConstantPredictor
classes = np.array(list(sorted(set(op.y_))))
else:
raise RuntimeError(
"No known ways to retrieve the number of classes for class %r."
Expand Down
38 changes: 38 additions & 0 deletions skl2onnx/operator_converters/one_vs_rest_classifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import numpy as np
from sklearn.base import is_regressor
from sklearn.svm import LinearSVC
from ..proto import onnx_proto
Expand All @@ -13,6 +14,8 @@
apply_slice, apply_sub, apply_cast, apply_abs, apply_add, apply_div)
from ..common.utils_classifier import _finalize_converter_classes
from ..common.data_types import guess_proto_type, Int64TensorType
from ..algebra.onnx_ops import (
OnnxReshape, OnnxShape, OnnxSlice, OnnxTile)
from .._supported_operators import sklearn_operator_name_map


Expand Down Expand Up @@ -172,9 +175,44 @@ def convert_one_vs_rest_classifier(scope: Scope, operator: Operator,
op.classes_, proto_dtype)


def convert_constant_predictor_classifier(scope: Scope, operator: Operator,
container: ModelComponentContainer):
"""
Converts a *_ConstantPredictor* into *ONNX* format.
"""
op_version = container.target_opset
proto_dtype = guess_proto_type(operator.inputs[0].type)
if proto_dtype != onnx_proto.TensorProto.DOUBLE:
proto_dtype = onnx_proto.TensorProto.FLOAT
op = operator.raw_operator
dtype = {onnx_proto.TensorProto.DOUBLE: np.float64,
onnx_proto.TensorProto.FLOAT: np.float32}
shape = OnnxShape(operator.inputs[0].full_name, op_version=op_version)
first = OnnxSlice(shape, np.array([0], dtype=np.int64),
np.array([1], dtype=np.int64), op_version=op_version)
y = op.y_.astype(dtype[proto_dtype]).ravel()
labels = OnnxTile(y.astype(np.int64),
first, op_version=op_version,
output_names=[operator.outputs[0].full_name])

cst = np.hstack([(1 - y).astype(y.dtype), y])
proba_flat = OnnxTile(cst, first, op_version=op_version)
proba_reshape = OnnxReshape(
proba_flat, np.array([-1, 2], dtype=np.int64),
output_names=[operator.outputs[1].full_name])

labels.add_to(scope, container)
proba_reshape.add_to(scope, container)


register_converter('SklearnOneVsRestClassifier',
convert_one_vs_rest_classifier,
options={'zipmap': [True, False, 'columns'],
'nocl': [True, False],
'output_class_labels': [False, True],
'raw_scores': [True, False]})

register_converter('Sklearn_ConstantPredictor',
convert_constant_predictor_classifier,
options={'zipmap': [True, False, 'columns'],
'nocl': [True, False]})
10 changes: 10 additions & 0 deletions skl2onnx/shape_calculators/one_vs_rest_classifier.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
# SPDX-License-Identifier: Apache-2.0

from ..common._registration import register_shape_calculator
from ..common.data_types import Int64TensorType
from ..common.shape_calculator import calculate_linear_classifier_output_shapes


def calculate_constant_predictor_output_shapes(operator):
N = operator.inputs[0].get_first_dimension()
operator.outputs[0].type = Int64TensorType([N])
operator.outputs[1].type.shape = [N, 2]


register_shape_calculator('Sklearn_ConstantPredictor',
calculate_constant_predictor_output_shapes)

register_shape_calculator('SklearnOneVsRestClassifier',
calculate_linear_classifier_output_shapes)
58 changes: 58 additions & 0 deletions tests/test_sklearn_constant_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0

"""Tests scikit-learn's SGDClassifier converter."""

import unittest
import numpy as np
from sklearn.multiclass import _ConstantPredictor
from onnxruntime import __version__ as ort_version
from skl2onnx import to_onnx

from skl2onnx.common.data_types import FloatTensorType, DoubleTensorType

from test_utils import (
dump_data_and_model,
TARGET_OPSET
)

ort_version = ".".join(ort_version.split(".")[:2])


class TestConstantPredictorConverter(unittest.TestCase):
def test_constant_predictor_float(self):
model = _ConstantPredictor()
X = np.array([[1, 2]])
y = np.array([0])
model.fit(X, y)
test_x = np.array([[1, 0], [2, 8]])

model_onnx = to_onnx(
model, "scikit-learn ConstantPredictor",
initial_types=[("input", FloatTensorType([None, X.shape[1]]))],
target_opset=TARGET_OPSET,
options={'zipmap': False})

self.assertIsNotNone(model_onnx is not None)
dump_data_and_model(test_x.astype(np.float32), model, model_onnx,
basename="SklearnConstantPredictorFloat")

def test_constant_predictor_double(self):
model = _ConstantPredictor()
X = np.array([[1, 2]])
y = np.array([0])
model.fit(X, y)
test_x = np.array([[1, 0], [2, 8]])

model_onnx = to_onnx(
model, "scikit-learn ConstantPredictor",
initial_types=[("input", DoubleTensorType([None, X.shape[1]]))],
target_opset=TARGET_OPSET,
options={'zipmap': False})

self.assertIsNotNone(model_onnx is not None)
dump_data_and_model(test_x.astype(np.float64), model, model_onnx,
basename="SklearnConstantPredictorDouble")


if __name__ == "__main__":
unittest.main(verbosity=3)

0 comments on commit 2432f70

Please sign in to comment.