Skip to content

Commit

Permalink
Update tensor.where to allow for case with only condition (#844)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanish1729 authored Jun 24, 2024
1 parent d3bd1f1 commit 7159215
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
26 changes: 25 additions & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,31 @@ def switch(cond, ift, iff):
"""if cond then ift else iff"""


where = switch
def where(cond, ift=None, iff=None, **kwargs):
"""
where(condition, [ift, iff])
Return elements chosen from `ift` or `iff` depending on `condition`.
Note: When only condition is provided, this function is a shorthand for `as_tensor(condition).nonzero()`.
Parameters
----------
condition : tensor_like, bool
Where True, yield `ift`, otherwise yield `iff`.
x, y : tensor_like
Values from which to choose.
Returns
-------
out : TensorVariable
A tensor with elements from `ift` where `condition` is True, and elements from `iff` elsewhere.
"""
if ift is not None and iff is not None:
return switch(cond, ift, iff, **kwargs)
elif ift is None and iff is None:
return as_tensor(cond).nonzero(**kwargs)
else:
raise ValueError("either both or neither of ift and iff should be given")


@scalar_elemwise
Expand Down
18 changes: 18 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
triu_indices,
triu_indices_from,
vertical_stack,
where,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise
Expand Down Expand Up @@ -4608,3 +4609,20 @@ def core_np(x, y):
vectorize_pt(x_test, y_test),
vectorize_np(x_test, y_test),
)


def test_where():
a = np.arange(10)
cond = a < 5
ift = np.pi
iff = np.e
# Test for all 3 inputs
np.testing.assert_allclose(np.where(cond, ift, iff), where(cond, ift, iff).eval())

# Test for only condition input
for np_output, pt_output in zip(np.where(cond), where(cond)):
np.testing.assert_allclose(np_output, pt_output.eval())

# Test for error
with pytest.raises(ValueError, match="either both"):
where(cond, ift)

0 comments on commit 7159215

Please sign in to comment.