Skip to content

Commit

Permalink
Add jax dispatch for searchsorted
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Dec 28, 2024
1 parent 6d3a2a4 commit d2eb992
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
11 changes: 11 additions & 0 deletions pytensor/link/jax/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
FillDiagonalOffset,
RavelMultiIndex,
Repeat,
SearchsortedOp,
Unique,
UnravelIndex,
)
Expand Down Expand Up @@ -130,3 +131,13 @@ def jax_funcify_FillDiagonalOffset(op, **kwargs):
# return filldiagonaloffset

raise NotImplementedError("flatiter not implemented in JAX")


@jax_funcify.register(SearchsortedOp)
def jax_funcify_SearchsortedOp(op, **kwargs):
side = op.side

def searchsorted(x, v, side=side, sorter=None):
return jnp.searchsorted(x=x, v=v, side=side, sorter=sorter)

return searchsorted
8 changes: 8 additions & 0 deletions tests/link/jax/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def test_extra_ops():
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)

values = np.arange(10)
query = np.array(6)
out = pt_extra_ops.searchsorted(values, query)
fgraph = FunctionGraph([], out)
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)


@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_bartlett_dynamic_shape():
Expand Down
18 changes: 18 additions & 0 deletions tests/tensor/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
InterpolationMethod,
interp,
interpolate1d,
polynomial_interpolate1d,
valid_methods,
)

Expand Down Expand Up @@ -105,3 +106,20 @@ def test_interpolate_scalar_extrapolate(method: InterpolationMethod):
# and last should take the right.
interior_point = x[3] + 0.1
assert f(interior_point) == (y[4] if method == "last" else y[3])


def test_polynomial_interpolate1d():
x = np.linspace(-2, 6, 10)
y = np.sin(x)

f_op = polynomial_interpolate1d(x, y)
x_hat_pt = pt.dvector("x_hat")
degree = pt.iscalar("degree")

f = pytensor.function(
[x_hat_pt, degree], f_op(x_hat_pt, degree, True), mode="FAST_RUN"
)
x_grid = np.linspace(-2, 6, 100)
y_hat = f(x_grid, 0)

assert_allclose(y_hat, np.mean(y))

0 comments on commit d2eb992

Please sign in to comment.