Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MvNormal as cholesky(cov) @ normal #1115

Open
ricardoV94 opened this issue Dec 11, 2024 · 3 comments
Open

Implement MvNormal as cholesky(cov) @ normal #1115

ricardoV94 opened this issue Dec 11, 2024 · 3 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 11, 2024

Description

This is much faster, and even more in PyMC models that are usually parametrized with a direct prior on the cholesky.

import pytensor
import pytensor.tensor as pt

srng = pt.random.RandomStream()

x = srng.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]])
fn = pytensor.function([], x)
%timeit fn()  # 510 µs ± 81.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

# Decompose cholesky in graph (numpy probably does this under the hood)
A = pt.linalg.cholesky([[1, 0.5], [0.5, 1]])
x = A @ srng.normal(size=(2,))
fn = pytensor.function([], x)
%timeit fn()  # 27.4 µs ± 3.27 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In general we should probably reduce the number of pure RV Ops we have. This allows more optimizations and makes it easier to implement different backends.

We should implement the MvNormal as an OpFromGraph that gets inlined after canonicalization (not as early as the ones with inline=True)

@jessegrabowski
Copy link
Member

One negative of using cholesky is that it doesn't allow singular covariance matrices. Any square-root type decomposition will work for generating fast forward samples. For example using SVD:

U, S, V_T = pt.linalg.svd(cov, compute_uv=True)
L = U @ pt.diag(pt.sqrt(S))
x = L @ srng.normal(size=(2,))
fn = pytensor.function([], x)
%timeit fn()  # 8.17 μs ± 76 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Cholesky is the best in general because it gives access to solve_triangular when we're computing the log-likelihood. I wish we exposed an allow_singular option like scipy/JAX that could let the user accept a performance hit to handle this case.

@ricardoV94
Copy link
Member Author

No reason why we can't have both

@jessegrabowski
Copy link
Member

JAX has a method argument that takes SVD, eighsh, or cholesky (default cholesky). I think we should go this route.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants