Skip to content

Commit

Permalink
Add torch implementation of IfElse (#974)
Browse files Browse the repository at this point in the history
Co-authored-by: Ian Schweer <[email protected]>
  • Loading branch information
Ch0ronomato and Ian Schweer authored Oct 3, 2024
1 parent 8a6e407 commit 46fdc58
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
14 changes: 14 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import (
Expand Down Expand Up @@ -153,6 +154,19 @@ def makevector(*x):
return makevector


@pytorch_funcify.register(IfElse)
def pytorch_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs

def ifelse(cond, *true_and_false, n_outs=n_outs):
if cond:
return true_and_false[:n_outs]
else:
return true_and_false[n_outs:]

return ifelse


@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)
Expand Down
18 changes: 18 additions & 0 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrices, matrix, scalar, vector
Expand Down Expand Up @@ -304,6 +305,23 @@ def test_pytorch_MakeVector():
compare_pytorch_and_py(x_fg, [])


def test_pytorch_ifelse():
p1_vals = np.r_[1, 2, 3]
p2_vals = np.r_[-1, -2, -3]

a = scalar("a")
x = ifelse(a < 0.5, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
x_fg = FunctionGraph([a], x)

compare_pytorch_and_py(x_fg, np.array([0.2], dtype=config.floatX))

a = scalar("a")
x = ifelse(a < 0.4, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
x_fg = FunctionGraph([a], x)

compare_pytorch_and_py(x_fg, np.array([0.5], dtype=config.floatX))


def test_pytorch_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y])
Expand Down

0 comments on commit 46fdc58

Please sign in to comment.