From 0a97875b681b3ede489e411bd91d816f6215ff4d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 27 Jun 2024 16:51:37 +0200 Subject: [PATCH] Add QMC marginalization --- pymc_experimental/model/marginal_model.py | 107 +++++++++++++++++- .../tests/model/test_marginal_model.py | 17 +++ 2 files changed, 119 insertions(+), 5 deletions(-) diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index ead9a362b..241ce7d97 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -2,11 +2,12 @@ from typing import Sequence, Union import numpy as np -import pymc import pytensor.tensor as pt +import scipy from arviz import InferenceData, dict_to_dataset -from pymc import SymbolicRandomVariable +from pymc import SymbolicRandomVariable, icdf from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list +from pymc.distributions.continuous import Continuous, Normal from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform from pymc.distributions.transforms import Chain from pymc.logprob.abstract import _logprob @@ -159,7 +160,11 @@ def _marginalize(self, user_warnings=False): f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}" ) - old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph( + if isinstance(rv_to_marginalize.owner.op, Continuous): + subgraph_builder_fn = replace_continuous_marginal_subgraph + else: + subgraph_builder_fn = replace_finite_discrete_marginal_subgraph + old_rvs, new_rvs = subgraph_builder_fn( fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize ) @@ -267,7 +272,11 @@ def marginalize( ) rv_op = rv_to_marginalize.owner.op - if isinstance(rv_op, DiscreteMarkovChain): + + if isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)): + pass + + elif isinstance(rv_op, DiscreteMarkovChain): if rv_op.n_lags > 1: raise NotImplementedError( "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" @@ -276,7 +285,11 @@ def marginalize( raise NotImplementedError( "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" ) - elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)): + + elif isinstance(rv_op, Normal): + pass + + else: raise NotImplementedError( f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported" ) @@ -549,6 +562,16 @@ class DiscreteMarginalMarkovChainRV(MarginalRV): """Base class for Discrete Marginal Markov Chain RVs""" +class QMCMarginalNormalRV(MarginalRV): + """Basec class for QMC Marginalized RVs""" + + __props__ = ("qmc_order",) + + def __init__(self, *args, qmc_order: int, **kwargs): + self.qmc_order = qmc_order + super().__init__(*args, **kwargs) + + def static_shape_ancestors(vars): """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph).""" return [ @@ -707,6 +730,36 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs return rvs_to_marginalize, marginalized_rvs +def replace_continuous_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs): + dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs) + if not dependent_rvs: + raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") + + marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) + dependent_rvs_input_rvs = [ + rv + for rv in find_conditional_input_rvs(dependent_rvs, all_rvs) + if rv is not rv_to_marginalize + ] + + input_rvs = [*marginalized_rv_input_rvs, *dependent_rvs_input_rvs] + rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs] + + outputs = rvs_to_marginalize + # We are strict about shared variables in SymbolicRandomVariables + inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs) + + marginalized_rvs = QMCMarginalNormalRV( + inputs=inputs, + outputs=outputs, + ndim_supp=max([rv.owner.op.ndim_supp for rv in dependent_rvs]), + qmc_order=13, + )(*inputs) + + fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) + return rvs_to_marginalize, marginalized_rvs + + def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: op = rv.owner.op dist_params = rv.owner.op.dist_params(rv.owner) @@ -870,3 +923,47 @@ def step_alpha(logp_emission, log_alpha, log_P): # return is the joint probability of everything together, but PyMC still expects one logp for each one. dummy_logps = (pt.constant(0),) * (len(values) - 1) return joint_logp, *dummy_logps + + +@_logprob.register(QMCMarginalNormalRV) +def qmc_marginal_rv_logp(op, values, *inputs, **kwargs): + # Clone the inner RV graph of the Marginalized RV + marginalized_rvs_node = op.make_node(*inputs) + marginalized_rv, *inner_rvs = clone_replace( + op.inner_outputs, + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, + ) + + marginalized_rv_node = marginalized_rv.owner + marginalized_rv_op = marginalized_rv_node.op + + # GET QMC draws from the marginalized RV + # TODO: Make this an Op + rng = marginalized_rv_op.rng_param(marginalized_rv_node) + shape = constant_fold(tuple(marginalized_rv.shape)) + size = np.prod(shape).astype(int) + n_draws = 2**op.qmc_order + qmc_engine = scipy.stats.qmc.Sobol(d=size, seed=rng.get_value(borrow=False)) + uniform_draws = qmc_engine.random(n_draws).reshape((n_draws, *shape)) + qmc_draws = icdf(marginalized_rv, uniform_draws) + qmc_draws.name = f"QMC_{op.name}_draws" + + # Obtain the logp of the dependent variables + # We need to include the marginalized RV for correctness, we remove it later. + inner_rv_values = dict(zip(inner_rvs, values)) + marginalized_vv = marginalized_rv.clone() + rv_values = inner_rv_values | {marginalized_rv: marginalized_vv} + logps_dict = conditional_logp(rv_values=rv_values, **kwargs) + # Pop the logp term corresponding to the marginalized RV + # (it already got accounted for in the bias of the QMC draws) + logps_dict.pop(marginalized_vv) + + # Vectorize across QMC draws and take the mean on log scale + core_marginalized_logps = list(logps_dict.values()) + batched_marginalized_logps = vectorize_graph( + core_marginalized_logps, replace={marginalized_vv: qmc_draws} + ) + return tuple( + pt.logsumexp(batched_marginalized_logp, axis=0) - pt.log(size) + for batched_marginalized_logp in batched_marginalized_logps + ) diff --git a/pymc_experimental/tests/model/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py index 74f955714..b50451468 100644 --- a/pymc_experimental/tests/model/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -6,6 +6,7 @@ import pymc as pm import pytensor.tensor as pt import pytest +import scipy from arviz import InferenceData, dict_to_dataset from pymc.distributions import transforms from pymc.logprob.abstract import _logprob @@ -802,3 +803,19 @@ def create_model(model_class): marginal_m.compile_logp()(ip), reference_m.compile_logp()(ip), ) + + +def test_marginalize_normal_via_qmc(): + with MarginalModel() as m: + SD = pm.HalfNormal("SD", default_transform=None) + X = pm.Normal("X", sigma=SD) + Y = pm.Normal("Y", mu=(2 * X + 1), sigma=1, observed=[1, 2, 3]) + + m.marginalize([X]) # ideally method="qmc" + + # P(Y=[1, 2, 3] | SD = 1) = int_x P(Y=[1, 2, 3] | SD=1, X=x) P(X=x | SD=1) = Norm([1, 2, 3], 0.5, sqrt(2)) + [logp_eval] = m.compile_logp(vars=[Y], sum=False)({"SD": 1}) + np.testing.assert_allclose( + logp_eval, + scipy.stats.norm.logpdf([1, 2, 3], 0.5, np.sqrt(2) / 2), + )