Skip to content

Commit

Permalink
Refactor Pathfinder VI: Default to PSIS, Add Concurrency, and Improve…
Browse files Browse the repository at this point in the history
…d Computational Performance

- Significantly computational efficiency by combining 3 computational graphs into 1 larger compile. Removed non-shared inputs and used  with  for significant performance gains.
- Set default importance sampling method to 'psis' for more stable posterior results, avoiding local peaks seen with 'psir'.
- Introduce concurrency options ('thread' and 'process') for multithreading and multiprocessing. Defaults to No concurrency as there haven't been any/or much reduction to the compute time.
- Adjusted default  from 8 to 4 and  from 1.0 to 2.0 and maxcor to max(3*log(N), 5). This default setting lessens computational time and and the degree by which the posterior variance is being underestimated.
  • Loading branch information
aphc14 committed Dec 7, 2024
1 parent 2815c4f commit e4b8996
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 173 deletions.
75 changes: 50 additions & 25 deletions pymc_experimental/inference/pathfinder/importance_sampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import warnings

from typing import Literal

import arviz as az
import numpy as np
import pytensor.tensor as pt
Expand Down Expand Up @@ -31,12 +33,13 @@ def perform(self, node: Apply, inputs, outputs) -> None:
outputs[1][0] = pareto_k


def psir(
def importance_sampling(
samples: TensorVariable,
# logP: TensorVariable,
# logQ: TensorVariable,
logiw: TensorVariable,
num_draws: int = 1000,
num_draws: int,
method: Literal["psis", "psir", "identity", "none"],
random_seed: int | None = None,
) -> np.ndarray:
"""Pareto Smoothed Importance Resampling (PSIR)
Expand All @@ -52,6 +55,8 @@ def psir(
log probability of proposal distribution
num_draws : int
number of draws to return where num_draws <= samples.shape[0]
method : str, optional
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
random_seed : int | None
Returns
Expand All @@ -74,30 +79,50 @@ def psir(
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
"""

psislw, pareto_k = PSIS()(logiw)
pareto_k = pareto_k.eval()
if pareto_k < 0.5:
pass
elif 0.5 <= pareto_k < 0.70:
logger.warning(
f"Pareto k value ({pareto_k:.2f}) is between 0.5 and 0.7 which indicates an imperfect approximation however still useful."
)
logger.info("Consider increasing ftol, gtol, maxcor or num_paths.")
elif pareto_k >= 0.7:
if method == "psis":
replace = False
logiw, pareto_k = PSIS()(logiw)
elif method == "psir":
replace = True
logiw, pareto_k = PSIS()(logiw)
elif method == "identity":
replace = False
logiw = logiw
pareto_k = None
elif method == "none":
logger.warning(
f"Pareto k value ({pareto_k:.2f}) exceeds 0.7 which indicates a bad approximation."
)
logger.info(
"Consider increasing ftol, gtol, maxcor, num_paths or reparametrising the model."
)
else:
logger.warning(
f"Received an invalid Pareto k value of {pareto_k:.2f} which indicates the model is seriously flawed."
)
logger.info(
"Consider reparametrising the model all together or ensure the input data are correct."
"importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
)
return samples

# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
# Pareto k may not be a good diagnostic for Pathfinder.
if pareto_k is not None:
pareto_k = pareto_k.eval()
if pareto_k < 0.5:
pass
elif 0.5 <= pareto_k < 0.70:
logger.info(
f"Pareto k value ({pareto_k:.2f}) is between 0.5 and 0.7 which indicates an imperfect approximation however still useful."
)
logger.info("Consider increasing ftol, gtol, maxcor or num_paths.")
elif pareto_k >= 0.7:
logger.info(
f"Pareto k value ({pareto_k:.2f}) exceeds 0.7 which indicates a bad approximation."
)
logger.info(
"Consider increasing ftol, gtol, maxcor, num_paths or reparametrising the model."
)
else:
logger.info(
f"Received an invalid Pareto k value of {pareto_k:.2f} which indicates the model is seriously flawed."
)
logger.info(
"Consider reparametrising the model all together or ensure the input data are correct."
)

logger.warning(f"Pareto k value: {pareto_k:.2f}")

p = pt.exp(psislw - pt.logsumexp(psislw)).eval()
p = pt.exp(logiw - pt.logsumexp(logiw)).eval()
rng = np.random.default_rng(random_seed)
return rng.choice(samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0)
return rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)
2 changes: 2 additions & 0 deletions pymc_experimental/inference/pathfinder/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(self, message=None):


class LBFGSOp(Op):
__props__ = ("fn", "grad_fn", "maxcor", "maxiter", "ftol", "gtol", "maxls")

def __init__(self, fn, grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000):
self.fn = fn
self.grad_fn = grad_fn
Expand Down
Loading

0 comments on commit e4b8996

Please sign in to comment.