Skip to content

Commit

Permalink
Use vectorize in finite_discrete_marginal_rv_logp
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Apr 21, 2024
1 parent 63571f0 commit 8158627
Showing 1 changed file with 15 additions and 27 deletions.
42 changes: 15 additions & 27 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
from pymc.logprob.basic import conditional_logp, logp
from pymc.logprob.transforms import IntervalTransform
from pymc.model import Model
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
from pymc.pytensorf import compile_pymc, constant_fold
from pymc.util import _get_seeds_per_chain, treedict
from pytensor import Mode, scan
from pytensor.compile import SharedVariable
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
from pytensor.graph.replace import vectorize_graph
from pytensor.graph.replace import graph_replace, vectorize_graph
from pytensor.scan import map as scan_map
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.elemwise import Elemwise
Expand Down Expand Up @@ -686,31 +685,23 @@ def _add_reduce_batch_dependent_logps(
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
# Clone the inner RV graph of the Marginalized RV
marginalized_rvs_node = op.make_node(*inputs)
inner_rvs = clone_replace(
marginalized_rv, *inner_rvs = clone_replace(
op.inner_outputs,
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
)
marginalized_rv = inner_rvs[0]

# Obtain the joint_logp graph of the inner RV graph
inner_rvs_to_values = {rv: rv.clone() for rv in inner_rvs}
logps_dict = conditional_logp(rv_values=inner_rvs_to_values, **kwargs)
inner_rv_values = dict(zip(inner_rvs, values))
marginalized_vv = marginalized_rv.clone()
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)

# Reduce logp dimensions corresponding to broadcasted variables
marginalized_logp = logps_dict.pop(inner_rvs_to_values[marginalized_rv])
marginalized_logp = logps_dict.pop(marginalized_vv)
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
marginalized_rv.type, logps_dict.values()
)

# Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
# values of the marginalized RV
# Some inputs are not root inputs (such as transformed projections of value variables)
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
inputs = list(inputvars(inputs))
joint_logp_op = OpFromGraph(
list(inner_rvs_to_values.values()) + inputs, [joint_logp], inline=True
)

# Compute the joint_logp for all possible n values of the marginalized RV. We assume
# each original dimension is independent so that it suffices to evaluate the graph
# n times, once with each possible value of the marginalized RV replicated across
Expand All @@ -729,17 +720,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
0,
)

# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
# TODO: Try vectorize here
if len(marginalized_rv_domain) <= 10:
joint_logps = [
joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs)
for i in range(len(marginalized_rv_domain))
]
else:

try:
joint_logps = vectorize_graph(
joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor}
)
except Exception:
# Fallback to Scan
def logp_fn(marginalized_rv_const, *non_sequences):
return joint_logp_op(marginalized_rv_const, *non_sequences)
return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})

joint_logps, _ = scan_map(
fn=logp_fn,
Expand Down

0 comments on commit 8158627

Please sign in to comment.