Skip to content

Commit

Permalink
Refactor out OpFromGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Dec 28, 2024
1 parent 0e03119 commit a0c02d0
Showing 1 changed file with 41 additions and 38 deletions.
79 changes: 41 additions & 38 deletions pytensor/tensor/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
from difflib import get_close_matches
from typing import Literal, get_args

from pytensor.compile.builders import OpFromGraph
from pytensor.tensor import TensorLike
from pytensor.tensor.basic import as_tensor_variable, switch
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.extra_ops import searchsorted
from pytensor.tensor.functional import vectorize
from pytensor.tensor.math import clip, eq, le
from pytensor.tensor.sort import argsort
from pytensor.tensor.type import scalar


InterpolationMethod = Literal["linear", "nearest", "first", "last", "mean"]
Expand Down Expand Up @@ -122,41 +120,41 @@ def interpolate1d(
else:
right_pad = as_tensor_variable(right_pad)

x_hat = scalar("x_hat", dtype=x.dtype)
idx = searchsorted(x, x_hat)

if x.ndim != 1 or y.ndim != 1:
raise ValueError("Inputs must be 1d")

if method == "linear":
y_hat = _linear_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "nearest":
y_hat = _nearest_neighbor_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "first":
y_hat = _stepwise_first_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "mean":
y_hat = _stepwise_mean_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "last":
y_hat = _stepwise_last_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
else:
raise NotImplementedError(
f"Unknown interpolation method: {method}. "
f"Did you mean {get_close_matches(method, valid_methods)}?"
)

return Blockwise(
OpFromGraph(inputs=[x_hat], outputs=[y_hat], inline=False), signature="()->()"
)
def _scalar_interpolate1d(x_hat):
idx = searchsorted(x, x_hat)

if x.ndim != 1 or y.ndim != 1:
raise ValueError("Inputs must be 1d")

if method == "linear":
y_hat = _linear_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "nearest":
y_hat = _nearest_neighbor_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "first":
y_hat = _stepwise_first_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "mean":
y_hat = _stepwise_mean_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "last":
y_hat = _stepwise_last_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
else:
raise NotImplementedError(
f"Unknown interpolation method: {method}. "
f"Did you mean {get_close_matches(method, valid_methods)}?"
)

return y_hat

return vectorize(_scalar_interpolate1d, signature="()->()")


def interp(x, xp, fp, left=None, right=None, period=None):
Expand Down Expand Up @@ -191,7 +189,12 @@ def interp(x, xp, fp, left=None, right=None, period=None):
The interpolated values, same shape as `x`.
"""

xp = as_tensor_variable(xp)
fp = as_tensor_variable(fp)
x = as_tensor_variable(x)

f = interpolate1d(
xp, fp, method="linear", left_pad=left, right_pad=right, extrapolate=False
)

return f(x)

0 comments on commit a0c02d0

Please sign in to comment.