-
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyMC/PyTensor Implementation of Pathfinder VI #387
base: main
Are you sure you want to change the base?
Changes from 12 commits
4540b84
0c880d2
8835cd5
663a60a
05aeeaf
0db91fe
cb4436c
2efb511
fdc3f38
1fd7a11
ef2956f
8b134b7
6484b3d
aa765fb
4299a58
f1a54c6
ea802fc
a77f2c8
9faaa72
2815c4f
e4b8996
885afaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from pymc_experimental.inference.pathfinder.pathfinder import fit_pathfinder | ||
|
||
__all__ = ["fit_pathfinder"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import logging | ||
|
||
import arviz as az | ||
import numpy as np | ||
import pytensor.tensor as pt | ||
|
||
from pytensor.graph import Apply, Op | ||
from pytensor.tensor.variable import TensorVariable | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class PSIS(Op): | ||
__props__ = () | ||
|
||
def make_node(self, inputs): | ||
logweights = pt.as_tensor(inputs) | ||
psislw = pt.dvector() | ||
pareto_k = pt.dscalar() | ||
return Apply(self, [logweights], [psislw, pareto_k]) | ||
|
||
def perform(self, node: Apply, inputs, outputs) -> None: | ||
logweights = inputs[0] | ||
psislw, pareto_k = az.psislw(logweights) | ||
outputs[0][0] = psislw | ||
outputs[1][0] = pareto_k | ||
|
||
|
||
def psir( | ||
samples: TensorVariable, | ||
# logP: TensorVariable, | ||
# logQ: TensorVariable, | ||
logiw: TensorVariable, | ||
num_draws: int = 1000, | ||
random_seed: int | None = None, | ||
) -> np.ndarray: | ||
"""Pareto Smoothed Importance Resampling (PSIR) | ||
This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS. | ||
|
||
Parameters | ||
---------- | ||
samples : np.ndarray | ||
samples from proposal distribution | ||
logP : np.ndarray | ||
log probability of target distribution | ||
logQ : np.ndarray | ||
log probability of proposal distribution | ||
num_draws : int | ||
number of draws to return where num_draws <= samples.shape[0] | ||
random_seed : int | None | ||
|
||
Returns | ||
------- | ||
np.ndarray | ||
importance sampled draws | ||
|
||
Future work! | ||
---------- | ||
- Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019) | ||
- Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018) | ||
- Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms. | ||
|
||
References | ||
---------- | ||
Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668 | ||
|
||
Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538 | ||
|
||
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49. | ||
""" | ||
# logiw = np.reshape(logP - logQ, (-1,), order="F") | ||
# logiw = (logP - logQ).ravel() | ||
psislw, pareto_k = PSIS()(logiw) | ||
pareto_k = pareto_k.eval() | ||
# FIXME: pareto_k is mostly bad, find out why! | ||
if pareto_k <= 0.70: | ||
pass | ||
elif 0.70 < pareto_k <= 1: | ||
logger.warning("pareto_k is bad: %f", pareto_k) | ||
logger.info("consider increasing ftol, gtol or maxcor parameters") | ||
else: | ||
logger.warning("pareto_k is very bad: %f", pareto_k) | ||
logger.info( | ||
"consider reparametrising the model, increasing ftol, gtol or maxcor parameters" | ||
) | ||
|
||
p = pt.exp(psislw - pt.logsumexp(psislw)).eval() | ||
rng = np.random.default_rng(random_seed) | ||
return rng.choice(samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from collections.abc import Callable | ||
from typing import NamedTuple | ||
|
||
import numpy as np | ||
import pytensor.tensor as pt | ||
|
||
from pytensor.graph import Apply, Op | ||
from scipy.optimize import minimize | ||
|
||
|
||
class LBFGSHistory(NamedTuple): | ||
x: np.ndarray | ||
g: np.ndarray | ||
|
||
|
||
class LBFGSHistoryManager: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cleaner to use a data class? Don't know. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, I agree. dataclass now added |
||
def __init__(self, grad_fn: Callable, x0: np.ndarray, maxiter: int): | ||
dim = x0.shape[0] | ||
maxiter_add_one = maxiter + 1 | ||
# Pre-allocate arrays to save memory and improve speed | ||
self.x_history = np.empty((maxiter_add_one, dim), dtype=np.float64) | ||
self.g_history = np.empty((maxiter_add_one, dim), dtype=np.float64) | ||
self.count = 0 | ||
self.grad_fn = grad_fn | ||
self.add_entry(x0, grad_fn(x0)) | ||
|
||
def add_entry(self, x, g): | ||
self.x_history[self.count] = x | ||
self.g_history[self.count] = g | ||
self.count += 1 | ||
|
||
def get_history(self): | ||
# Return trimmed arrays up to L << L^max | ||
x = self.x_history[: self.count] | ||
g = self.g_history[: self.count] | ||
return LBFGSHistory( | ||
x=x, | ||
g=g, | ||
) | ||
|
||
def __call__(self, x): | ||
grad = self.grad_fn(x) | ||
if np.all(np.isfinite(grad)): | ||
self.add_entry(x, grad) | ||
|
||
|
||
def lbfgs( | ||
fn, | ||
grad_fn, | ||
x0: np.ndarray, | ||
maxcor: int | None = None, | ||
maxiter=1000, | ||
ftol=1e-5, | ||
gtol=1e-8, | ||
maxls=1000, | ||
**lbfgs_kwargs, | ||
) -> LBFGSHistory: | ||
def callback(xk): | ||
lbfgs_history_manager(xk) | ||
|
||
lbfgs_history_manager = LBFGSHistoryManager( | ||
grad_fn=grad_fn, | ||
x0=x0, | ||
maxiter=maxiter, | ||
) | ||
|
||
default_lbfgs_options = dict( | ||
maxcor=maxcor, | ||
maxiter=maxiter, | ||
ftol=ftol, | ||
gtol=gtol, | ||
maxls=maxls, | ||
) | ||
options = lbfgs_kwargs.pop("options", {}) | ||
options = default_lbfgs_options | options | ||
|
||
# TODO: return the status of the lbfgs optimisation to handle the case where the optimisation fails. More details in the _single_pathfinder function. | ||
|
||
minimize( | ||
fn, | ||
x0, | ||
method="L-BFGS-B", | ||
jac=grad_fn, | ||
options=options, | ||
callback=callback, | ||
**lbfgs_kwargs, | ||
) | ||
lbfgs_history = lbfgs_history_manager.get_history() | ||
return lbfgs_history.x, lbfgs_history.g | ||
|
||
|
||
class LBFGSOp(Op): | ||
def __init__(self, fn, grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add type hints throughout (and docstrings, ideally). |
||
self.fn = fn | ||
self.grad_fn = grad_fn | ||
self.maxcor = maxcor | ||
self.maxiter = maxiter | ||
self.ftol = ftol | ||
self.gtol = gtol | ||
self.maxls = maxls | ||
|
||
def make_node(self, x0): | ||
x0 = pt.as_tensor_variable(x0) | ||
x_history = pt.dmatrix() | ||
g_history = pt.dmatrix() | ||
return Apply(self, [x0], [x_history, g_history]) | ||
|
||
def perform(self, node, inputs, outputs): | ||
x0 = inputs[0] | ||
x0 = np.array(x0, dtype=np.float64) | ||
|
||
history_manager = LBFGSHistoryManager(grad_fn=self.grad_fn, x0=x0, maxiter=self.maxiter) | ||
|
||
minimize( | ||
self.fn, | ||
x0, | ||
method="L-BFGS-B", | ||
jac=self.grad_fn, | ||
callback=history_manager, | ||
options={ | ||
"maxcor": self.maxcor, | ||
"maxiter": self.maxiter, | ||
"ftol": self.ftol, | ||
"gtol": self.gtol, | ||
"maxls": self.maxls, | ||
}, | ||
) | ||
|
||
# fmin_l_bfgs_b( | ||
# func=self.fn, | ||
# fprime=self.grad_fn, | ||
# x0=x0, | ||
# pgtol=self.gtol, | ||
# factr=self.ftol / np.finfo(float).eps, | ||
# maxls=self.maxls, | ||
# maxiter=self.maxiter, | ||
# m=self.maxcor, | ||
# callback=history_manager, | ||
# ) | ||
|
||
outputs[0][0] = history_manager.get_history().x | ||
outputs[1][0] = history_manager.get_history().g |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR will provide that, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the latest commit addresses this