Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyMC/PyTensor Implementation of Pathfinder VI #387

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4540b84
renamed samples argument name and pathfinder variables to avoid confu…
aphc14 Oct 3, 2024
0c880d2
Minor changes made to the `fit_pathfinder` function and added test
aphc14 Oct 19, 2024
8835cd5
extract additional pathfinder objects from high level API for debugging
aphc14 Sep 17, 2024
663a60a
changed pathfinder samples argument to num_draws
aphc14 Oct 26, 2024
05aeeaf
Merge branch 'replicate_pathfinder_w_pytensor' into scipy_lbfgs
aphc14 Oct 26, 2024
0db91fe
feat(pathfinder): add PyMC-based Pathfinder VI implementation
aphc14 Oct 31, 2024
cb4436c
Multipath Pathfinder VI implementation in pymc-experimental
aphc14 Nov 4, 2024
2efb511
Added type hints and epsilon parameter to fit_pathfinder
aphc14 Nov 7, 2024
fdc3f38
Removed initial point values (l=0) to reduce iterations. Simplified …
aphc14 Nov 7, 2024
1fd7a11
Added placeholder/reminder to remove jax dependency when converting t…
aphc14 Nov 7, 2024
ef2956f
Sync updates with draft PR #386. \n- Added pytensor.function for bfgs…
aphc14 Nov 7, 2024
8b134b7
Reduced size of compute graph with pathfinder_body_fn
aphc14 Nov 11, 2024
6484b3d
- Added TODO comments for implementing Taylor approximation methods: …
aphc14 Nov 14, 2024
aa765fb
fix: correct posterior approximations in Pathfinder VI
aphc14 Nov 21, 2024
4299a58
feat: Add dense BFGS sampling for Pathfinder VI
aphc14 Nov 21, 2024
f1a54c6
feat: improve Pathfinder performance and compatibility
aphc14 Nov 24, 2024
ea802fc
minor: improve error handling in Pathfinder VI
aphc14 Nov 25, 2024
a77f2c8
Progress bar and other minor changes
aphc14 Nov 27, 2024
9faaa72
set maxcor to max(5, floor(N / 1.9)). max=1 will cause error
aphc14 Nov 27, 2024
2815c4f
Merge branch 'main' into pathfinder_w_pytensor_symbolic
aphc14 Dec 7, 2024
e4b8996
Refactor Pathfinder VI: Default to PSIS, Add Concurrency, and Improve…
aphc14 Dec 7, 2024
885afaa
Improvements to Importance Sampling and InferenceData shape
aphc14 Dec 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pymc_experimental/inference/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from importlib.util import find_spec


def fit(method, **kwargs):
Expand All @@ -31,11 +30,9 @@ def fit(method, **kwargs):
arviz.InferenceData
"""
if method == "pathfinder":
if find_spec("blackjax") is None:
raise RuntimeError("Need BlackJAX to use `pathfinder`")

from pymc_experimental.inference.pathfinder import fit_pathfinder

# TODO: edit **kwargs to be more consistent with fit_pathfinder with blackjax and pymc backends.
return fit_pathfinder(**kwargs)

if method == "laplace":
Expand Down
134 changes: 0 additions & 134 deletions pymc_experimental/inference/pathfinder.py

This file was deleted.

3 changes: 3 additions & 0 deletions pymc_experimental/inference/pathfinder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pymc_experimental.inference.pathfinder.pathfinder import fit_pathfinder

__all__ = ["fit_pathfinder"]
142 changes: 142 additions & 0 deletions pymc_experimental/inference/pathfinder/importance_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import logging
import warnings

from typing import Literal

import arviz as az
import numpy as np
import pytensor.tensor as pt

from pytensor.graph import Apply, Op

logger = logging.getLogger(__name__)


class PSIS(Op):
__props__ = ()

def make_node(self, inputs):
logweights = pt.as_tensor(inputs)
psislw = pt.dvector()
pareto_k = pt.dscalar()
return Apply(self, [logweights], [psislw, pareto_k])

def perform(self, node: Apply, inputs, outputs) -> None:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
)
logweights = inputs[0]
psislw, pareto_k = az.psislw(logweights)
outputs[0][0] = psislw
outputs[1][0] = pareto_k


def importance_sampling(
samples: np.ndarray,
logP: np.ndarray,
logQ: np.ndarray,
num_draws: int,
method: Literal["psis", "psir", "identity", "none"],
logiw: np.ndarray | None = None,
random_seed: int | None = None,
) -> np.ndarray:
"""Pareto Smoothed Importance Resampling (PSIR)
This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS.

Parameters
----------
samples : np.ndarray
samples from proposal distribution
logP : np.ndarray
log probability of target distribution
logQ : np.ndarray
log probability of proposal distribution
num_draws : int
number of draws to return where num_draws <= samples.shape[0]
method : str, optional
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
random_seed : int | None

Returns
-------
np.ndarray
importance sampled draws

Future work!
----------
- Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019)
- Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018)
- Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms.

References
----------
Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668

Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538

Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
"""

num_paths, num_pdraws, N = samples.shape

if method == "none":
logger.warning(
"importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
)
return samples
else:
samples = samples.reshape(-1, N)
logP = logP.ravel()
logQ = logQ.ravel()

# adjust log densities
log_I = np.log(num_paths)
logP -= log_I
logQ -= log_I
logiw = logP - logQ

if method == "psis":
replace = False
logiw, pareto_k = PSIS()(logiw)
elif method == "psir":
replace = True
logiw, pareto_k = PSIS()(logiw)
elif method == "identity":
replace = False
logiw = logiw
pareto_k = None
else:
raise ValueError(f"Invalid importance sampling method: {method}")

# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
# Pareto k may not be a good diagnostic for Pathfinder.
if pareto_k is not None:
pareto_k = pareto_k.eval()
if pareto_k < 0.5:
pass
elif 0.5 <= pareto_k < 0.70:
logger.info(
f"Pareto k value ({pareto_k:.2f}) is between 0.5 and 0.7 which indicates an imperfect approximation however still useful."
)
logger.info("Consider increasing ftol, gtol, maxcor or num_paths.")
elif pareto_k >= 0.7:
logger.info(
f"Pareto k value ({pareto_k:.2f}) exceeds 0.7 which indicates a bad approximation."
)
logger.info(
"Consider increasing ftol, gtol, maxcor, num_paths or reparametrising the model."
)
else:
logger.info(
f"Received an invalid Pareto k value of {pareto_k:.2f} which indicates the model is seriously flawed."
)
logger.info(
"Consider reparametrising the model all together or ensure the input data are correct."
)

logger.warning(f"Pareto k value: {pareto_k:.2f}")

p = pt.exp(logiw - pt.logsumexp(logiw)).eval()
aphc14 marked this conversation as resolved.
Show resolved Hide resolved
rng = np.random.default_rng(random_seed)
return rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)
Loading