Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch conditionals: IfElse #940

Closed
wants to merge 73 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
3c4f73d
Add IfElse
Jul 17, 2024
bfb97ea
Remove space
Jul 17, 2024
6ad1c5c
Implement Dot and BatchedDot in PyTensor (#878)
HangenYuu Jul 18, 2024
cac9feb
Add `OpFromGraph` wrapper around `alloc_diag` (#915)
jessegrabowski Jul 18, 2024
ad27dc7
Bump actions/upload-artifact from 3 to 4 (#560)
dependabot[bot] Jul 18, 2024
f489cf4
Added rewrite for matrix inv(inv(x)) -> x (#893)
tanish1729 Jul 19, 2024
981688c
Implement `pad` (#748)
jessegrabowski Jul 19, 2024
a601a27
Update away from torch.where
Jul 21, 2024
aab9fae
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Jul 22, 2024
739d97d
Removed unused config options
Armavica Jul 19, 2024
b9f2dde
Remove add_experimental_configvars
Armavica Jul 19, 2024
f9f5c5b
Remove default in_c_key and change for cast_policy
Armavica Jul 19, 2024
158a7d0
Fix typo in docstring
Armavica Jul 19, 2024
7a0175a
Simplify _ChangeFlagDecorator
Armavica Jul 19, 2024
d9ed1e2
Fix typo amblibm -> amdlibm
Armavica Jul 19, 2024
9f4b89d
Remove unused ContextsParam
Armavica Jul 19, 2024
d455460
Simplify config.add(linker)
Armavica Jul 19, 2024
367351f
Fixed dead wiki links (#950)
HangenYuu Jul 25, 2024
58fec45
Implement nlinalg Ops in PyTorch (#920)
twaclaw Jul 26, 2024
7fd8cbd
Update for m1
Jul 17, 2024
a5587a7
Add new env file
Jul 21, 2024
a09fa75
Update comment
Jul 21, 2024
d6254af
Update environment-osx-arm64.yml
twiecki Jul 22, 2024
23427a0
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Jul 29, 2024
f25a624
Implement Einsum
jessegrabowski Apr 19, 2024
b65d08c
Skip tri test in latest version of JAX
ricardoV94 Aug 4, 2024
da91dc7
Corrected the reference from 'an PyTensor' to 'a PyTensor' in the con…
abhishekshah5486 Aug 5, 2024
0ae3cfe
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Aug 5, 2024
48450b0
Fix test to allow for n_outs>1
Aug 9, 2024
fd27b6a
Remove test value
Aug 9, 2024
7fffec6
Pickle error message changed (#966)
twiecki Aug 10, 2024
29183c7
Add building of pyodide universal wheels (#918)
twiecki Aug 10, 2024
4d0103b
Removed types examples and introduced tensor (#968)
Krupakar-Reddy-S Aug 12, 2024
f62401a
maintanance: unpin scipy
ferrine Aug 13, 2024
dd8895d
mypy: fix graph.py
ferrine Aug 14, 2024
a3f0a4e
mypy: fix graph/basic.py
ferrine Aug 14, 2024
79232b2
Implement Dot and BatchedDot in PyTensor (#878)
HangenYuu Jul 18, 2024
143ded6
Add `OpFromGraph` wrapper around `alloc_diag` (#915)
jessegrabowski Jul 18, 2024
8c30780
Bump actions/upload-artifact from 3 to 4 (#560)
dependabot[bot] Jul 18, 2024
297bdd4
Added rewrite for matrix inv(inv(x)) -> x (#893)
tanish1729 Jul 19, 2024
a4e014e
Implement `pad` (#748)
jessegrabowski Jul 19, 2024
8d25c14
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Jul 22, 2024
6fcc37c
Removed unused config options
Armavica Jul 19, 2024
39612d1
Remove add_experimental_configvars
Armavica Jul 19, 2024
9571d4f
Remove default in_c_key and change for cast_policy
Armavica Jul 19, 2024
ab4f150
Fix typo in docstring
Armavica Jul 19, 2024
153d209
Simplify _ChangeFlagDecorator
Armavica Jul 19, 2024
3aaf756
Fix typo amblibm -> amdlibm
Armavica Jul 19, 2024
1b2802e
Remove unused ContextsParam
Armavica Jul 19, 2024
9c6748f
Simplify config.add(linker)
Armavica Jul 19, 2024
9973e03
Fixed dead wiki links (#950)
HangenYuu Jul 25, 2024
286c8fc
Implement nlinalg Ops in PyTorch (#920)
twaclaw Jul 26, 2024
70c902b
Update for m1
Jul 17, 2024
bd607f3
Add new env file
Jul 21, 2024
3249ae2
Update comment
Jul 21, 2024
d2ad1ed
Update environment-osx-arm64.yml
twiecki Jul 22, 2024
f11df4a
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Jul 29, 2024
a7c099c
Implement Einsum
jessegrabowski Apr 19, 2024
6112f82
Skip tri test in latest version of JAX
ricardoV94 Aug 4, 2024
cd8585d
Corrected the reference from 'an PyTensor' to 'a PyTensor' in the con…
abhishekshah5486 Aug 5, 2024
bd38216
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Aug 5, 2024
521b8ca
Pickle error message changed (#966)
twiecki Aug 10, 2024
917cc55
Add building of pyodide universal wheels (#918)
twiecki Aug 10, 2024
e879b0c
Removed types examples and introduced tensor (#968)
Krupakar-Reddy-S Aug 12, 2024
3523d79
maintanance: unpin scipy
ferrine Aug 13, 2024
400323f
mypy: fix graph.py
ferrine Aug 14, 2024
f0214a1
mypy: fix graph/basic.py
ferrine Aug 14, 2024
9f3a938
Add IfElse
Jul 17, 2024
d36d4ce
Remove space
Jul 17, 2024
9adbbe2
Update away from torch.where
Jul 21, 2024
2766457
Fix test to allow for n_outs>1
Aug 9, 2024
ef9277b
Remove test value
Aug 9, 2024
d4aaeaf
Merge branch 'branches' of github.com:Ch0ronomato/pytensor into branches
Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
Expand Down Expand Up @@ -132,3 +133,14 @@ def makevector(*x):
return torch.tensor(x, dtype=torch_dtype)

return makevector


@pytorch_funcify.register(IfElse)
def pytorch_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
assert n_outs == 1
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved

def ifelse(cond, *args, n_outs=n_outs):
return torch.where(cond, *args)

return ifelse
20 changes: 19 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.op import Op, get_test_value
from pytensor.ifelse import ifelse
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrix, scalar, vector
Expand Down Expand Up @@ -301,3 +302,20 @@ def test_pytorch_MakeVector():
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])


def test_pytorch_ifelse():
true_vals = np.r_[1, 2, 3]
false_vals = np.r_[-1, -2, -3]

x = ifelse(np.array(True), true_vals, false_vals)
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])

a = scalar("a")
a.tag.test_value = np.array(0.2, dtype=config.floatX)
x = ifelse(a < 0.5, true_vals, false_vals)
x_fg = FunctionGraph([a], [x]) # I.e. False

compare_pytorch_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])