Skip to content

Commit

Permalink
Improvements to Importance Sampling and InferenceData shape
Browse files Browse the repository at this point in the history
- Handle different importance sampling methods for reshaping and adjusting log densities.
- Modified  to return InferenceData with chain dim of size num_paths when
  • Loading branch information
aphc14 committed Dec 8, 2024
1 parent e4b8996 commit 0db1733
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 34 deletions.
48 changes: 31 additions & 17 deletions pymc_experimental/inference/pathfinder/importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytensor.tensor as pt

from pytensor.graph import Apply, Op
from pytensor.tensor.variable import TensorVariable

logger = logging.getLogger(__name__)

Expand All @@ -34,12 +33,12 @@ def perform(self, node: Apply, inputs, outputs) -> None:


def importance_sampling(
samples: TensorVariable,
# logP: TensorVariable,
# logQ: TensorVariable,
logiw: TensorVariable,
samples: np.ndarray,
logP: np.ndarray,
logQ: np.ndarray,
num_draws: int,
method: Literal["psis", "psir", "identity", "none"],
logiw: np.ndarray | None = None,
random_seed: int | None = None,
) -> np.ndarray:
"""Pareto Smoothed Importance Resampling (PSIR)
Expand Down Expand Up @@ -79,21 +78,36 @@ def importance_sampling(
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
"""

if method == "psis":
replace = False
logiw, pareto_k = PSIS()(logiw)
elif method == "psir":
replace = True
logiw, pareto_k = PSIS()(logiw)
elif method == "identity":
replace = False
logiw = logiw
pareto_k = None
elif method == "none":
num_paths, num_pdraws, N = samples.shape

if method == "none":
logger.warning(
"importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
)
return samples
else:
samples = samples.reshape(-1, N)
logP = logP.ravel()
logQ = logQ.ravel()

# adjust log densities
log_I = np.log(num_paths)
logP -= log_I
logQ -= log_I
logiw = logP - logQ

if method == "psis":
replace = False
logiw, pareto_k = PSIS()(logiw)
elif method == "psir":
replace = True
logiw, pareto_k = PSIS()(logiw)
elif method == "identity":
replace = False
logiw = logiw
pareto_k = None
else:
raise ValueError(f"Invalid importance sampling method: {method}")

# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
# Pareto k may not be a good diagnostic for Pathfinder.
Expand Down Expand Up @@ -121,7 +135,7 @@ def importance_sampling(
"Consider reparametrising the model all together or ensure the input data are correct."
)

logger.warning(f"Pareto k value: {pareto_k:.2f}")
logger.warning(f"Pareto k value: {pareto_k:.2f}")

p = pt.exp(logiw - pt.logsumexp(logiw)).eval()
rng = np.random.default_rng(random_seed)
Expand Down
31 changes: 18 additions & 13 deletions pymc_experimental/inference/pathfinder/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,13 @@ def convert_flat_trace_to_idata(
postprocessing_backend="cpu",
inference_backend="pymc",
model=None,
importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
):
if importance_sampling == "none":
# samples.ndim == 3 in this case, otherwise ndim == 2
num_paths, num_pdraws, N = samples.shape
samples = samples.reshape(-1, N)

model = modelcontext(model)
ip = model.initial_point()
ip_point_map_info = DictToArrayBijection.map(ip).point_map_info
Expand Down Expand Up @@ -152,6 +158,13 @@ def convert_flat_trace_to_idata(
)
fn.trust_input = True
result = fn(*list(trace.values()))

if importance_sampling == "none":
result = [
res.reshape(num_paths, num_pdraws, *var.type.shape)
for res, var in zip(result, vars_to_sample)
]

elif inference_backend == "blackjax":
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = jax.vmap(jax.vmap(jax_fn))(
Expand Down Expand Up @@ -731,7 +744,6 @@ def multipath_pathfinder(
**pathfinder_kwargs,
):
*path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1)
N = DictToArrayBijection.map(model.initial_point()).data.shape[0]

single_pathfinder_fn = make_single_pathfinder_fn(
model,
Expand Down Expand Up @@ -808,19 +820,11 @@ def multipath_pathfinder(
logP = np.concatenate(logP)
logQ = np.concatenate(logQ)

samples = samples.reshape(-1, N)
logP = logP.ravel()
logQ = logQ.ravel()

# adjust log densities
log_I = np.log(num_paths)
logP -= log_I
logQ -= log_I
logiw = logP - logQ

return _importance_sampling(
samples=samples,
logiw=logiw,
logP=logP,
logQ=logQ,
# logiw=logiw,
num_draws=num_draws,
method=importance_sampling,
random_seed=choice_seed,
Expand Down Expand Up @@ -881,7 +885,7 @@ def fit_pathfinder(
epsilon: float
value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
importance_sampling : str, optional
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
progressbar : bool, optional
Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time.
random_seed : RandomSeed, optional
Expand Down Expand Up @@ -974,5 +978,6 @@ def fit_pathfinder(
postprocessing_backend=postprocessing_backend,
inference_backend=inference_backend,
model=model,
importance_sampling=importance_sampling,
)
return idata
11 changes: 7 additions & 4 deletions tests/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,22 @@ def test_pathfinder(inference_backend):
with model:
idata = pmx.fit(
method="pathfinder",
num_paths=20,
num_paths=50,
jitter=10.0,
random_seed=41,
inference_backend=inference_backend,
)

assert idata.posterior["mu"].shape == (1, 1000)
assert idata.posterior["tau"].shape == (1, 1000)
assert idata.posterior["theta"].shape == (1, 1000, 8)
# NOTE: Pathfinder tends to return means around 7 and tau around 0.58. So need to increase atol by a large amount.
if inference_backend == "pymc":
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=2.5)
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=3.8)
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.6)
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.5)


def test_bfgs_sample():
import pytensor
import pytensor.tensor as pt

from pymc_experimental.inference.pathfinder.pathfinder import (
Expand All @@ -73,6 +74,7 @@ def test_bfgs_sample():
L = Lp1 - 1
J = 6
num_samples = 1000
rng = pytensor.shared(np.random.default_rng(42), name="rng")

# mock data
x_data = np.random.randn(Lp1, N)
Expand All @@ -90,6 +92,7 @@ def test_bfgs_sample():

# sample
phi, logq = bfgs_sample(
rng=rng,
num_samples=num_samples,
x=x,
g=g,
Expand Down

0 comments on commit 0db1733

Please sign in to comment.