Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch implementation of IfElse #974

Merged
merged 8 commits into from
Oct 3, 2024

Conversation

Ch0ronomato
Copy link
Contributor

@Ch0ronomato Ch0ronomato commented Aug 15, 2024

Description

Add the IfElse op support in torch (reopened cause i git screwup my old branch)

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Copy link

codecov bot commented Aug 15, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.74%. Comparing base (8a6e407) to head (eda2dbc).
Report is 115 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #974   +/-   ##
=======================================
  Coverage   81.74%   81.74%           
=======================================
  Files         183      183           
  Lines       47733    47742    +9     
  Branches    11616    11617    +1     
=======================================
+ Hits        39020    39029    +9     
- Misses       6518     6520    +2     
+ Partials     2195     2193    -2     
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/basic.py 94.44% <100.00%> (+0.50%) ⬆️

... and 2 files with indirect coverage changes


def ifelse(cond, *true_and_false, n_outs=n_outs):
if cond:
return torch.stack(true_and_false[:n_outs])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to stack, and shouldn't, because outputs can have different dimensions / sizes

Comment on lines +317 to +319
a = scalar("a")
x = ifelse(a < 0.4, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
x_fg = FunctionGraph([a], x)
Copy link
Member

@ricardoV94 ricardoV94 Sep 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to test twice, to cover the false case, do a case where the multiple outputs are not something that can be stacked (say (pt.zeros((3, 5), pt.ones(2,)).

I would have the other test case return a single output, to test the single output case as well

@Ch0ronomato Ch0ronomato requested a review from ricardoV94 October 2, 2024 19:04
@ricardoV94 ricardoV94 added torch PyTorch backend enhancement New feature or request labels Oct 3, 2024
@ricardoV94 ricardoV94 changed the title Add torch ifelse Add torch implementation of IfElse Oct 3, 2024
@ricardoV94 ricardoV94 merged commit 46fdc58 into pymc-devs:main Oct 3, 2024
60 checks passed
@ricardoV94
Copy link
Member

Thanks @Ch0ronomato

Ch0ronomato added a commit to Ch0ronomato/pytensor that referenced this pull request Nov 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants