Skip to content

Commit

Permalink
Derive logprob of SetSubtensor operations
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 29, 2024
1 parent c7840ae commit 9f0f6d5
Show file tree
Hide file tree
Showing 4 changed files with 375 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ jobs:
tests/logprob/test_order.py
tests/logprob/test_rewriting.py
tests/logprob/test_scan.py
tests/logprob/test_set_subtensor.py
tests/logprob/test_tensor.py
tests/logprob/test_transform_value.py
tests/logprob/test_transforms.py
Expand Down
1 change: 1 addition & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import pymc.logprob.mixture
import pymc.logprob.order
import pymc.logprob.scan
import pymc.logprob.set_subtensor
import pymc.logprob.tensor
import pymc.logprob.transforms

Expand Down
215 changes: 215 additions & 0 deletions pymc/logprob/set_subtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytensor.graph.basic import Variable
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor import eq
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
IncSubtensor,
indices_from_subtensor,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT

from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper
from pymc.logprob.checks import MeasurableCheckAndRaise
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import (
check_potential_measurability,
dirac_delta,
filter_measurable_variables,
)


class MeasurableSetSubtensor(IncSubtensor, MeasurableOp):
"""Measurable SetSubtensor Op."""

def __str__(self):
return f"Measurable{super().__str__()}"


class MeasurableAdvancedSetSubtensor(AdvancedIncSubtensor, MeasurableOp):
"""Measurable AdvancedSetSubtensor Op."""

def __str__(self):
return f"Measurable{super().__str__()}"


set_subtensor_does_not_broadcast = MeasurableCheckAndRaise(
exc_type=NotImplementedError,
msg="Measurable SetSubtensor not supported when set value is broadcasted.",
)


@node_rewriter(tracks=[IncSubtensor, AdvancedIncSubtensor1, AdvancedIncSubtensor])
def find_measurable_set_subtensor(fgraph, node) -> list | None:
"""Finds `SetSubtensor` for which a `logprob` can be computed."""

if isinstance(node.op, MeasurableOp):
return None

if not node.op.set_instead_of_inc:
return None

x, y, *idx_elements = node.inputs

measurable_inputs = filter_measurable_variables([x, y])

if y not in measurable_inputs:
return None

if x not in measurable_inputs:
# x is potentially measurable, wait for it's logprob IR to be inferred
if check_potential_measurability([x]):
return None
# x has no link to measurable variables, so it's value should be constant
else:
x = dirac_delta(x, rtol=0, atol=0)

if check_potential_measurability(idx_elements):
return None

if isinstance(node.op, IncSubtensor):
measurable_class = MeasurableSetSubtensor
idx = indices_from_subtensor(idx_elements, node.op.idx_list)
else:
measurable_class = MeasurableAdvancedSetSubtensor
idx = tuple(idx_elements)

# Check that y is not certainly broadcasted.
indexed_block = x[idx]
missing_y_dims = indexed_block.type.ndim - y.type.ndim
y_bcast = [True] * missing_y_dims + list(y.type.broadcastable)
if any(
y_dim_bcast and indexed_block_dim_len not in (None, 1)
for y_dim_bcast, indexed_block_dim_len in zip(
y_bcast, indexed_block.type.shape, strict=True
)
):
return None

measurable_set_subtensor = measurable_class(**node.op._props_dict())(x, y, *idx_elements)

# Often with indexing we don't know the static shape of the indexed block.
# And, what's more, the indexing operations actually support runtime broadcasting.
# As the logp is not valid under broadcasting, we have to add a runtime check.
# This will hopefully be removed during shape inference when not violated.
potential_broadcasted_dims = [
i
for i, (y_bcast_dim, indexed_block_dim_len) in enumerate(
zip(y_bcast, indexed_block.type.shape)
)
if y_bcast_dim and indexed_block_dim_len is None
]
if potential_broadcasted_dims:
indexed_block_shape = tuple(indexed_block.shape)
measurable_set_subtensor = set_subtensor_does_not_broadcast(
measurable_set_subtensor,
*(eq(indexed_block_shape[i], 1) for i in potential_broadcasted_dims),
)

return [measurable_set_subtensor]


measurable_ir_rewrites_db.register(
find_measurable_set_subtensor.__name__,
find_measurable_set_subtensor,
"basic",
"set_subtensor",
)


def indexed_dims(idx) -> list[int | None]:
"""Return the indices of the dimensions of the indexed tensor that are being indexed."""
dims = []
idx_counter = 0
for idx_elem in idx:
if isinstance(idx_elem, Variable) and isinstance(idx_elem.type, NoneTypeT):
# None in indexes correspond to newaxis, and don't map to any existing dimension
dims.append(None)

elif (
isinstance(idx_elem, Variable)
and isinstance(idx_elem.type, TensorType)
and idx_elem.type.dtype == "bool"
):
# Boolean indexes map to as many dimensions as the mask has
for i in range(idx_elem.type.ndim):
dims.append(idx_counter)
idx_counter += 1
else:
dims.append(idx_counter)
idx_counter += 1

return dims


@_logprob.register(MeasurableSetSubtensor)
@_logprob.register(MeasurableAdvancedSetSubtensor)
def logprob_setsubtensor(op, values, x, y, *idx_elements, **kwargs):
"""Compute the log-likelihood graph for a `SetSubtensor`.
For a generative graph like:
o = zeros(2)
x = o[0].set(X)
y = x[1].set(Y)
The log-likelihood graph is:
logp(y, value) = (
logp(x, value)
[1].set(logp(y, value[1]))
)
Unrolling the logp(x, value) gives:
logp(y, value) = (
DiracDelta(zeros(2), value) # Irrelevant if all entries are set
[0].set(logp(x, value[0]))
[1].set(logp(y, value[1]))
)
"""
[value] = values
if isinstance(op, MeasurableSetSubtensor):
# For basic indexing we have to recreate the index from the input list
idx = indices_from_subtensor(idx_elements, op.idx_list)
else:
# For advanced indexing we can use the idx_elements directly
idx = tuple(idx_elements)

x_logp = _logprob_helper(x, value)
y_logp = _logprob_helper(y, value[idx])

y_ndim_supp = x[idx].type.ndim - y_logp.type.ndim
x_ndim_supp = x.type.ndim - x_logp.type.ndim
ndim_supp = max(y_ndim_supp, x_ndim_supp)
if ndim_supp > 0:
# Multivariate logp only valid if we are not doing indexing along the reduced dimensions
# Otherwise we don't know if successive writings are overlapping or not
core_dims = set(range(x.type.ndim)[-ndim_supp:])
if set(indexed_dims(idx)) & core_dims:
# When we have IR meta-info about support_ndim, we can fail at the rewriting stage
raise NotImplementedError(
"Indexing along core dimensions of multivariate SetSubtensor not supported"
)

ndim_supp_diff = y_ndim_supp - x_ndim_supp
if ndim_supp_diff > 0:
# In this case y_logp will have fewer dimensions than x_logp after indexing, so we need to reduce x before indexing.
x_logp = x_logp.sum(axis=tuple(range(-ndim_supp_diff, 0)))
elif ndim_supp_diff < 0:
# In this case x_logp will have fewer dimensions than y_logp after indexing, so we need to reduce y before indexing.
y_logp = y_logp.sum(axis=tuple(range(ndim_supp_diff, 0)))

out_logp = x_logp[idx].set(y_logp)
return out_logp
158 changes: 158 additions & 0 deletions tests/logprob/test_set_subtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytensor
import pytensor.tensor as pt
import pytest

from pymc.distributions import Beta, Dirichlet, MvNormal, MvStudentT, Normal, StudentT
from pymc.logprob.basic import logp


@pytest.mark.parametrize("univariate", [True, False])
def test_complete_set_subtensor(univariate):
if univariate:
rv0 = Normal.dist(mu=-10)
rv1 = StudentT.dist(nu=3, mu=0)
rv2 = Normal.dist(mu=10, sigma=3)
rv34 = Beta.dist(alpha=[np.pi, np.e], beta=[1, 1])
base = pt.empty((5,))
test_val = [2, 0, -2, 0.25, 0.5]
else:
rv0 = MvNormal.dist(mu=[-11, -9], cov=pt.eye(2))
rv1 = MvStudentT.dist(nu=3, mu=[-1, 1], cov=pt.eye(2))
rv2 = MvNormal.dist(mu=[9, 11], cov=pt.eye(2) * 3)
rv34 = Dirichlet.dist(a=[[np.pi, 1], [np.e, 1]])
base = pt.empty((3, 2))
test_val = [[2, 0], [0, -2], [-2, 2], [0.25, 0.75], [0.5, 0.5]]

# fmt: off
rv = (
# Boolean indexing
base[np.array([True, False, False, False, False])].set(rv0)
# Slice indexing
[1:2].set(rv1)
# Integer indexing
[2].set(rv2)
# Vector indexing
[[3, 4]].set(rv34)
)
# fmt: on
ref_rv = pt.join(0, [rv0], [rv1], [rv2], rv34)

np.testing.assert_allclose(
logp(rv, test_val).eval(),
logp(ref_rv, test_val).eval(),
)


def test_partial_set_subtensor():
rv123 = Normal.dist(mu=[-10, 0, 10])

# When base is empty, it doesn't matter what the missing values are
base = pt.empty((5,))
rv = base[:3].set(rv123)

np.testing.assert_allclose(
logp(rv, [0, 0, 0, 1, np.pi]).eval(),
[*logp(rv123, [0, 0, 0]).eval(), 0, 0],
)

# Otherwise they should match
base = pt.ones((5,))
rv = base[:3].set(rv123)

np.testing.assert_allclose(
logp(rv, [0, 0, 0, 1, np.pi]).eval(),
[*logp(rv123, [0, 0, 0]).eval(), 0, -np.inf],
)


def test_overwrite_set_subtensor():
"""Test that order of overwriting in the generative graph is respected."""
x = Normal.dist(mu=[0, 1, 2])
y = x[1:].set(Normal.dist([10, 20]))
z = y[2:].set(Normal.dist([300]))

np.testing.assert_allclose(
logp(z, [0, 0, 0]).eval(),
logp(Normal.dist([0, 10, 300]), [0, 0, 0]).eval(),
)


def test_mixed_dimensionality_set_subtensor():
x = Normal.dist(mu=0, size=(3, 2))
y = x[1].set(MvNormal.dist(mu=[1, 1], cov=np.eye(2)))
z = y[2].set(Normal.dist(mu=2, size=(2,)))

# Because `y` is multivariate the last dimension of `z` must be summed over
test_val = np.zeros((3, 2))
logp_eval = logp(z, test_val).eval()
assert logp_eval.shape == (3,)
np.testing.assert_allclose(
logp_eval,
logp(Normal.dist(mu=[[0, 0], [1, 1], [2, 2]]), test_val).sum(-1).eval(),
)


def test_invalid_indexing_core_dims():
x = pt.empty((2, 2))
rv = MvNormal.dist(cov=np.eye(2))
vv = x.type()

match_msg = "Indexing along core dimensions of multivariate SetSubtensor not supported"

y = x[[0, 1], [1, 0]].set(rv)
with pytest.raises(NotImplementedError, match=match_msg):
logp(y, vv)

y = x[np.array([[False, True], [True, False]])].set(rv)
with pytest.raises(NotImplementedError, match=match_msg):
logp(y, vv)

# Univariate indexing above multivariate core dims also not supported
z = y[0].set(rv)[0, 1].set(Normal.dist())
with pytest.raises(NotImplementedError, match=match_msg):
logp(z, vv)


def test_invalid_broadcasted_set_subtensor():
rv_bcast = Normal.dist(mu=0)
base = pt.empty((5,))

rv = base[:3].set(rv_bcast)
vv = rv.type()

# Broadcasting is known at write time, and PyMC does not attempt to make SetSubtensor measurable
with pytest.raises(NotImplementedError):
logp(rv, vv)

mask = pt.tensor(shape=(5,), dtype=bool)
rv = base[mask].set(rv_bcast)

# Broadcasting is only known at runtime, and PyMC raises an error when it happens
logp_rv = logp(rv, vv)
fn = pytensor.function([mask, vv], logp_rv)
test_vv = np.zeros(5)

np.testing.assert_allclose(
fn([False, False, True, False, False], test_vv),
[0, 0, -0.91893853, 0, 0],
)

with pytest.raises(
NotImplementedError,
match="Measurable SetSubtensor not supported when set value is broadcasted.",
):
fn([False, False, True, False, True], test_vv)

0 comments on commit 9f0f6d5

Please sign in to comment.