-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement specialized MvNormal density based on precision matrix #7345
Conversation
Implementation checks may fail until pymc-devs/pytensor#799 is fixed |
3d4fa2f
to
7564037
Compare
7564037
to
f4a828e
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
f4a828e
to
979a765
Compare
Benchmark code import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt
rng = np.random.default_rng(123)
n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)
with pm.Model(check_bounds=False) as m:
Q = pm.Data("Q", Q_test)
x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q)
logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
pytensor.dprint(logp_fn)
dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
pytensor.dprint(dlogp_fn)
np.testing.assert_allclose(logp_fn(x_test), np.array(-1789.93662205))
np.testing.assert_allclose(np.sum(dlogp_fn(x_test) ** 2), np.array(18445204.8755109), rtol=1e-6)
# Before: 2.66 ms
# After: 1.31 ms
%timeit -n 1000 logp_fn(x_test)
# Before: 2.45 ms
# After: 72 µs
%timeit -n 1000 dlogp_fn(x_test) |
1bc0876
to
2b9886e
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7345 +/- ##
==========================================
+ Coverage 92.18% 92.20% +0.01%
==========================================
Files 103 103
Lines 17263 17301 +38
==========================================
+ Hits 15914 15952 +38
Misses 1349 1349
|
Final question is just whether we want / can do a similar thing for the MvStudentT. Otherwise it's ready to merge on my end |
pymc/distributions/multivariate.py
Outdated
[value] = value | ||
k = value.shape[-1] | ||
delta = value - mean | ||
det_sign, logdet_tau = pt.nlinalg.slogdet(tau) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if the slogdet is necessarily a good idea here. Internally at least in numpy this does a lu factorization, which is I think takes theoretically about twice as long as the cholesky, and should usually be less stable. (But I think the performance can differ a lot based on number of threads and blas). So for a matrix that is not constant this might be slower than the usual MvNormal right now.
So I think it is better to use the cholesky decomposition here and use that to get the logdet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the example above the matrix is not constant, so you can use it to benchmark
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here: #7345 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And it twas 2x faster logp and 100x faster dlogp on my crappy PC. Are you skeptical of those numbers, or you think we can do even better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you suggesting something like this https://math.stackexchange.com/questions/2001041/logarithm-of-the-determinant-of-a-positive-definite-matrix ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here would be some code with a non-constant matrix, but I get a NotImplementederror for the grad of Blockwise?
import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt
rng = np.random.default_rng(123)
n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)
v_test = rng.normal(size=n)
with pm.Model(check_bounds=False) as m:
Q = pm.Data("Q", Q_test)
v = pm.Normal("v", shape=n)
x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q + v[None, :] * v[:, None])
logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
#pytensor.dprint(logp_fn)
dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
#pytensor.dprint(dlogp_fn)
#np.testing.assert_allclose(logp_fn(x_test, v_test), np.array(-1789.93662205))
#np.testing.assert_allclose(np.sum(dlogp_fn(x_test, v_test) ** 2), np.array(18445204.8755109), rtol=1e-6)
with threadpoolctl.threadpool_limits(1):
# Before: 2.66 ms
# After: 1.31 ms
%timeit logp_fn(x_test, v_test)
# Before: 2.45 ms
# After: 72 µs
%timeit dlogp_fn(x_test, v_test)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your expression moves the log above the diagonal, which makes sense, but I think the minus in -2 *
is wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The performance of the original code looks I think pretty bad by the way:
#353 μs ± 4.53 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
#1.61 ms ± 32.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
If everything is implemented well I don't think there is a good reason why the gradient should be much slower than just the logp. Factoring the matrx onces should be enough.
I guess this is because we compute the pullback of the cholesky or so, even though the pullback of the logdet is actually pretty easy if you already have a factorization...
There might be a nice usecase for an OpFromGraph and overwriting the forward values hiding here somewhere...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be 40us (1.2x faster) with the cholesky now that I only do log on the diagonal. Funny enough it's fusing the log and the sum, which is our only elemwise + reduce fusion we have :)
logdet_tau = 2 * pt.log(pt.diagonal(pt.linalg.cholesky(tau), axis1=-2, axis2=-1)).sum()
%env OMP_NUM_THREADS=1
import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt
rng = np.random.default_rng(123)
n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)
with pm.Model(check_bounds=False) as m:
Q = pm.Data("Q", Q_test)
x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q)
logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
pytensor.dprint(logp_fn)
dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
pytensor.dprint(dlogp_fn)
np.testing.assert_allclose(logp_fn(x_test), np.array(-1789.93662205))
np.testing.assert_allclose(np.sum(dlogp_fn(x_test) ** 2), np.array(18445204.8755109), rtol=1e-6)
# With slogdet: 236 µs ± 31.7 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# With cholesky logdet: 192 µs ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit -n 10000 logp_fn(x_test)
# With slogdet: 29.8 µs ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# With cholesky logdet: 32.5 µs ± 4.51 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit -n 10000 dlogp_fn(x_test)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pushed a commit with the cholesky factorization
2b9886e
to
7698254
Compare
pymc/distributions/multivariate.py
Outdated
logp = -0.5 * (k * pt.log(2 * pt.pi) - logdet + quadratic_form) | ||
return check_parameters( | ||
logp, | ||
(cholesky_diagonal > 0).all(-1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this is the right check for posdef-ness? @aseyboldt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cholesky will simply fail (throw an exception) if the matrix is not posdef. I think we check for that in the perform method and return nan, but I don't know what for instance jax will do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay so we could check for nan? I'll see what JAX does
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess all(nan) will also be false so this ends up being the same thing? Or does it evaluate to nan...?
I'll check
Last benchmarks, running the following script: %env OMP_NUM_THREADS=1
USE_TAU = True
import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt
rng = np.random.default_rng(123)
n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)
with pm.Model(check_bounds=False) as m:
Q = pm.Data("Q", Q_test)
if USE_TAU:
x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q)
else:
x = pm.MvNormal("x", mu=pt.zeros(n), cov=Q)
logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
pytensor.dprint(logp_fn)
dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
pytensor.dprint(dlogp_fn)
%timeit -n 10000 logp_fn(x_test)
%timeit -n 10000 dlogp_fn(x_test) USE_TAU = TRUE, without optimization:
USE_TAU = True with optimization
For reference: USE_TAU = False before and after (unchanged)Before:
After:
Summary
|
ab6c1a3
to
f50b56d
Compare
|
f50b56d
to
cbd39d1
Compare
@aseyboldt any thing that should block this PR? |
Co-authored-by: theorashid <[email protected]> Co-authored-by: elizavetasemenova <[email protected]> Co-authored-by: aseyboldt <[email protected]>
cbd39d1
to
8550f01
Compare
Looks good. I think it is possible that we could further improve the MvNormal in both parametrizations, but this is definetly an improvement as it is. |
We don't recompute the Cholesky, we have rewrites to remove it and even a specific test for it: pymc/tests/distributions/test_multivariate.py Line 2368 in b407c01
|
Description
This PR is exploring a specialized logp for a MvNormal (and possible MvStudentT) parametrized directly in terms of tau. According to common model implementation looks like:
TODO (some are optional for this PR)
Sparse implementation? May need some ideas like: https://stackoverflow.com/questions/19107617/how-to-compute-scipy-sparse-matrix-determinant-without-turning-it-to-denseInvestigate in a follow up PRRelated Issue
Checklist
Type of change
CC @theorashid @elizavetasemenova
📚 Documentation preview 📚: https://pymc--7345.org.readthedocs.build/en/7345/