Skip to content

Commit

Permalink
Expand and simplify local_dimshuffle_rv_lift
Browse files Browse the repository at this point in the history
* The rewrite no longer bails out when dimshuffle affects both unique param dimensions and repeated param dimensions from the size argument. This requires:
  1) Adding broadcastable dimensions to the parameters, which should be "cost-free" and would need to be done in the `perform` method anyway.
  2) Extend size to incorporate implicit batch dimensions coming from the parameters. This requires computing the shape resulting from broadcasting the parameters. It's unclear whether this is less performant, because the `perform` method can now simply broadcast each parameter to the size, instead of having to broadcast the parameters together.
* The rewrite now works with Multivariate RVs
* The rewrite bails out when dimensions are dropped by the Dimshuffle. This case was not correctly handled by the previous rewrite
  • Loading branch information
ricardoV94 committed Dec 4, 2022
1 parent 8e61224 commit 2ebfbf1
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 129 deletions.
178 changes: 52 additions & 126 deletions pytensor/tensor/random/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
from pytensor.tensor.subtensor import (
AdvancedSubtensor,
AdvancedSubtensor1,
Expand Down Expand Up @@ -115,23 +115,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
For example, ``normal(mu, std).T == normal(mu.T, std.T)``.
The basic idea behind this rewrite is that we need to separate the
``DimShuffle``-ing into distinct ``DimShuffle``s that each occur in two
distinct sub-spaces: the (set of independent) parameters and ``size``
(i.e. replications) sub-spaces.
If a ``DimShuffle`` exchanges dimensions across those two sub-spaces, then we
don't do anything.
Otherwise, if the ``DimShuffle`` only exchanges dimensions within each of
those sub-spaces, we can break it apart and apply the parameter-space
``DimShuffle`` to the distribution parameters, and then apply the
replications-space ``DimShuffle`` to the ``size`` tuple. The latter is a
particularly simple rearranging of a tuple, but the former requires a
little more work.
TODO: Currently, multivariate support for this rewrite is disabled.
This rewrite is only applicable when the Dimshuffle operation does
not affect support dimensions.
TODO: Support dimension dropping
"""

ds_op = node.op
Expand All @@ -142,128 +129,67 @@ def local_dimshuffle_rv_lift(fgraph, node):
base_rv = node.inputs[0]
rv_node = base_rv.owner

if not (
rv_node and isinstance(rv_node.op, RandomVariable) and rv_node.op.ndim_supp == 0
):
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False

# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(base_rv, node, fgraph):
# Dimshuffle which drop dimensions not supported yet
if ds_op.drop:
return False

rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs
rv = rv_node.default_output()

# We need to know the dimensions that were *not* added by the `size`
# parameter (i.e. the dimensions corresponding to independent variates with
# different parameter values)
num_ind_dims = None
if len(dist_params) == 1:
num_ind_dims = dist_params[0].ndim
else:
# When there is more than one distribution parameter, assume that all
# of them will broadcast to the maximum number of dimensions
num_ind_dims = max(d.ndim for d in dist_params)

# If the indices in `ds_new_order` are entirely within the replication
# indices group or the independent variates indices group, then we can apply
# this rewrite.

ds_new_order = ds_op.new_order
# Create a map from old index order to new/`DimShuffled` index order
dim_orders = [(n, d) for n, d in enumerate(ds_new_order) if isinstance(d, int)]

# Find the index at which the replications/independents split occurs
reps_ind_split_idx = len(dim_orders) - (num_ind_dims + rv_op.ndim_supp)

ds_reps_new_dims = dim_orders[:reps_ind_split_idx]
ds_ind_new_dims = dim_orders[reps_ind_split_idx:]
ds_in_ind_space = ds_ind_new_dims and all(
d >= reps_ind_split_idx for n, d in ds_ind_new_dims
)
# Check that Dimshuffle does not affect support dims
supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim))
shuffled_dims = {dim for i, dim in enumerate(ds_op.shuffle) if dim != i}
augmented_dims = set(d - rv_op.ndim_supp for d in ds_op.augment)
if (shuffled_dims | augmented_dims) & supp_dims:
return False

if ds_in_ind_space or (not ds_ind_new_dims and not ds_reps_new_dims):
# If no one else is using the underlying RandomVariable, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(base_rv, node, fgraph):
return False

# Update the `size` array to reflect the `DimShuffle`d dimensions,
# since the trailing dimensions in `size` represent the independent
# variates dimensions (for univariate distributions, at least)
has_size = get_vector_length(size) > 0
new_size = (
[constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order]
if has_size
else size
batched_dims = rv.ndim - rv_op.ndim_supp
batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims)

# Make size explicit
missing_size_dims = batched_dims - get_vector_length(size)
if missing_size_dims > 0:
full_size = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape)
size = full_size[:missing_size_dims] + tuple(size)

# Update the size to reflect the DimShuffled dimensions
new_size = [
constant(1, dtype="int64") if o == "x" else size[o]
for o in batched_dims_ds_order
]

# Updates the params to reflect the Dimshuffled dimensions
new_dist_params = []
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# Add broadcastable dimensions to the parameters that would have been expanded by the size
padleft = batched_dims - (param.ndim - param_ndim_supp)
if padleft > 0:
param = shape_padleft(param, padleft)

# Add the parameter support dimension indexes to the batched dimensions Dimshuffle
param_new_order = batched_dims_ds_order + tuple(
range(batched_dims, batched_dims + param_ndim_supp)
)
new_dist_params.append(param.dimshuffle(param_new_order))

# Compute the new axes parameter(s) for the `DimShuffle` that will be
# applied to the `RandomVariable` parameters (they need to be offset)
if ds_ind_new_dims:
rv_params_new_order = [
d - reps_ind_split_idx if isinstance(d, int) else d
for d in ds_new_order[ds_ind_new_dims[0][0] :]
]

if not has_size and len(ds_new_order[: ds_ind_new_dims[0][0]]) > 0:
# Additional broadcast dimensions need to be added to the
# independent dimensions (i.e. parameters), since there's no
# `size` to which they can be added
rv_params_new_order = (
list(ds_new_order[: ds_ind_new_dims[0][0]]) + rv_params_new_order
)
else:
# This case is reached when, for example, `ds_new_order` only
# consists of new broadcastable dimensions (i.e. `"x"`s)
rv_params_new_order = ds_new_order

# Lift the `DimShuffle`s into the parameters
# NOTE: The parameters might not be broadcasted against each other, so
# we can only apply the parts of the `DimShuffle` that are relevant.
new_dist_params = []
for d in dist_params:
if d.ndim < len(ds_ind_new_dims):
_rv_params_new_order = [
o
for o in rv_params_new_order
if (isinstance(o, int) and o < d.ndim) or o == "x"
]
else:
_rv_params_new_order = rv_params_new_order

new_dist_params.append(
type(ds_op)(d.type.broadcastable, _rv_params_new_order)(d)
)
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)

if config.compute_test_value != "off":
compute_test_value(new_node)

out = new_node.outputs[1]
if base_rv.name:
out.name = f"{base_rv.name}_lifted"
return [out]
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)

ds_in_reps_space = ds_reps_new_dims and all(
d < reps_ind_split_idx for n, d in ds_reps_new_dims
)

if ds_in_reps_space:
# Update the `size` array to reflect the `DimShuffle`d dimensions.
# There should be no need to `DimShuffle` now.
new_size = [
constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order
]

new_node = rv_op.make_node(rng, new_size, dtype, *dist_params)

if config.compute_test_value != "off":
compute_test_value(new_node)

out = new_node.outputs[1]
if base_rv.name:
out.name = f"{base_rv.name}_lifted"
return [out]
if config.compute_test_value != "off":
compute_test_value(new_node)

return False
out = new_node.outputs[1]
if base_rv.name:
out.name = f"{base_rv.name}_lifted"
return [out]


@node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])
Expand Down
63 changes: 60 additions & 3 deletions tests/tensor/random/test_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import (
dirichlet,
Expand Down Expand Up @@ -42,7 +43,11 @@ def apply_local_rewrite_to_rv(

size_at = []
for s in size:
s_at = iscalar()
# To test DimShuffle with dropping dims we need that size dimension to be constant
if s == 1:
s_at = constant(np.array(1, dtype="int32"))
else:
s_at = iscalar()
s_at.tag.test_value = s
size_at.append(s_at)

Expand Down Expand Up @@ -314,7 +319,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
),
(
("x", 1, 0, 2, "x"),
False,
True,
normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
Expand All @@ -332,7 +337,30 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
(3, 2, 2),
1,
),
# A multi-dimensional case
# Supported multi-dimensional cases
(
(1, 0, 2),
True,
multivariate_normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(3, 2),
1e-3,
),
(
(1, 0, "x", 2),
True,
multivariate_normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(3, 2),
1e-3,
),
# Not supported multi-dimensional cases where dimshuffle affects the support dimensionality
(
(0, 2, 1),
False,
Expand All @@ -344,6 +372,35 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
(3, 2),
1e-3,
),
(
(0, 1, 2, "x"),
False,
multivariate_normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(3, 2),
1e-3,
),
pytest.param(
(1,),
True,
normal,
(0, 1),
(1, 2),
1e-3,
marks=pytest.mark.xfail(reason="Dropping dimensions not supported yet"),
),
pytest.param(
(1,),
True,
normal,
([[0, 0]], 1),
(1, 2),
1e-3,
marks=pytest.mark.xfail(reason="Dropping dimensions not supported yet"),
),
],
)
@config.change_flags(compute_test_value_opt="raise", compute_test_value="raise")
Expand Down

0 comments on commit 2ebfbf1

Please sign in to comment.