Skip to content

Commit

Permalink
Allow creating MarginalModel from existing Model
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Apr 30, 2024
1 parent 6d12203 commit dfe3fe0
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 23 deletions.
1 change: 1 addition & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ methods in the current release of PyMC experimental.

as_model
MarginalModel
marginalize
model_builder.ModelBuilder

Inference
Expand Down
92 changes: 69 additions & 23 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Sequence
from typing import Sequence, Union

import numpy as np
import pymc
Expand All @@ -25,10 +25,12 @@
from pytensor.tensor.shape import Shape
from pytensor.tensor.special import log_softmax

__all__ = ["MarginalModel"]
__all__ = ["MarginalModel", "marginalize"]

from pymc_experimental.distributions import DiscreteMarkovChain

ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str]


class MarginalModel(Model):
"""Subclass of PyMC Model that implements functionality for automatic
Expand Down Expand Up @@ -207,35 +209,50 @@ def logp(self, vars=None, **kwargs):
vars = [m[var.name] for var in vars]
return m._logp(vars=vars, **kwargs)

def clone(self):
m = MarginalModel(coords=self.coords)
model_vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs
data_vars = [var for name, var in self.named_vars.items() if var not in model_vars]
@staticmethod
def from_model(model: Union[Model, "MarginalModel"]) -> "MarginalModel":
new_model = MarginalModel(coords=model.coords)
if isinstance(model, MarginalModel):
marginalized_rvs = model.marginalized_rvs
marginalized_named_vars_to_dims = model._marginalized_named_vars_to_dims
else:
marginalized_rvs = []
marginalized_named_vars_to_dims = {}

model_vars = model.basic_RVs + model.potentials + model.deterministics + marginalized_rvs
data_vars = [var for name, var in model.named_vars.items() if var not in model_vars]
vars = model_vars + data_vars
cloned_vars = clone_replace(vars)
vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
m.vars_to_clone = vars_to_clone

m.named_vars = treedict({name: vars_to_clone[var] for name, var in self.named_vars.items()})
m.named_vars_to_dims = self.named_vars_to_dims
m.values_to_rvs = {i: vars_to_clone[rv] for i, rv in self.values_to_rvs.items()}
m.rvs_to_values = {vars_to_clone[rv]: i for rv, i in self.rvs_to_values.items()}
m.rvs_to_transforms = {vars_to_clone[rv]: i for rv, i in self.rvs_to_transforms.items()}
m.rvs_to_initial_values = {
vars_to_clone[rv]: i for rv, i in self.rvs_to_initial_values.items()
new_model.vars_to_clone = vars_to_clone

new_model.named_vars = treedict(
{name: vars_to_clone[var] for name, var in model.named_vars.items()}
)
new_model.named_vars_to_dims = model.named_vars_to_dims
new_model.values_to_rvs = {vv: vars_to_clone[rv] for vv, rv in model.values_to_rvs.items()}
new_model.rvs_to_values = {vars_to_clone[rv]: vv for rv, vv in model.rvs_to_values.items()}
new_model.rvs_to_transforms = {
vars_to_clone[rv]: tr for rv, tr in model.rvs_to_transforms.items()
}
new_model.rvs_to_initial_values = {
vars_to_clone[rv]: iv for rv, iv in model.rvs_to_initial_values.items()
}
m.free_RVs = [vars_to_clone[rv] for rv in self.free_RVs]
m.observed_RVs = [vars_to_clone[rv] for rv in self.observed_RVs]
m.potentials = [vars_to_clone[pot] for pot in self.potentials]
m.deterministics = [vars_to_clone[det] for det in self.deterministics]
new_model.free_RVs = [vars_to_clone[rv] for rv in model.free_RVs]
new_model.observed_RVs = [vars_to_clone[rv] for rv in model.observed_RVs]
new_model.potentials = [vars_to_clone[pot] for pot in model.potentials]
new_model.deterministics = [vars_to_clone[det] for det in model.deterministics]

m.marginalized_rvs = [vars_to_clone[rv] for rv in self.marginalized_rvs]
m._marginalized_named_vars_to_dims = self._marginalized_named_vars_to_dims
return m
new_model.marginalized_rvs = [vars_to_clone[rv] for rv in marginalized_rvs]
new_model._marginalized_named_vars_to_dims = marginalized_named_vars_to_dims
return new_model

def clone(self):
return self.from_model(self)

def marginalize(
self,
rvs_to_marginalize: TensorVariable | Sequence[TensorVariable] | str | Sequence[str],
rvs_to_marginalize: ModelRVs,
):
if not isinstance(rvs_to_marginalize, Sequence):
rvs_to_marginalize = (rvs_to_marginalize,)
Expand Down Expand Up @@ -491,6 +508,35 @@ def transform_input(inputs):
return rv_dataset


def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
"""Marginalize a subset of variables in a PyMC model.
This creates a class of `MarginalModel` from an existing `Model`, with the specified
variables marginalized.
See documentation for `MarginalModel` for more information.
Parameters
----------
model : Model
PyMC model to marginalize. Original variables well be cloned.
rvs_to_marginalize : Sequence[TensorVariable]
Variables to marginalize in the returned model.
Returns
-------
marginal_model: MarginalModel
Marginal model with the specified variables marginalized.
"""
if not isinstance(rvs_to_marginalize, tuple | list):
rvs_to_marginalize = (rvs_to_marginalize,)
rvs_to_marginalize = [rv if isinstance(rv, str) else rv.name for rv in rvs_to_marginalize]

marginal_model = MarginalModel.from_model(model)
marginal_model.marginalize(rvs_to_marginalize)
return marginal_model


class MarginalRV(SymbolicRandomVariable):
"""Base class for Marginalized RVs"""

Expand Down
33 changes: 33 additions & 0 deletions pymc_experimental/tests/model/test_marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pymc import ImputationWarning, inputvars
from pymc.distributions import transforms
from pymc.logprob.abstract import _logprob
from pymc.model.fgraph import fgraph_from_model
from pymc.util import UNSET
from scipy.special import log_softmax, logsumexp
from scipy.stats import halfnorm, norm
Expand All @@ -19,7 +20,9 @@
FiniteDiscreteMarginalRV,
MarginalModel,
is_conditional_dependent,
marginalize,
)
from pymc_experimental.tests.utils import equal_computations_up_to_root


@pytest.fixture
Expand Down Expand Up @@ -776,3 +779,33 @@ def test_mutable_indexing_jax_backend():
pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data)
model.marginalize(["is_outlier"])
get_jaxified_logp(model)


def test_marginal_model_func():
def create_model(model_class):
with model_class(coords={"trial": range(10)}) as m:
idx = pm.Bernoulli("idx", p=0.5, dims="trial")
mu = pt.where(idx, 1, -1)
sigma = pm.HalfNormal("sigma")
y = pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10)
return m

marginal_m = marginalize(create_model(pm.Model), ["idx"])
assert isinstance(marginal_m, MarginalModel)

reference_m = create_model(MarginalModel)
reference_m.marginalize(["idx"])

# Check forward graph representation is the same
marginal_fgraph, _ = fgraph_from_model(marginal_m)
reference_fgraph, _ = fgraph_from_model(reference_m)
assert equal_computations_up_to_root(marginal_fgraph.outputs, reference_fgraph.outputs)

# Check logp graph is the same
# This fails because OpFromGraphs comparison is broken
# assert equal_computations_up_to_root([marginal_m.logp()], [reference_m.logp()])
ip = marginal_m.initial_point()
np.testing.assert_allclose(
marginal_m.compile_logp()(ip),
reference_m.compile_logp()(ip),
)
31 changes: 31 additions & 0 deletions pymc_experimental/tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Sequence

from pytensor.compile import SharedVariable
from pytensor.graph import Constant, graph_inputs
from pytensor.graph.basic import Variable, equal_computations
from pytensor.tensor.random.type import RandomType


def equal_computations_up_to_root(
xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True
) -> bool:
# Check if graphs are equivalent even if root variables have distinct identities

x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)]
y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)]
if len(x_graph_inputs) != len(y_graph_inputs):
return False
for x, y in zip(x_graph_inputs, y_graph_inputs):
if x.type != y.type:
return False
if x.name != y.name:
return False
if isinstance(x, SharedVariable):
if not isinstance(y, SharedVariable):
return False
if isinstance(x.type, RandomType) and ignore_rng_values:
continue
if not x.type.values_eq(x.get_value(), y.get_value()):
return False

return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs)

0 comments on commit dfe3fe0

Please sign in to comment.