Skip to content

Commit

Permalink
Add kernel Matern for gaussian process (#978)
Browse files Browse the repository at this point in the history
* Add kernel Matern

Signed-off-by: xadupre <[email protected]>

* fix matern kernel

Signed-off-by: Xavier Dupre <[email protected]>

* fix diag

Signed-off-by: Xavier Dupre <[email protected]>

---------

Signed-off-by: xadupre <[email protected]>
Signed-off-by: Xavier Dupre <[email protected]>
Co-authored-by: xiaowuhu <[email protected]>
  • Loading branch information
xadupre and xiaowuhu authored May 28, 2024
1 parent 8a09ebc commit c4a9de3
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 92 deletions.
2 changes: 2 additions & 0 deletions CHANGELOGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 1.17.0 (development)

* Support kernel Matern in Gaussian Process
[#978](https://github.com/onnx/sklearn-onnx/pull/978)
* Fix for multidimensional gaussian process
[#1097](https://github.com/onnx/sklearn-onnx/pull/1097)
* Minor fixes to support scikit-learn==1.5.0
Expand Down
2 changes: 1 addition & 1 deletion skl2onnx/algebra/graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def run(self):
expected_outputs = None

logger.debug(
"[State.run] id=%d op_name=%r is_model=%r " "expected_outputs=%r",
"[State.run] id=%d op_name=%r is_model=%r expected_outputs=%r",
id(self),
self.operator_name,
self.is_model,
Expand Down
204 changes: 119 additions & 85 deletions skl2onnx/operator_converters/_gp_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import numpy as np
from onnx.numpy_helper import from_array
from sklearn.gaussian_process.kernels import (
Sum,
Product,
ConstantKernel,
RBF,
DotProduct,
ExpSineSquared,
RationalQuadratic,
Matern,
PairwiseKernel,
Product,
RationalQuadratic,
RBF,
Sum,
WhiteKernel,
)
from ..algebra.complex_functions import onnx_squareform_pdist, onnx_cdist
Expand All @@ -22,15 +23,17 @@
OnnxTranspose,
OnnxDiv,
OnnxExp,
OnnxNeg,
OnnxShape,
OnnxSin,
OnnxPow,
OnnxReduceSumApi11,
OnnxSqueezeApi11,
OnnxIdentity,
OnnxReduceSumSquareApi18,
OnnxReduceL2_typed,
OnnxEyeLike,
OnnxGather,
OnnxConcat,
)
from ..algebra.custom_ops import OnnxCDist

Expand All @@ -45,7 +48,7 @@ def convert_kernel_diag(
):
if op_version is None:
raise RuntimeError("op_version must not be None.")
if isinstance(kernel, Sum):
if type(kernel) is Sum:
return OnnxAdd(
convert_kernel_diag(
kernel.k1, X, dtype=dtype, optim=optim, op_version=op_version
Expand All @@ -57,7 +60,7 @@ def convert_kernel_diag(
op_version=op_version,
)

if isinstance(kernel, Product):
if type(kernel) is Product:
return OnnxMul(
convert_kernel_diag(
kernel.k1, X, dtype=dtype, optim=optim, op_version=op_version
Expand All @@ -69,7 +72,7 @@ def convert_kernel_diag(
op_version=op_version,
)

if isinstance(kernel, ConstantKernel):
if type(kernel) is ConstantKernel:
onnx_zeros = _zero_vector_of_size(
X, keepdims=0, dtype=dtype, op_version=op_version
)
Expand All @@ -80,7 +83,7 @@ def convert_kernel_diag(
op_version=op_version,
)

if isinstance(kernel, (RBF, ExpSineSquared, RationalQuadratic)):
if type(kernel) in {RBF, ExpSineSquared, RationalQuadratic, Matern}:
onnx_zeros = _zero_vector_of_size(
X, keepdims=0, dtype=dtype, op_version=op_version
)
Expand All @@ -91,15 +94,14 @@ def convert_kernel_diag(
output_names=output_names,
op_version=op_version,
)
else:
return OnnxAdd(
onnx_zeros,
np.array([1], dtype=dtype),
output_names=output_names,
op_version=op_version,
)
return OnnxAdd(
onnx_zeros,
np.array([1], dtype=dtype),
output_names=output_names,
op_version=op_version,
)

if isinstance(kernel, DotProduct):
if type(kernel) is DotProduct:
t_sigma_0 = py_make_float_array(kernel.sigma_0**2, dtype=dtype)
return OnnxSqueezeApi11(
OnnxAdd(
Expand Down Expand Up @@ -141,7 +143,7 @@ def _convert_exp_sine_squared(
dtype=None,
optim=None,
op_version=None,
**kwargs
**kwargs,
):
"""
Implements the kernel ExpSineSquared.
Expand Down Expand Up @@ -230,7 +232,7 @@ def _convert_pairwise_kernel(
degree=None,
coef0=None,
optim=None,
**kwargs
**kwargs,
):
"""
Implements the kernel PairwiseKernel.
Expand Down Expand Up @@ -265,9 +267,9 @@ def convert_kernel(
):
if op_version is None:
raise RuntimeError("op_version must not be None.")
if isinstance(kernel, Sum):
if type(kernel) is Sum:
clop = OnnxAdd
elif isinstance(kernel, Product):
elif type(kernel) is Product:
clop = OnnxMul
else:
clop = None
Expand All @@ -293,7 +295,7 @@ def convert_kernel(
op_version=op_version,
)

if isinstance(kernel, ConstantKernel):
if type(kernel) is ConstantKernel:
# X and x_train should have the same number of features.
onnx_zeros_x = _zero_vector_of_size(
X, keepdims=1, dtype=dtype, op_version=op_version
Expand All @@ -314,38 +316,27 @@ def convert_kernel(
op_version=op_version,
)

if isinstance(kernel, RBF):
if type(kernel) in (RBF, Matern):
# length_scale = np.squeeze(length_scale).astype(float)
zeroh = _zero_vector_of_size(
X, axis=1, keepdims=0, dtype=dtype, op_version=op_version
)
zerov = _zero_vector_of_size(
X, axis=0, keepdims=1, dtype=dtype, op_version=op_version
)

if isinstance(kernel.length_scale, np.ndarray) and len(kernel.length_scale) > 0:
const = kernel.length_scale.astype(dtype)
else:
tensor_value = py_make_float_array(
kernel.length_scale, dtype=dtype, as_tensor=True
)
const = OnnxConstantOfShape(
OnnxShape(zeroh, op_version=op_version),
value=tensor_value,
op_version=op_version,
)
const = np.array([kernel.length_scale], dtype=dtype)
X_scaled = OnnxDiv(X, const, op_version=op_version)
if x_train is None:
dist = onnx_squareform_pdist(
X_scaled, metric="sqeuclidean", dtype=dtype, op_version=op_version
X_scaled,
metric="sqeuclidean" if type(kernel) is RBF else "euclidean",
dtype=dtype,
op_version=op_version,
)
else:
x_train_scaled = OnnxDiv(x_train, const, op_version=op_version)
if optim is None:
dist = onnx_cdist(
X_scaled,
x_train_scaled,
metric="sqeuclidean",
metric="sqeuclidean" if type(kernel) is RBF else "euclidean",
dtype=dtype,
op_version=op_version,
)
Expand All @@ -359,26 +350,69 @@ def convert_kernel(
else:
raise ValueError("Unknown optimization '{}'.".format(optim))

tensor_value = py_make_float_array(-0.5, dtype=dtype, as_tensor=True)
cst5 = OnnxConstantOfShape(
OnnxShape(zerov, op_version=op_version),
value=tensor_value,
op_version=op_version,
)
# see https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/gaussian_process/kernels.py#L1719
if type(kernel) is RBF:
K = OnnxMul(dist, np.array([0.5], dtype=dtype), op_version=op_version)
return OnnxExp(
OnnxNeg(K, op_version=op_version),
op_version=op_version,
output_names=output_names,
)
# Matern
if kernel.nu == 0.5:
# K = np.exp(-dists
return OnnxExp(
OnnxNeg(dist, op_version=op_version),
op_version=op_version,
output_names=output_names,
)

# K = np.exp(-.5 * dists)
exp = OnnxExp(
OnnxMul(dist, cst5, op_version=op_version),
output_names=output_names,
op_version=op_version,
)
if kernel.nu == 1.5:
# K = dists * math.sqrt(3)
# K = (1.0 + K) * np.exp(-K)
K = OnnxMul(dist, np.array([3**0.5], dtype=dtype), op_version=op_version)
exp_k = OnnxExp(OnnxNeg(K, op_version=op_version), op_version=op_version)
k_1 = OnnxAdd(K, np.array([1], dtype=dtype), op_version=op_version)
return OnnxMul(k_1, exp_k, output_names=output_names, op_version=op_version)

if kernel.nu == 2.5:
# K = dists * math.sqrt(5)
# K = (1.0 + K + K**2 / 3.0) * np.exp(-K)
K = OnnxMul(dist, np.array([5**0.5], dtype=dtype), op_version=op_version)
exp_k = OnnxExp(OnnxNeg(K, op_version=op_version), op_version=op_version)
k_12 = OnnxAdd(
OnnxAdd(K, np.array([1], dtype=type), op_version=op_version),
OnnxDiv(
OnnxMul(K, K, op_version=op_version),
np.array([3], dtype=dtype),
op_version=op_version,
),
op_version=op_version,
)
return OnnxMul(
k_12, exp_k, output_names=output_names, op_version=op_version
)

if kernel.nu == np.inf:
# K = np.exp(-(dists**2) / 2.0)
return OnnxExp(
OnnxNeg(
OnnxDiv(
OnnxMul(dist, dist, op_version=op_version),
np.array([2], dtype=dtype),
op_version=op_version,
),
op_version=op_version,
),
op_version=op_version,
output_names=output_names,
)

# This should not be needed.
# K = squareform(K)
# np.fill_diagonal(K, 1)
return exp
raise RuntimeError(
f"The converter is not implemented for Matern(nu={kernel.nu}, ...)."
)

if isinstance(kernel, ExpSineSquared):
if type(kernel) is ExpSineSquared:
if not isinstance(kernel.length_scale, (float, int)):
raise NotImplementedError(
"length_scale should be float not {}.".format(type(kernel.length_scale))
Expand All @@ -395,7 +429,7 @@ def convert_kernel(
op_version=op_version,
)

if isinstance(kernel, DotProduct):
if type(kernel) is DotProduct:
if not isinstance(kernel.sigma_0, (float, int)):
raise NotImplementedError(
"sigma_0 should be float not {}.".format(type(kernel.sigma_0))
Expand Down Expand Up @@ -424,7 +458,7 @@ def convert_kernel(
op_version=op_version,
)

if isinstance(kernel, RationalQuadratic):
if type(kernel) is RationalQuadratic:
if x_train is None:
return _convert_rational_quadratic(
X,
Expand All @@ -448,7 +482,7 @@ def convert_kernel(
op_version=op_version,
)

if isinstance(kernel, PairwiseKernel):
if type(kernel) is PairwiseKernel:
if x_train is None:
return _convert_pairwise_kernel(
X,
Expand All @@ -470,7 +504,7 @@ def convert_kernel(
op_version=op_version,
)

if isinstance(kernel, WhiteKernel):
if type(kernel) is WhiteKernel:
# X and x_train should have the same number of features.
onnx_zeros_x = _zero_vector_of_size(
X, keepdims=1, dtype=dtype, op_version=op_version
Expand Down Expand Up @@ -508,30 +542,30 @@ def _zero_vector_of_size(
raise RuntimeError("op_version must not be None.")
if keepdims is None:
raise ValueError("Default for keepdims is not allowed.")
if dtype == np.float32:
res = OnnxReduceSumApi11(
OnnxConstantOfShape(
OnnxShape(X, op_version=op_version), op_version=op_version
),
axes=[1 - axis],
keepdims=keepdims,
output_names=output_names,
op_version=op_version,
)
elif dtype in (np.float64, np.int32, np.int64):
res = OnnxReduceSumApi11(
OnnxConstantOfShape(
OnnxShape(X, op_version=op_version),
value=py_make_float_array(0, dtype=dtype, as_tensor=True),
op_version=op_version,
),
axes=[1 - axis],
keepdims=keepdims,
output_names=output_names,
op_version=op_version,

shape = OnnxShape(X, op_version=op_version)
if axis == 0:
dim = OnnxGather(shape, np.array([0], dtype=np.int64), op_version=op_version)
new_shape = (
OnnxConcat(
dim, np.array([1], dtype=np.int64), axis=0, op_version=op_version
)
if keepdims
else dim
)
else:
raise NotImplementedError(
"Unable to create zero vector of type {}".format(dtype)
dim = OnnxGather(shape, np.array([1], dtype=np.int64), op_version=op_version)
new_shape = (
OnnxConcat(
np.array([1], dtype=np.int64), dim, axis=0, op_version=op_version
)
if keepdims
else dim
)
return res

return OnnxConstantOfShape(
new_shape,
output_names=output_names,
op_version=op_version,
value=from_array(np.array([0], dtype=dtype)),
)
Loading

0 comments on commit c4a9de3

Please sign in to comment.