-
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
1,232 additions
and
136 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
128 changes: 128 additions & 0 deletions
128
pymc_experimental/inference/pathfinder/importance_sampling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import logging | ||
import warnings | ||
|
||
from typing import Literal | ||
|
||
import arviz as az | ||
import numpy as np | ||
import pytensor.tensor as pt | ||
|
||
from pytensor.graph import Apply, Op | ||
from pytensor.tensor.variable import TensorVariable | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class PSIS(Op): | ||
__props__ = () | ||
|
||
def make_node(self, inputs): | ||
logweights = pt.as_tensor(inputs) | ||
psislw = pt.dvector() | ||
pareto_k = pt.dscalar() | ||
return Apply(self, [logweights], [psislw, pareto_k]) | ||
|
||
def perform(self, node: Apply, inputs, outputs) -> None: | ||
with warnings.catch_warnings(): | ||
warnings.filterwarnings( | ||
"ignore", category=RuntimeWarning, message="overflow encountered in exp" | ||
) | ||
logweights = inputs[0] | ||
psislw, pareto_k = az.psislw(logweights) | ||
outputs[0][0] = psislw | ||
outputs[1][0] = pareto_k | ||
|
||
|
||
def importance_sampling( | ||
samples: TensorVariable, | ||
# logP: TensorVariable, | ||
# logQ: TensorVariable, | ||
logiw: TensorVariable, | ||
num_draws: int, | ||
method: Literal["psis", "psir", "identity", "none"], | ||
random_seed: int | None = None, | ||
) -> np.ndarray: | ||
"""Pareto Smoothed Importance Resampling (PSIR) | ||
This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS. | ||
Parameters | ||
---------- | ||
samples : np.ndarray | ||
samples from proposal distribution | ||
logP : np.ndarray | ||
log probability of target distribution | ||
logQ : np.ndarray | ||
log probability of proposal distribution | ||
num_draws : int | ||
number of draws to return where num_draws <= samples.shape[0] | ||
method : str, optional | ||
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths. | ||
random_seed : int | None | ||
Returns | ||
------- | ||
np.ndarray | ||
importance sampled draws | ||
Future work! | ||
---------- | ||
- Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019) | ||
- Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018) | ||
- Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms. | ||
References | ||
---------- | ||
Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668 | ||
Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538 | ||
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49. | ||
""" | ||
|
||
if method == "psis": | ||
replace = False | ||
logiw, pareto_k = PSIS()(logiw) | ||
elif method == "psir": | ||
replace = True | ||
logiw, pareto_k = PSIS()(logiw) | ||
elif method == "identity": | ||
replace = False | ||
logiw = logiw | ||
pareto_k = None | ||
elif method == "none": | ||
logger.warning( | ||
"importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability." | ||
) | ||
return samples | ||
|
||
# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI. | ||
# Pareto k may not be a good diagnostic for Pathfinder. | ||
if pareto_k is not None: | ||
pareto_k = pareto_k.eval() | ||
if pareto_k < 0.5: | ||
pass | ||
elif 0.5 <= pareto_k < 0.70: | ||
logger.info( | ||
f"Pareto k value ({pareto_k:.2f}) is between 0.5 and 0.7 which indicates an imperfect approximation however still useful." | ||
) | ||
logger.info("Consider increasing ftol, gtol, maxcor or num_paths.") | ||
elif pareto_k >= 0.7: | ||
logger.info( | ||
f"Pareto k value ({pareto_k:.2f}) exceeds 0.7 which indicates a bad approximation." | ||
) | ||
logger.info( | ||
"Consider increasing ftol, gtol, maxcor, num_paths or reparametrising the model." | ||
) | ||
else: | ||
logger.info( | ||
f"Received an invalid Pareto k value of {pareto_k:.2f} which indicates the model is seriously flawed." | ||
) | ||
logger.info( | ||
"Consider reparametrising the model all together or ensure the input data are correct." | ||
) | ||
|
||
logger.warning(f"Pareto k value: {pareto_k:.2f}") | ||
|
||
p = pt.exp(logiw - pt.logsumexp(logiw)).eval() | ||
rng = np.random.default_rng(random_seed) | ||
return rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import logging | ||
|
||
from collections.abc import Callable | ||
from dataclasses import dataclass, field | ||
|
||
import numpy as np | ||
import pytensor.tensor as pt | ||
|
||
from numpy.typing import NDArray | ||
from pytensor.graph import Apply, Op | ||
from scipy.optimize import minimize | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass(slots=True) | ||
class LBFGSHistory: | ||
x: NDArray[np.float64] | ||
g: NDArray[np.float64] | ||
|
||
def __post_init__(self): | ||
self.x = np.ascontiguousarray(self.x, dtype=np.float64) | ||
self.g = np.ascontiguousarray(self.g, dtype=np.float64) | ||
|
||
|
||
@dataclass(slots=True) | ||
class LBFGSHistoryManager: | ||
fn: Callable[[NDArray[np.float64]], np.float64] | ||
grad_fn: Callable[[NDArray[np.float64]], NDArray[np.float64]] | ||
x0: NDArray[np.float64] | ||
maxiter: int | ||
x_history: NDArray[np.float64] = field(init=False) | ||
g_history: NDArray[np.float64] = field(init=False) | ||
count: int = field(init=False, default=0) | ||
|
||
def __post_init__(self) -> None: | ||
self.x_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64) | ||
self.g_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64) | ||
|
||
value = self.fn(self.x0) | ||
grad = self.grad_fn(self.x0) | ||
if np.all(np.isfinite(grad)) and np.isfinite(value): | ||
self.x_history[0] = self.x0 | ||
self.g_history[0] = grad | ||
self.count = 1 | ||
|
||
def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None: | ||
self.x_history[self.count] = x | ||
self.g_history[self.count] = g | ||
self.count += 1 | ||
|
||
def get_history(self) -> LBFGSHistory: | ||
return LBFGSHistory(x=self.x_history[: self.count], g=self.g_history[: self.count]) | ||
|
||
def __call__(self, x: NDArray[np.float64]) -> None: | ||
value = self.fn(x) | ||
grad = self.grad_fn(x) | ||
if np.all(np.isfinite(grad)) and np.isfinite(value) and self.count < self.maxiter + 1: | ||
self.add_entry(x, grad) | ||
|
||
|
||
class LBFGSInitFailed(Exception): | ||
DEFAULT_MESSAGE = "LBFGS failed to initialise." | ||
|
||
def __init__(self, message=None): | ||
if message is None: | ||
message = self.DEFAULT_MESSAGE | ||
super().__init__(message) | ||
|
||
|
||
class LBFGSOp(Op): | ||
__props__ = ("fn", "grad_fn", "maxcor", "maxiter", "ftol", "gtol", "maxls") | ||
|
||
def __init__(self, fn, grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000): | ||
self.fn = fn | ||
self.grad_fn = grad_fn | ||
self.maxcor = maxcor | ||
self.maxiter = maxiter | ||
self.ftol = ftol | ||
self.gtol = gtol | ||
self.maxls = maxls | ||
|
||
def make_node(self, x0): | ||
x0 = pt.as_tensor_variable(x0) | ||
x_history = pt.dmatrix() | ||
g_history = pt.dmatrix() | ||
return Apply(self, [x0], [x_history, g_history]) | ||
|
||
def perform(self, node, inputs, outputs): | ||
x0 = inputs[0] | ||
x0 = np.array(x0, dtype=np.float64) | ||
|
||
history_manager = LBFGSHistoryManager( | ||
fn=self.fn, grad_fn=self.grad_fn, x0=x0, maxiter=self.maxiter | ||
) | ||
|
||
result = minimize( | ||
self.fn, | ||
x0, | ||
method="L-BFGS-B", | ||
jac=self.grad_fn, | ||
callback=history_manager, | ||
options={ | ||
"maxcor": self.maxcor, | ||
"maxiter": self.maxiter, | ||
"ftol": self.ftol, | ||
"gtol": self.gtol, | ||
"maxls": self.maxls, | ||
}, | ||
) | ||
|
||
if result.status == 1: | ||
logger.info("LBFGS maximum number of iterations reached. Consider increasing maxiter.") | ||
elif (result.status == 2) or (history_manager.count <= 1): | ||
if result.nit <= 1: | ||
logger.info( | ||
"LBFGS failed to initialise. The model might be degenerate or the jitter might be too large." | ||
) | ||
raise LBFGSInitFailed | ||
elif result.fun == np.inf: | ||
logger.info( | ||
"LBFGS diverged to infinity. The model might be degenerate or requires reparameterisation." | ||
) | ||
|
||
outputs[0][0] = history_manager.get_history().x | ||
outputs[1][0] = history_manager.get_history().g |
Oops, something went wrong.