Skip to content

Commit

Permalink
Add IfElse
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Schweer committed Jul 17, 2024
1 parent 426931b commit 3c4f73d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
13 changes: 13 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
Expand Down Expand Up @@ -124,6 +125,7 @@ def eye(N, M, k):
return eye



@pytorch_funcify.register(MakeVector)
def pytorch_funcify_MakeVector(op, **kwargs):
torch_dtype = getattr(torch, op.dtype)
Expand All @@ -132,3 +134,14 @@ def makevector(*x):
return torch.tensor(x, dtype=torch_dtype)

return makevector


@pytorch_funcify.register(IfElse)
def pytorch_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
assert n_outs == 1

def ifelse(cond, *args, n_outs=n_outs):
return torch.where(cond, *args)

return ifelse
20 changes: 19 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.op import Op, get_test_value
from pytensor.ifelse import ifelse
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrix, scalar, vector
Expand Down Expand Up @@ -301,3 +302,20 @@ def test_pytorch_MakeVector():
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])


def test_pytorch_ifelse():
true_vals = np.r_[1, 2, 3]
false_vals = np.r_[-1, -2, -3]

x = ifelse(np.array(True), true_vals, false_vals)
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])

a = scalar("a")
a.tag.test_value = np.array(0.2, dtype=config.floatX)
x = ifelse(a < 0.5, true_vals, false_vals)
x_fg = FunctionGraph([a], [x]) # I.e. False

compare_pytorch_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])

0 comments on commit 3c4f73d

Please sign in to comment.