Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Kve Op and Kv helper #1081

Merged
merged 2 commits into from
Nov 13, 2024
Merged

Implement Kve Op and Kv helper #1081

merged 2 commits into from
Nov 13, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Nov 11, 2024

Needed for pymc-devs/pymc-extras#389

Related to #1038

CC @AuguB


📚 Documentation preview 📚: https://pytensor--1081.org.readthedocs.build/en/1081/

pytensor/scalar/math.py Outdated Show resolved Hide resolved
Copy link

codecov bot commented Nov 11, 2024

Codecov Report

Attention: Patch coverage is 88.88889% with 3 lines in your changes missing coverage. Please review.

Project coverage is 82.11%. Comparing base (fdbf3aa) to head (301187e).
Report is 8 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/scalar/math.py 88.23% 2 Missing ⚠️
pytensor/link/jax/dispatch/scalar.py 75.00% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1081   +/-   ##
=======================================
  Coverage   82.10%   82.11%           
=======================================
  Files         183      183           
  Lines       47930    47959   +29     
  Branches     8633     8635    +2     
=======================================
+ Hits        39354    39381   +27     
- Misses       6410     6411    +1     
- Partials     2166     2167    +1     
Files with missing lines Coverage Δ
pytensor/tensor/math.py 91.32% <100.00%> (+0.03%) ⬆️
pytensor/tensor/rewriting/math.py 89.85% <100.00%> (+0.01%) ⬆️
pytensor/link/jax/dispatch/scalar.py 94.64% <75.00%> (+0.66%) ⬆️
pytensor/scalar/math.py 87.10% <88.23%> (+0.02%) ⬆️

... and 2 files with indirect coverage changes

@dehorsley
Copy link
Contributor

dehorsley commented Nov 12, 2024

Is it worth exposing Kve as well? I think you can use it to prevent some underflow for large argument in your logp in pymc-devs/pymc-extras#389

@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 12, 2024

Is it worth exposing Kve as well? I think you can use it to prevent some underflow for large argument in your log in pymc-devs/pymc-extras#389

You think so? Do we even need kv then or should it just be a helper around kve? Reducing core Ops is always nice

@dehorsley
Copy link
Contributor

dehorsley commented Nov 12, 2024

You think so?

exp(x) rounds to 0 around x=-700, any values less than that will give you nans (edit: -inf of course) in your logp.

Do we even need kv then or should it just be a helper around kve? Reducing core Ops is always nice

Maybe not! Include a stabilising rewrite for Kv -> Kve * exp(-|x|)

@dehorsley
Copy link
Contributor

I notice TFP also has log_bessel_kve which should be more numerically stable again in this use case. That's probably a good candidate as a rewrite target.

@ricardoV94
Copy link
Member Author

Maybe not! Include a stabilising rewrite for Kv -> Kve * exp(-|x|)

Kv = Kve * exp(x) not -x. I'm not sure it's like the Iv/Ive combo.

There must be cases where kv is preferred otherwise why would it be offered? If it's never preferred there's no reason to have it offered.

Re log, I've also came across a paper that implements a stable form iteratively (see linked discussion), but seems like quite some complexity. I would wait until there's demand for it.

@ricardoV94
Copy link
Member Author

The scipy page shows an example of underflowing earlier, so you're probably right. https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.kve.html

Is kv more stable for smaller x, or does kve win across the whole range. If so I'll just ditch kv altogether

@dehorsley
Copy link
Contributor

Maybe not! Include a stabilising rewrite for Kv -> Kve * exp(-|x|)

Kv = Kve * exp(x) not -x. I'm not sure it's like the Iv/Ive combo.

Other way around is it not? Kv ~ exp(-x)/sqrt(z) for large z, so Kve = Kv * exp(x) is the "non exponential" part

There must be cases where kv is preferred otherwise why would it be offered? If it's never preferred there's no reason to have it offered.

I'm sure there must be but at least since we're mostly dealing with logps, I'd say we'll want to factor out the exponential term most of the time for stability and kve will probably get much more use.

Maybe a good argument for having Kve and not Kv is someone without experience with numerical analysis is more likely to write something stable if they only have Kve 😀

Re log, I've also came across a paper that implements a stable form iteratively (see linked discussion), but seems like quite some complexity. I would wait until there's demand for it.

For a log(kve) specialisation, we could use the TFP implementation (which claims to be more stable) in jax and fallback to the naive implementation in scipy/c. That said, for the range of values I tried with the TFP log(kve), I didn't see any difference between it and the naive implementation, so maybe not worth it.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 12, 2024

Other way around is it not? Kv ~ exp(-x)/sqrt(z) for large z, so Kve = Kv * exp(x) is the "non exponential" part

I didn't think it through with 1/exp(x) = exp(-x) :P

For a log(kve) specialisation, we could use the TFP implementation (which claims to be more stable) in jax and fallback to the naive implementation in scipy/c. That said, for the range of values I tried with the TFP log(kve), I didn't see any difference between it and the naive implementation, so maybe not worth it.

Those are different things. JAX/TFP cannot be used instead of scipy, they are different backends.

But yes we can offer kve, and kv as kve * pt.exp(-x). Gonna try to run a small benchmark to see if it's indeed more stable across the mark

@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 12, 2024

@dehorsley I switched so we have a core kve, and a helper kv based on kve. I also added a stabilization rewrite for log(kv). What do you think?

A specialization for log(kve) can be left for later if there's a demand for more stability. Good to know that there are options.

Should we do something similar for Iv/Ive (and remove the core Iv Op)? If so, I'll open an issue for this, it need not block this PR

@ricardoV94 ricardoV94 changed the title Implement Kv Op Implement Kve Op and Kv helper Nov 12, 2024
@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 12, 2024

Btw @dehorsley if you want to hang out in the devs discord server here is a link: https://discord.gg/vYkmxNuF

Say hi when / if you join so we can grant you access to the private channels

@dehorsley
Copy link
Contributor

@dehorsley I switched so we have a core kve, and a helper kv based on kve. I also added a stabilization rewrite for log(kv). What do you think?

Sounds good to me!

A specialization for log(kve) can be left for later if there's a demand for more stability. Good to know that there are options.

Sure. For reference, it looks like the TFP implementation actually calculates log(Kve)/log(Ive), then applies exp to get the Kve/Ive implementations. I'm not sure how good jax/XLA's optimisation pass is, but maybe it is able to simplify the log(exp) expression too. That might explain why I was seeing bit-for-bit identical outputs between log(tfp.math.bessel_kve(nu, x)) and tfp.math.log_bessel_kve(nu, x).

Incidentally, it seems like TFP's Kve implementation is more numerically stable than SciPy's: scipy.special.kve(1,x) gives NaNs around x=1e10, whereas tfp.math.bessel_kve(1,x) is happy for x all the way up to the float limit.

Should we do something similar for Iv/Ive (and remove the core Iv Op)? If so, I'll open an issue for this, it need not block this PR

Yeah I think so.

Copy link
Contributor

@dehorsley dehorsley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, just a couple of nit-picks 😁

pytensor/tensor/rewriting/math.py Outdated Show resolved Hide resolved
tests/link/jax/test_scalar.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 merged commit 33a4d48 into pymc-devs:main Nov 13, 2024
61 of 62 checks passed
@ricardoV94
Copy link
Member Author

Thanks @dehorsley, nits addressed

@ricardoV94 ricardoV94 added the enhancement New feature or request label Nov 15, 2024
@ricardoV94 ricardoV94 deleted the add-kv branch November 15, 2024 08:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants