Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive probability of set_subtensor operations #7553

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ jobs:
tests/logprob/test_order.py
tests/logprob/test_rewriting.py
tests/logprob/test_scan.py
tests/logprob/test_set_subtensor.py
tests/logprob/test_tensor.py
tests/logprob/test_transform_value.py
tests/logprob/test_transforms.py
Expand Down
1 change: 1 addition & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import pymc.logprob.mixture
import pymc.logprob.order
import pymc.logprob.scan
import pymc.logprob.set_subtensor
import pymc.logprob.tensor
import pymc.logprob.transforms

Expand Down
2 changes: 1 addition & 1 deletion pymc/logprob/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs):
(value,) = values
# transfer assertion from rv to value
assertions = replace_rvs_by_values(assertions, rvs_to_values={inner_rv: value})
value = op(value, *assertions)
value = CheckAndRaise(**op._props_dict())(value, *assertions)
return _logprob_helper(inner_rv, value)


Expand Down
215 changes: 215 additions & 0 deletions pymc/logprob/set_subtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytensor.graph.basic import Variable
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor import eq
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
IncSubtensor,
indices_from_subtensor,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT

from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper
from pymc.logprob.checks import MeasurableCheckAndRaise
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import (
check_potential_measurability,
dirac_delta,
filter_measurable_variables,
)


class MeasurableSetSubtensor(IncSubtensor, MeasurableOp):
"""Measurable SetSubtensor Op."""

def __str__(self):
return f"Measurable{super().__str__()}"

Check warning on line 40 in pymc/logprob/set_subtensor.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/set_subtensor.py#L40

Added line #L40 was not covered by tests


class MeasurableAdvancedSetSubtensor(AdvancedIncSubtensor, MeasurableOp):
"""Measurable AdvancedSetSubtensor Op."""

def __str__(self):
return f"Measurable{super().__str__()}"

Check warning on line 47 in pymc/logprob/set_subtensor.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/set_subtensor.py#L47

Added line #L47 was not covered by tests


set_subtensor_does_not_broadcast = MeasurableCheckAndRaise(
exc_type=NotImplementedError,
msg="Measurable SetSubtensor not supported when set value is broadcasted.",
)


@node_rewriter(tracks=[IncSubtensor, AdvancedIncSubtensor1, AdvancedIncSubtensor])
def find_measurable_set_subtensor(fgraph, node) -> list | None:
"""Find `SetSubtensor` for which a `logprob` can be computed."""
if isinstance(node.op, MeasurableOp):
return None

if not node.op.set_instead_of_inc:
return None

Check warning on line 63 in pymc/logprob/set_subtensor.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/set_subtensor.py#L63

Added line #L63 was not covered by tests

x, y, *idx_elements = node.inputs

measurable_inputs = filter_measurable_variables([x, y])

if y not in measurable_inputs:
return None

if x not in measurable_inputs:
# x is potentially measurable, wait for it's logprob IR to be inferred
if check_potential_measurability([x]):
return None
# x has no link to measurable variables, so it's value should be constant
else:
x = dirac_delta(x, rtol=0, atol=0)

if check_potential_measurability(idx_elements):
return None

Check warning on line 81 in pymc/logprob/set_subtensor.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/set_subtensor.py#L81

Added line #L81 was not covered by tests

measurable_class: type[MeasurableSetSubtensor | MeasurableAdvancedSetSubtensor]
if isinstance(node.op, IncSubtensor):
measurable_class = MeasurableSetSubtensor
idx = indices_from_subtensor(idx_elements, node.op.idx_list)
else:
measurable_class = MeasurableAdvancedSetSubtensor
idx = tuple(idx_elements)

# Check that y is not certainly broadcasted.
indexed_block = x[idx]
missing_y_dims = indexed_block.type.ndim - y.type.ndim
y_bcast = [True] * missing_y_dims + list(y.type.broadcastable)
if any(
y_dim_bcast and indexed_block_dim_len not in (None, 1)
for y_dim_bcast, indexed_block_dim_len in zip(
y_bcast, indexed_block.type.shape, strict=True
)
):
return None

measurable_set_subtensor = measurable_class(**node.op._props_dict())(x, y, *idx_elements)

# Often with indexing we don't know the static shape of the indexed block.
# And, what's more, the indexing operations actually support runtime broadcasting.
# As the logp is not valid under broadcasting, we have to add a runtime check.
# This will hopefully be removed during shape inference when not violated.
potential_broadcasted_dims = [
i
for i, (y_bcast_dim, indexed_block_dim_len) in enumerate(
zip(y_bcast, indexed_block.type.shape)
)
if y_bcast_dim and indexed_block_dim_len is None
]
if potential_broadcasted_dims:
indexed_block_shape = tuple(indexed_block.shape)
measurable_set_subtensor = set_subtensor_does_not_broadcast(
measurable_set_subtensor,
*(eq(indexed_block_shape[i], 1) for i in potential_broadcasted_dims),
)

return [measurable_set_subtensor]


measurable_ir_rewrites_db.register(
find_measurable_set_subtensor.__name__,
find_measurable_set_subtensor,
"basic",
"set_subtensor",
)


def indexed_dims(idx) -> list[int | None]:
"""Return the indices of the dimensions of the indexed tensor that are being indexed."""
dims: list[int | None] = []
idx_counter = 0
for idx_elem in idx:
if isinstance(idx_elem, Variable) and isinstance(idx_elem.type, NoneTypeT):
# None in indexes correspond to newaxis, and don't map to any existing dimension
dims.append(None)

Check warning on line 141 in pymc/logprob/set_subtensor.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/set_subtensor.py#L141

Added line #L141 was not covered by tests

elif (
isinstance(idx_elem, Variable)
and isinstance(idx_elem.type, TensorType)
and idx_elem.type.dtype == "bool"
):
# Boolean indexes map to as many dimensions as the mask has
for i in range(idx_elem.type.ndim):
dims.append(idx_counter)
idx_counter += 1
else:
dims.append(idx_counter)
idx_counter += 1

return dims


@_logprob.register(MeasurableSetSubtensor)
@_logprob.register(MeasurableAdvancedSetSubtensor)
def logprob_setsubtensor(op, values, x, y, *idx_elements, **kwargs):
"""Compute the log-likelihood graph for a `SetSubtensor`.
For a generative graph like:
o = zeros(2)
x = o[0].set(X)
y = x[1].set(Y)
The log-likelihood graph is:
logp(y, value) = (
logp(x, value)
[1].set(logp(y, value[1]))
)
Unrolling the logp(x, value) gives:
logp(y, value) = (
DiracDelta(zeros(2), value) # Irrelevant if all entries are set
[0].set(logp(x, value[0]))
[1].set(logp(y, value[1]))
)
"""
[value] = values
if isinstance(op, MeasurableSetSubtensor):
# For basic indexing we have to recreate the index from the input list
idx = indices_from_subtensor(idx_elements, op.idx_list)
else:
# For advanced indexing we can use the idx_elements directly
idx = tuple(idx_elements)

x_logp = _logprob_helper(x, value)
y_logp = _logprob_helper(y, value[idx])

y_ndim_supp = x[idx].type.ndim - y_logp.type.ndim
x_ndim_supp = x.type.ndim - x_logp.type.ndim
ndim_supp = max(y_ndim_supp, x_ndim_supp)
if ndim_supp > 0:
# Multivariate logp only valid if we are not doing indexing along the reduced dimensions
# Otherwise we don't know if successive writings are overlapping or not
core_dims = set(range(x.type.ndim)[-ndim_supp:])
if set(indexed_dims(idx)) & core_dims:
# When we have IR meta-info about support_ndim, we can fail at the rewriting stage
raise NotImplementedError(
"Indexing along core dimensions of multivariate SetSubtensor not supported"
)

ndim_supp_diff = y_ndim_supp - x_ndim_supp
if ndim_supp_diff > 0:
# In this case y_logp will have fewer dimensions than x_logp after indexing, so we need to reduce x before indexing.
x_logp = x_logp.sum(axis=tuple(range(-ndim_supp_diff, 0)))
elif ndim_supp_diff < 0:
# In this case x_logp will have fewer dimensions than y_logp after indexing, so we need to reduce y before indexing.
y_logp = y_logp.sum(axis=tuple(range(ndim_supp_diff, 0)))

out_logp = x_logp[idx].set(y_logp)
return out_logp
48 changes: 18 additions & 30 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,9 @@
from pytensor.graph.basic import Constant, Variable, clone_get_equiv, graph_inputs, walk
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph
from pytensor.link.c.type import CType
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.basic import AllocEmpty, get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -150,27 +149,6 @@ def expand(r):
}


def convert_indices(indices, entry):
if indices and isinstance(entry, CType):
rval = indices.pop(0)
return rval
elif isinstance(entry, slice):
return slice(
convert_indices(indices, entry.start),
convert_indices(indices, entry.stop),
convert_indices(indices, entry.step),
)
else:
return entry


def indices_from_subtensor(idx_list, indices):
"""Compute a useable index tuple from the inputs of a ``*Subtensor**`` ``Op``."""
return tuple(
tuple(convert_indices(list(indices), idx) for idx in idx_list) if idx_list else indices
)


def filter_measurable_variables(inputs):
return [
inp for inp in inputs if (inp.owner is not None and isinstance(inp.owner.op, MeasurableOp))
Expand Down Expand Up @@ -266,7 +244,7 @@ class DiracDelta(MeasurableOp, Op):

__props__ = ("rtol", "atol")

def __init__(self, rtol=1e-5, atol=1e-8):
def __init__(self, rtol, atol):
self.rtol = rtol
self.atol = atol

Expand All @@ -289,15 +267,25 @@ def infer_shape(self, fgraph, node, input_shapes):
return input_shapes


dirac_delta = DiracDelta()
def dirac_delta(x, rtol=1e-5, atol=1e-8):
return DiracDelta(rtol, atol)(x)


@_logprob.register(DiracDelta)
def diracdelta_logprob(op, values, *inputs, **kwargs):
(values,) = values
(const_value,) = inputs
values, const_value = pt.broadcast_arrays(values, const_value)
return pt.switch(pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol), 0.0, -np.inf)
def diracdelta_logprob(op, values, const_value, **kwargs):
[value] = values

if const_value.owner and isinstance(const_value.owner.op, AllocEmpty):
# Any value is considered valid for an AllocEmpty array
return pt.zeros_like(value)

if op.rtol == 0 and op.atol == 0:
# Strict equality, cheaper logp
match = pt.eq(value, const_value)
else:
# Loose equality, more expensive logp
match = pt.isclose(value, const_value, rtol=op.rtol, atol=op.atol)
return pt.switch(match, np.array(0, dtype=value.dtype), -np.inf)


def find_negated_var(var):
Expand Down
Loading
Loading