Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytensor-native interpolation functions #1141

Merged
merged 5 commits into from
Dec 30, 2024

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Dec 28, 2024

Description

Adds a new file, interpolation.py. Objective is to match a useful subset of scipy.interpolate. So far I have 1d interpolation with different strategies. Example:

import numpy as np
from pytensor.tensor.interpolate import interpolate1d
import matplotlib.pyplot as plt

x = np.linspace(-2, 6, 10)
y = np.sin(x)

f_linear = interpolate1d(x, y, method='linear')
f_nearest = interpolate1d(x, y, method='nearest')
f_first = interpolate1d(x, y, method='first')
f_last = interpolate1d(x, y, method='last')
f_mean = interpolate1d(x, y, method='mean')

x_hat = pt.dvector('x_hat')
ops = [f_linear, f_nearest, f_first, f_last, f_mean]
MODE = 'FAST_RUN'
funcs = [
    pytensor.function([x_hat], op(x_hat), mode=MODE) for op in ops
]

inputs = np.linspace(-2.5, 6.5, 1000)
labels = ['linear', 'nearest', 'first', 'last', 'mean']
fig, ax = plt.subplots()
ax.scatter(x, y, color='k')
for f, label in zip(funcs, labels):
    ax.plot(inputs, f(inputs), label=label)
ax.legend()
plt.show()

image

Everything is composed of pytensor primitives, so it ought to compile to JAX/NUMBA without additional changes. It seems there are so bugs though, so they don't work yet. Specifically, NUMBA raises the following error:

TypeError: pytensor.link.numba.dispatch.basic.numba_funcify() got multiple values for keyword argument 'parent_node'

And it also gives a warning about a missing Supervisor, which might or might not be related.

In addition to this more feature rich interpolate1d API, I also added pt.interp, which matches the numpy signature 1:1:

from pytensor.tensor.interpolate import interp

xp = [1., 2., 3.]
fp = [3., 2., 0.]
interp(2.5, xp, fp)
# 1.0
interp([0, 1, 1.5, 2.72, 3.14], xp, fp)
# array([3.  , 3.  , 2.5 , 0.56, 0.  ])

UNDEF = -99.0
interp(3.14, xp, fp, right=UNDEF)
# -99.0

I figure I'll also do a function that do polynomial interpolation (fit via least-squares). The holy grail here is BSplines, but how to do that is somewhat beyond me at the moment. I think just getting these simple methods into the library would be useful, then we can open issues to add more interesting methods.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1141.org.readthedocs.build/en/1141/

@ricardoV94
Copy link
Member

Can you add labels to the PR?

Copy link

codecov bot commented Dec 28, 2024

Codecov Report

Attention: Patch coverage is 95.00000% with 4 lines in your changes missing coverage. Please review.

Project coverage is 82.14%. Comparing base (83c6b44) to head (d5daef1).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/interpolate.py 94.59% 2 Missing and 2 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1141      +/-   ##
==========================================
+ Coverage   82.12%   82.14%   +0.02%     
==========================================
  Files         185      186       +1     
  Lines       48130    48210      +80     
  Branches     8669     8678       +9     
==========================================
+ Hits        39527    39603      +76     
- Misses       6438     6440       +2     
- Partials     2165     2167       +2     
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/extra_ops.py 82.35% <100.00%> (+1.70%) ⬆️
pytensor/tensor/interpolate.py 94.59% <94.59%> (ø)

@jessegrabowski jessegrabowski marked this pull request as ready for review December 30, 2024 12:08
@jessegrabowski jessegrabowski merged commit 4e85676 into pymc-devs:main Dec 30, 2024
60 of 62 checks passed
@jessegrabowski jessegrabowski deleted the interp1d branch December 30, 2024 22:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement 1d interpolation
2 participants