Skip to content

Commit

Permalink
Recovered branch
Browse files Browse the repository at this point in the history
  • Loading branch information
aphc14 committed Dec 7, 2024
2 parents f1cd343 + f45f800 commit a164296
Show file tree
Hide file tree
Showing 4 changed files with 1,232 additions and 136 deletions.
136 changes: 0 additions & 136 deletions pymc_experimental/inference/pathfinder.py

This file was deleted.

128 changes: 128 additions & 0 deletions pymc_experimental/inference/pathfinder/importance_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import logging
import warnings

from typing import Literal

import arviz as az
import numpy as np
import pytensor.tensor as pt

from pytensor.graph import Apply, Op
from pytensor.tensor.variable import TensorVariable

logger = logging.getLogger(__name__)


class PSIS(Op):
__props__ = ()

def make_node(self, inputs):
logweights = pt.as_tensor(inputs)
psislw = pt.dvector()
pareto_k = pt.dscalar()
return Apply(self, [logweights], [psislw, pareto_k])

def perform(self, node: Apply, inputs, outputs) -> None:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
)
logweights = inputs[0]
psislw, pareto_k = az.psislw(logweights)
outputs[0][0] = psislw
outputs[1][0] = pareto_k


def importance_sampling(
samples: TensorVariable,
# logP: TensorVariable,
# logQ: TensorVariable,
logiw: TensorVariable,
num_draws: int,
method: Literal["psis", "psir", "identity", "none"],
random_seed: int | None = None,
) -> np.ndarray:
"""Pareto Smoothed Importance Resampling (PSIR)
This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS.
Parameters
----------
samples : np.ndarray
samples from proposal distribution
logP : np.ndarray
log probability of target distribution
logQ : np.ndarray
log probability of proposal distribution
num_draws : int
number of draws to return where num_draws <= samples.shape[0]
method : str, optional
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
random_seed : int | None
Returns
-------
np.ndarray
importance sampled draws
Future work!
----------
- Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019)
- Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018)
- Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms.
References
----------
Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668
Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
"""

if method == "psis":
replace = False
logiw, pareto_k = PSIS()(logiw)
elif method == "psir":
replace = True
logiw, pareto_k = PSIS()(logiw)
elif method == "identity":
replace = False
logiw = logiw
pareto_k = None
elif method == "none":
logger.warning(
"importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
)
return samples

# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
# Pareto k may not be a good diagnostic for Pathfinder.
if pareto_k is not None:
pareto_k = pareto_k.eval()
if pareto_k < 0.5:
pass
elif 0.5 <= pareto_k < 0.70:
logger.info(
f"Pareto k value ({pareto_k:.2f}) is between 0.5 and 0.7 which indicates an imperfect approximation however still useful."
)
logger.info("Consider increasing ftol, gtol, maxcor or num_paths.")
elif pareto_k >= 0.7:
logger.info(
f"Pareto k value ({pareto_k:.2f}) exceeds 0.7 which indicates a bad approximation."
)
logger.info(
"Consider increasing ftol, gtol, maxcor, num_paths or reparametrising the model."
)
else:
logger.info(
f"Received an invalid Pareto k value of {pareto_k:.2f} which indicates the model is seriously flawed."
)
logger.info(
"Consider reparametrising the model all together or ensure the input data are correct."
)

logger.warning(f"Pareto k value: {pareto_k:.2f}")

p = pt.exp(logiw - pt.logsumexp(logiw)).eval()
rng = np.random.default_rng(random_seed)
return rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)
126 changes: 126 additions & 0 deletions pymc_experimental/inference/pathfinder/lbfgs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import logging

from collections.abc import Callable
from dataclasses import dataclass, field

import numpy as np
import pytensor.tensor as pt

from numpy.typing import NDArray
from pytensor.graph import Apply, Op
from scipy.optimize import minimize

logger = logging.getLogger(__name__)


@dataclass(slots=True)
class LBFGSHistory:
x: NDArray[np.float64]
g: NDArray[np.float64]

def __post_init__(self):
self.x = np.ascontiguousarray(self.x, dtype=np.float64)
self.g = np.ascontiguousarray(self.g, dtype=np.float64)


@dataclass(slots=True)
class LBFGSHistoryManager:
fn: Callable[[NDArray[np.float64]], np.float64]
grad_fn: Callable[[NDArray[np.float64]], NDArray[np.float64]]
x0: NDArray[np.float64]
maxiter: int
x_history: NDArray[np.float64] = field(init=False)
g_history: NDArray[np.float64] = field(init=False)
count: int = field(init=False, default=0)

def __post_init__(self) -> None:
self.x_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64)
self.g_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64)

value = self.fn(self.x0)
grad = self.grad_fn(self.x0)
if np.all(np.isfinite(grad)) and np.isfinite(value):
self.x_history[0] = self.x0
self.g_history[0] = grad
self.count = 1

def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None:
self.x_history[self.count] = x
self.g_history[self.count] = g
self.count += 1

def get_history(self) -> LBFGSHistory:
return LBFGSHistory(x=self.x_history[: self.count], g=self.g_history[: self.count])

def __call__(self, x: NDArray[np.float64]) -> None:
value = self.fn(x)
grad = self.grad_fn(x)
if np.all(np.isfinite(grad)) and np.isfinite(value) and self.count < self.maxiter + 1:
self.add_entry(x, grad)


class LBFGSInitFailed(Exception):
DEFAULT_MESSAGE = "LBFGS failed to initialise."

def __init__(self, message=None):
if message is None:
message = self.DEFAULT_MESSAGE
super().__init__(message)


class LBFGSOp(Op):
__props__ = ("fn", "grad_fn", "maxcor", "maxiter", "ftol", "gtol", "maxls")

def __init__(self, fn, grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000):
self.fn = fn
self.grad_fn = grad_fn
self.maxcor = maxcor
self.maxiter = maxiter
self.ftol = ftol
self.gtol = gtol
self.maxls = maxls

def make_node(self, x0):
x0 = pt.as_tensor_variable(x0)
x_history = pt.dmatrix()
g_history = pt.dmatrix()
return Apply(self, [x0], [x_history, g_history])

def perform(self, node, inputs, outputs):
x0 = inputs[0]
x0 = np.array(x0, dtype=np.float64)

history_manager = LBFGSHistoryManager(
fn=self.fn, grad_fn=self.grad_fn, x0=x0, maxiter=self.maxiter
)

result = minimize(
self.fn,
x0,
method="L-BFGS-B",
jac=self.grad_fn,
callback=history_manager,
options={
"maxcor": self.maxcor,
"maxiter": self.maxiter,
"ftol": self.ftol,
"gtol": self.gtol,
"maxls": self.maxls,
},
)

if result.status == 1:
logger.info("LBFGS maximum number of iterations reached. Consider increasing maxiter.")
elif (result.status == 2) or (history_manager.count <= 1):
if result.nit <= 1:
logger.info(
"LBFGS failed to initialise. The model might be degenerate or the jitter might be too large."
)
raise LBFGSInitFailed
elif result.fun == np.inf:
logger.info(
"LBFGS diverged to infinity. The model might be degenerate or requires reparameterisation."
)

outputs[0][0] = history_manager.get_history().x
outputs[1][0] = history_manager.get_history().g
Loading

0 comments on commit a164296

Please sign in to comment.