Skip to content

Commit

Permalink
Add cond operation for conditional selection (#141)
Browse files Browse the repository at this point in the history
* Refactor and simplify counterfactual handlers

* lint

* Add cond operation

* Add broadcasting test cases

* add broadcasting test cases

* docstring for cond

* Update chirho/indexed/ops.py

Co-authored-by: rfl-urbaniak <[email protected]>

---------

Co-authored-by: rfl-urbaniak <[email protected]>
  • Loading branch information
eb8680 and rfl-urbaniak authored Jul 14, 2023
1 parent 198ccc7 commit 035a84b
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 0 deletions.
27 changes: 27 additions & 0 deletions chirho/indexed/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from chirho.indexed.ops import (
IndexSet,
cond,
gather,
get_index_plates,
indices_of,
Expand Down Expand Up @@ -179,6 +180,32 @@ def _indices_of_distribution(
return indices_of(value.batch_shape, event_dim=0, **kwargs)


@cond.register(int)
@cond.register(float)
@cond.register(bool)
def _cond_number(
fst: Union[bool, numbers.Number],
snd: Union[bool, numbers.Number, torch.Tensor],
case: Union[bool, torch.Tensor],
**kwargs,
) -> torch.Tensor:
return cond(
torch.as_tensor(fst), torch.as_tensor(snd), torch.as_tensor(case), **kwargs
)


@cond.register
def _cond_tensor(
fst: torch.Tensor,
snd: torch.Tensor,
case: torch.Tensor,
*,
event_dim: int = 0,
**kwargs,
) -> torch.Tensor:
return torch.where(case[(...,) + (None,) * event_dim], snd, fst)


class _LazyPlateMessenger(IndepMessenger):
prefix: str = "__index_plate__"

Expand Down
46 changes: 46 additions & 0 deletions chirho/indexed/ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import operator
from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple, TypeVar, Union

import pyro
Expand Down Expand Up @@ -264,6 +265,51 @@ def _scatter_n(values: Dict[IndexSet, T], *, result: Optional[T] = None, **kwarg
return result


@functools.singledispatch
def cond(fst, snd: T, case, **kwargs):
"""
Selection operation that is the sum-type analogue of :func:`scatter`
in the sense that where :func:`scatter` propagates both of its arguments,
:func:`cond` propagates only one, depending on the value of a boolean ``case`` .
For a given ``fst`` , ``snd`` , and ``case`` , :func:`cond` returns
``snd`` if the ``case`` is true, and ``fst`` otherwise,
analogous to a Python conditional expression ``snd if case else fst`` .
Unlike a Python conditional expression, however, the case may be a tensor,
and both branches are evaluated, as with :func:`torch.where` ::
>> fst, snd = torch.randn(2, 3), torch.randn(2, 3)
>> case = (fst < snd).all(-1)
>> x = cond(fst, snd, case, event_dim=1)
>> assert (x == torch.where(case[..., None], snd, fst)).all()
.. note::
:func:`cond` can be extended to new value types by registering
an implementation for the type using :func:`functools.singledispatch` .
:param fst: The value to return if ``case`` is ``False`` .
:param snd: The value to return if ``case`` is ``True`` .
:param case: A boolean value or tensor. If a tensor, should have event shape ``()`` .
:param kwargs: Additional keyword arguments used by specific implementations.
"""
raise NotImplementedError(f"cond not implemented for {type(fst)}")


@cond.register(dict)
@pyro.poutine.runtime.effectful(type="cond_n")
def _cond_n(values: Dict[IndexSet, T], case: Union[bool, torch.Tensor], **kwargs):
assert len(values) > 0
assert all(isinstance(k, IndexSet) for k in values.keys())
result: Optional[T] = None
for indices, value in values.items():
tst = functools.reduce(
operator.or_, [case == index for index in next(iter(indices.values()))]
)
result = cond(result if result is not None else value, value, tst, **kwargs)
return result


@pyro.poutine.runtime.effectful(type="get_index_plates")
def get_index_plates() -> (
Dict[Hashable, pyro.poutine.indep_messenger.CondIndepStackFrame]
Expand Down
57 changes: 57 additions & 0 deletions tests/indexed/test_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from chirho.indexed.internals import add_indices
from chirho.indexed.ops import (
IndexSet,
cond,
gather,
get_index_plates,
indexset_as_mask,
Expand Down Expand Up @@ -421,3 +422,59 @@ def test_index_plate_names():
assert len(index_plates) == 1
for name, frame in index_plates.items():
assert name != frame.name


@pytest.mark.parametrize(
"enum_shape,plate_shape,batch_shape,event_shape", SHAPE_CASES, ids=str
)
def test_cond_tensor_associate(enum_shape, batch_shape, plate_shape, event_shape):
cf_dim = -1 - len(plate_shape)
event_dim = len(event_shape)
ind1, ind2, ind3 = (
IndexSet(new_dim={0}),
IndexSet(new_dim={1}),
IndexSet(new_dim={2}),
)
name_to_dim = {f"dim_{i}": cf_dim - i for i in range(len(batch_shape))}

case = torch.randint(0, 3, enum_shape + batch_shape + plate_shape)
value1 = torch.randn(batch_shape + plate_shape + event_shape)
value2 = torch.randn(
enum_shape + batch_shape + (1,) * len(plate_shape) + event_shape
)
value3 = torch.randn(enum_shape + batch_shape + plate_shape + event_shape)

with IndexPlatesMessenger(cf_dim):
for name, dim in name_to_dim.items():
add_indices(
IndexSet(**{name: set(range(max(3, (batch_shape + plate_shape)[dim])))})
)

actual_full = cond(
{ind1: value1, ind2: value2, ind3: value3}, case, event_dim=event_dim
)

actual_left = cond(
cond(value1, value2, case == 1, event_dim=event_dim),
value3,
case >= 2,
event_dim=event_dim,
)

actual_right = cond(
value1,
cond(value2, value3, case == 2, event_dim=event_dim),
case >= 1,
event_dim=event_dim,
)

assert (
indices_of(actual_full, event_dim=event_dim)
== indices_of(actual_left, event_dim=event_dim)
== indices_of(actual_right, event_dim=event_dim)
)

assert actual_full.shape == enum_shape + batch_shape + plate_shape + event_shape
assert actual_full.shape == actual_left.shape == actual_right.shape
assert (actual_full == actual_left).all()
assert (actual_left == actual_right).all()

0 comments on commit 035a84b

Please sign in to comment.