Skip to content

Commit

Permalink
Refactor out OpFromGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Dec 28, 2024
1 parent 0e03119 commit b7a23d7
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 56 deletions.
79 changes: 41 additions & 38 deletions pytensor/tensor/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
from difflib import get_close_matches
from typing import Literal, get_args

from pytensor.compile.builders import OpFromGraph
from pytensor.tensor import TensorLike
from pytensor.tensor.basic import as_tensor_variable, switch
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.extra_ops import searchsorted
from pytensor.tensor.functional import vectorize
from pytensor.tensor.math import clip, eq, le
from pytensor.tensor.sort import argsort
from pytensor.tensor.type import scalar


InterpolationMethod = Literal["linear", "nearest", "first", "last", "mean"]
Expand Down Expand Up @@ -122,41 +120,41 @@ def interpolate1d(
else:
right_pad = as_tensor_variable(right_pad)

x_hat = scalar("x_hat", dtype=x.dtype)
idx = searchsorted(x, x_hat)

if x.ndim != 1 or y.ndim != 1:
raise ValueError("Inputs must be 1d")

if method == "linear":
y_hat = _linear_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "nearest":
y_hat = _nearest_neighbor_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "first":
y_hat = _stepwise_first_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "mean":
y_hat = _stepwise_mean_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "last":
y_hat = _stepwise_last_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
else:
raise NotImplementedError(
f"Unknown interpolation method: {method}. "
f"Did you mean {get_close_matches(method, valid_methods)}?"
)

return Blockwise(
OpFromGraph(inputs=[x_hat], outputs=[y_hat], inline=False), signature="()->()"
)
def _scalar_interpolate1d(x_hat):
idx = searchsorted(x, x_hat)

if x.ndim != 1 or y.ndim != 1:
raise ValueError("Inputs must be 1d")

Check warning on line 127 in pytensor/tensor/interpolate.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/interpolate.py#L127

Added line #L127 was not covered by tests

if method == "linear":
y_hat = _linear_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "nearest":
y_hat = _nearest_neighbor_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "first":
y_hat = _stepwise_first_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "mean":
y_hat = _stepwise_mean_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "last":
y_hat = _stepwise_last_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
else:
raise NotImplementedError(

Check warning on line 150 in pytensor/tensor/interpolate.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/interpolate.py#L150

Added line #L150 was not covered by tests
f"Unknown interpolation method: {method}. "
f"Did you mean {get_close_matches(method, valid_methods)}?"
)

return y_hat

return vectorize(_scalar_interpolate1d, signature="()->()")


def interp(x, xp, fp, left=None, right=None, period=None):
Expand Down Expand Up @@ -191,7 +189,12 @@ def interp(x, xp, fp, left=None, right=None, period=None):
The interpolated values, same shape as `x`.
"""

xp = as_tensor_variable(xp)
fp = as_tensor_variable(fp)
x = as_tensor_variable(x)

f = interpolate1d(
xp, fp, method="linear", left_pad=left, right_pad=right, extrapolate=False
)

return f(x)
18 changes: 0 additions & 18 deletions tests/tensor/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
InterpolationMethod,
interp,
interpolate1d,
polynomial_interpolate1d,
valid_methods,
)

Expand Down Expand Up @@ -106,20 +105,3 @@ def test_interpolate_scalar_extrapolate(method: InterpolationMethod):
# and last should take the right.
interior_point = x[3] + 0.1
assert f(interior_point) == (y[4] if method == "last" else y[3])


def test_polynomial_interpolate1d():
x = np.linspace(-2, 6, 10)
y = np.sin(x)

f_op = polynomial_interpolate1d(x, y)
x_hat_pt = pt.dvector("x_hat")
degree = pt.iscalar("degree")

f = pytensor.function(
[x_hat_pt, degree], f_op(x_hat_pt, degree, True), mode="FAST_RUN"
)
x_grid = np.linspace(-2, 6, 100)
y_hat = f(x_grid, 0)

assert_allclose(y_hat, np.mean(y))

0 comments on commit b7a23d7

Please sign in to comment.