Skip to content

Commit

Permalink
Replace MarginalModel by model transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 3, 2025
1 parent 304a9e6 commit 1b10334
Show file tree
Hide file tree
Showing 12 changed files with 930 additions and 681 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ import pymc as pm
import pymc_extras as pmx

with pm.Model():
alpha = pmx.ParabolicFractal('alpha', b=1, c=1)

alpha = pmx.ParabolicFractal('alpha', b=1, c=1)

...
...

```

Expand Down
3 changes: 2 additions & 1 deletion docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ methods in the current release of PyMC experimental.
:toctree: generated/

as_model
MarginalModel
marginalize
recover_marginals
model_builder.ModelBuilder

Inference
Expand Down Expand Up @@ -53,6 +53,7 @@ Utils

spline.bspline_interpolation
prior.prior_from_idata
model_equivalence.equivalent_models


Statespace Models
Expand Down
6 changes: 5 additions & 1 deletion pymc_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from pymc_extras import gp, statespace, utils
from pymc_extras.distributions import *
from pymc_extras.inference.fit import fit
from pymc_extras.model.marginal.marginal_model import MarginalModel, marginalize
from pymc_extras.model.marginal.marginal_model import (
MarginalModel,
marginalize,
recover_marginals,
)
from pymc_extras.model.model_api import as_model
from pymc_extras.version import __version__

Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def transition(*args):
discrete_mc_op = DiscreteMarkovChainRV(
inputs=[P_, steps_, init_dist_, state_rng],
outputs=[state_next_rng, discrete_mc_],
ndim_supp=1,
n_lags=n_lags,
extended_signature="(p,p),(),(p),[rng]->[rng],(t)",
)

discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng)
Expand Down
103 changes: 100 additions & 3 deletions pymc_extras/model/marginal/distributions.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
import warnings

from collections.abc import Sequence

import numpy as np
import pytensor.tensor as pt

from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
from pymc.distributions.distribution import _support_point, support_point
from pymc.logprob.abstract import MeasurableOp, _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import constant_fold
from pytensor import Variable
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.mode import Mode
from pytensor.graph import Op, vectorize_graph
from pytensor.graph import FunctionGraph, Op, vectorize_graph
from pytensor.graph.basic import equal_computations
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.scan import map as scan_map
from pytensor.scan import scan
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.type import RandomType

from pymc_extras.distributions import DiscreteMarkovChain


class MarginalRV(OpFromGraph, MeasurableOp):
"""Base class for Marginalized RVs"""

def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
def __init__(
self,
*args,
dims_connections: tuple[tuple[int | None], ...],
dims: tuple[Variable, ...],
**kwargs,
) -> None:
self.dims_connections = dims_connections
self.dims = dims
super().__init__(*args, **kwargs)

@property
Expand All @@ -43,6 +55,74 @@ def support_axes(self) -> tuple[tuple[int]]:
)
return tuple(support_axes_vars)

def __eq__(self, other):
# Just to allow easy testing of equivalent models,
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
if type(self) is not type(other):
return False

return equal_computations(
self.inner_outputs,
other.inner_outputs,
self.inner_inputs,
other.inner_inputs,
)

def __hash__(self):
# Just to allow easy testing of equivalent models,
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
return hash((type(self), len(self.inner_inputs), len(self.inner_outputs)))


@_support_point.register
def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
"""Support point for a marginalized RV.
The support point of a marginalized RV is the support point of the inner RV,
conditioned on the marginalized RV taking its support point.
"""
outputs = rv.owner.outputs

inner_rv = op.inner_outputs[outputs.index(rv)]
marginalized_inner_rv, *other_dependent_inner_rvs = (
out
for out in op.inner_outputs
if out is not inner_rv and not isinstance(out.type, RandomType)
)

# Replace references to inner rvs by the dummy variables (including the marginalized RV)
# This is necessary because the inner RVs may depend on each other
marginalized_inner_rv_dummy = marginalized_inner_rv.clone()
other_dependent_inner_rv_to_dummies = {
inner_rv: inner_rv.clone() for inner_rv in other_dependent_inner_rvs
}
inner_rv = clone_replace(
inner_rv,
replace={marginalized_inner_rv: marginalized_inner_rv_dummy}
| other_dependent_inner_rv_to_dummies,
)

# Get support point of inner RV and marginalized RV
inner_rv_support_point = support_point(inner_rv)
marginalized_inner_rv_support_point = support_point(marginalized_inner_rv)

replacements = [
# Replace the marginalized RV dummy by its support point
(marginalized_inner_rv_dummy, marginalized_inner_rv_support_point),
# Replace other dependent RVs dummies by the respective outer outputs.
# PyMC will replace them by their support points later
*(
(v, outputs[op.inner_outputs.index(k)])
for k, v in other_dependent_inner_rv_to_dummies.items()
),
# Replace outer input RVs
*zip(op.inner_inputs, inputs),
]
fgraph = FunctionGraph(outputs=[inner_rv_support_point], clone=False)
fgraph.replace_all(replacements, import_missing=True)
[rv_support_point] = fgraph.outputs
return rv_support_point


class MarginalFiniteDiscreteRV(MarginalRV):
"""Base class for Marginalized Finite Discrete RVs"""
Expand Down Expand Up @@ -132,12 +212,27 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
the inner graph.
"""
return clone_replace(
return graph_replace(
op.inner_outputs,
replace=tuple(zip(op.inner_inputs, inputs)),
strict=False,
)


class NonSeparableLogpWarning(UserWarning):
pass


def warn_non_separable_logp(values):
if len(values) > 1:
warnings.warn(
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
f"Their joint logp terms will be assigned to the first value: {values[0]}.",
NonSeparableLogpWarning,
stacklevel=2,
)


DUMMY_ZERO = pt.constant(0, name="dummy_zero")


Expand Down Expand Up @@ -199,6 +294,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
# Align logp with non-collapsed batch dimensions of first RV
joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)

warn_non_separable_logp(values)
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
return joint_logp, *dummy_logps
Expand Down Expand Up @@ -272,5 +368,6 @@ def step_alpha(logp_emission, log_alpha, log_P):

# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
warn_non_separable_logp(values)
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
return joint_logp, *dummy_logps
17 changes: 8 additions & 9 deletions pymc_extras/model/marginal/graph_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from itertools import zip_longest

from pymc import SymbolicRandomVariable
from pytensor.compile import SharedVariable
from pytensor.graph import Constant, Variable, ancestors
from pymc.model.fgraph import ModelVar
from pytensor.graph import Variable, ancestors
from pytensor.graph.basic import io_toposort
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.blockwise import Blockwise
Expand Down Expand Up @@ -35,13 +35,9 @@ def static_shape_ancestors(vars):

def find_conditional_input_rvs(output_rvs, all_rvs):
"""Find conditionally indepedent input RVs."""
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
return [
var
for var in ancestors(output_rvs, blockers=blockers)
if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
]
other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
return [var for var in ancestors(output_rvs, blockers=blockers) if var in other_rvs]


def is_conditional_dependent(
Expand Down Expand Up @@ -141,6 +137,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
# None of the inputs are related to the batch_axes of the input_vars
continue

elif isinstance(node.op, ModelVar):
var_dims[node.outputs[0]] = inputs_dims[0]

elif isinstance(node.op, DimShuffle):
[input_dims] = inputs_dims
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)
Expand Down
Loading

0 comments on commit 1b10334

Please sign in to comment.