Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyMC/PyTensor Implementation of Pathfinder VI #387

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4540b84
renamed samples argument name and pathfinder variables to avoid confu…
aphc14 Oct 3, 2024
0c880d2
Minor changes made to the `fit_pathfinder` function and added test
aphc14 Oct 19, 2024
8835cd5
extract additional pathfinder objects from high level API for debugging
aphc14 Sep 17, 2024
663a60a
changed pathfinder samples argument to num_draws
aphc14 Oct 26, 2024
05aeeaf
Merge branch 'replicate_pathfinder_w_pytensor' into scipy_lbfgs
aphc14 Oct 26, 2024
0db91fe
feat(pathfinder): add PyMC-based Pathfinder VI implementation
aphc14 Oct 31, 2024
cb4436c
Multipath Pathfinder VI implementation in pymc-experimental
aphc14 Nov 4, 2024
2efb511
Added type hints and epsilon parameter to fit_pathfinder
aphc14 Nov 7, 2024
fdc3f38
Removed initial point values (l=0) to reduce iterations. Simplified …
aphc14 Nov 7, 2024
1fd7a11
Added placeholder/reminder to remove jax dependency when converting t…
aphc14 Nov 7, 2024
ef2956f
Sync updates with draft PR #386. \n- Added pytensor.function for bfgs…
aphc14 Nov 7, 2024
8b134b7
Reduced size of compute graph with pathfinder_body_fn
aphc14 Nov 11, 2024
6484b3d
- Added TODO comments for implementing Taylor approximation methods: …
aphc14 Nov 14, 2024
aa765fb
fix: correct posterior approximations in Pathfinder VI
aphc14 Nov 21, 2024
4299a58
feat: Add dense BFGS sampling for Pathfinder VI
aphc14 Nov 21, 2024
f1a54c6
feat: improve Pathfinder performance and compatibility
aphc14 Nov 24, 2024
ea802fc
minor: improve error handling in Pathfinder VI
aphc14 Nov 25, 2024
a77f2c8
Progress bar and other minor changes
aphc14 Nov 27, 2024
9faaa72
set maxcor to max(5, floor(N / 1.9)). max=1 will cause error
aphc14 Nov 27, 2024
2815c4f
Merge branch 'main' into pathfinder_w_pytensor_symbolic
aphc14 Dec 7, 2024
e4b8996
Refactor Pathfinder VI: Default to PSIS, Add Concurrency, and Improve…
aphc14 Dec 7, 2024
885afaa
Improvements to Importance Sampling and InferenceData shape
aphc14 Dec 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pymc_experimental/inference/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ def fit(method, **kwargs):
arviz.InferenceData
"""
if method == "pathfinder":
# TODO: Remove this once we have a pure PyMC implementation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR will provide that, no?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the latest commit addresses this

if find_spec("blackjax") is None:
raise RuntimeError("Need BlackJAX to use `pathfinder`")

from pymc_experimental.inference.pathfinder import fit_pathfinder

# TODO: edit **kwargs to be more consistent with fit_pathfinder with blackjax and pymc backends.
return fit_pathfinder(**kwargs)

if method == "laplace":
Expand Down
134 changes: 0 additions & 134 deletions pymc_experimental/inference/pathfinder.py

This file was deleted.

3 changes: 3 additions & 0 deletions pymc_experimental/inference/pathfinder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pymc_experimental.inference.pathfinder.pathfinder import fit_pathfinder

__all__ = ["fit_pathfinder"]
73 changes: 73 additions & 0 deletions pymc_experimental/inference/pathfinder/importance_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging

import arviz as az
import numpy as np

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)


def psir(
samples: np.ndarray,
logP: np.ndarray,
logQ: np.ndarray,
num_draws: int = 1000,
random_seed: int | None = None,
) -> np.ndarray:
"""Pareto Smoothed Importance Resampling (PSIR)
This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS.

Parameters
----------
samples : np.ndarray
samples from proposal distribution
logP : np.ndarray
log probability of target distribution
logQ : np.ndarray
log probability of proposal distribution
num_draws : int
number of draws to return where num_draws <= samples.shape[0]
random_seed : int | None

Returns
-------
np.ndarray
importance sampled draws

Future work!
----------
- Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019)
- Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018)
- Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms.

References
----------
Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668

Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538

Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
"""

def logsumexp(x):
aphc14 marked this conversation as resolved.
Show resolved Hide resolved
c = x.max()
return c + np.log(np.sum(np.exp(x - c)))

logiw = np.reshape(logP - logQ, -1, order="F")
psislw, pareto_k = az.psislw(logiw)

# FIXME: pareto_k is mostly bad, find out why!
if pareto_k <= 0.70:
pass
elif 0.70 < pareto_k <= 1:
logger.warning("pareto_k is bad: %f", pareto_k)
logger.info("consider increasing ftol, gtol or maxcor parameters")
else:
logger.warning("pareto_k is very bad: %f", pareto_k)
logger.info(
"consider reparametrising the model, increasing ftol, gtol or maxcor parameters"
)

p = np.exp(psislw - logsumexp(psislw))
rng = np.random.default_rng(random_seed)
return rng.choice(samples, size=num_draws, p=p, shuffle=False, axis=0)
92 changes: 92 additions & 0 deletions pymc_experimental/inference/pathfinder/lbfgs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from collections.abc import Callable
from typing import NamedTuple

import numpy as np

from scipy.optimize import minimize


class LBFGSHistory(NamedTuple):
x: np.ndarray
f: np.ndarray
g: np.ndarray


class LBFGSHistoryManager:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaner to use a data class? Don't know.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I agree. dataclass now added

def __init__(self, fn: Callable, grad_fn: Callable, x0: np.ndarray, maxiter: int):
dim = x0.shape[0]
maxiter_add_one = maxiter + 1
# Pre-allocate arrays to save memory and improve speed
self.x_history = np.empty((maxiter_add_one, dim), dtype=np.float64)
self.f_history = np.empty(maxiter_add_one, dtype=np.float64)
self.g_history = np.empty((maxiter_add_one, dim), dtype=np.float64)
self.count = 0
self.fn = fn
self.grad_fn = grad_fn
self.add_entry(x0, fn(x0), grad_fn(x0))

def add_entry(self, x, f, g=None):
self.x_history[self.count] = x
self.f_history[self.count] = f
if self.g_history is not None and g is not None:
self.g_history[self.count] = g
self.count += 1

def get_history(self):
# Return trimmed arrays up to the number of entries actually used
x = self.x_history[: self.count]
f = self.f_history[: self.count]
g = self.g_history[: self.count] if self.g_history is not None else None
return LBFGSHistory(
x=x,
f=f,
g=g,
)

def __call__(self, x):
self.add_entry(x, self.fn(x), self.grad_fn(x))


def lbfgs(
fn,
grad_fn,
x0: np.ndarray,
maxcor: int | None = None,
maxiter=1000,
ftol=1e-5,
gtol=1e-8,
maxls=1000,
**lbfgs_kwargs,
) -> LBFGSHistory:
def callback(xk):
lbfgs_history_manager(xk)

lbfgs_history_manager = LBFGSHistoryManager(
fn=fn,
grad_fn=grad_fn,
x0=x0,
maxiter=maxiter,
)

default_lbfgs_options = dict(
maxcor=maxcor,
maxiter=maxiter,
ftol=ftol,
gtol=gtol,
maxls=maxls,
)
options = lbfgs_kwargs.pop("options", {})
options = default_lbfgs_options | options

# TODO: return the status of the lbfgs optimisation to handle the case where the optimisation fails. More details in the _single_pathfinder function.

minimize(
fn,
x0,
method="L-BFGS-B",
jac=grad_fn,
options=options,
callback=callback,
**lbfgs_kwargs,
)
return lbfgs_history_manager.get_history()
Loading