Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch support for Join and Careduce Ops #869

Merged
merged 11 commits into from
Jul 4, 2024

Conversation

HarshvirSandhu
Copy link
Contributor

Description

Add pytorch support for common CARdeuce Ops (sum, prod, all, any, ...)

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]
view = op.view
Copy link
Member

@ricardoV94 ricardoV94 Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to remove this kwarg, so you can just raise NotImplementedErrror if the Op has it set in the outer dispatch function. #753

Copy link

codecov bot commented Jun 28, 2024

Codecov Report

Attention: Patch coverage is 62.00000% with 19 lines in your changes missing coverage. Please review.

Project coverage is 81.01%. Comparing base (17fa8b1) to head (35153f7).
Report is 147 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/elemwise.py 58.13% 18 Missing ⚠️
pytensor/link/pytorch/dispatch/basic.py 85.71% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #869      +/-   ##
==========================================
+ Coverage   80.98%   81.01%   +0.02%     
==========================================
  Files         169      169              
  Lines       47025    46894     -131     
  Branches    11501    11492       -9     
==========================================
- Hits        38084    37989      -95     
+ Misses       6724     6698      -26     
+ Partials     2217     2207      -10     
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/basic.py 85.71% <85.71%> (-0.29%) ⬇️
pytensor/link/pytorch/dispatch/elemwise.py 64.58% <58.13%> (-5.23%) ⬇️

... and 36 files with indirect coverage changes

@ricardoV94 ricardoV94 added torch PyTorch backend enhancement New feature or request labels Jun 28, 2024
@ricardoV94 ricardoV94 changed the title Pytorch support for Careduce Ops Pytorch support for Join and Careduce Ops Jun 28, 2024

@pytorch_funcify.register(Prod)
def pytorch_funcify_prod(op, **kwargs):
dim = op.axis[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why axis[0]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op.axis is a tuple, pytorch expects integers for dim

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we need to change the logic, because it is possible for them to be tuples with more than one entry

Copy link
Contributor Author

@HarshvirSandhu HarshvirSandhu Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do something like this:

for d in op.axis:
     x=torch.prod(x, dim=d, keepdim=True) # To make sure constant shape, we can reshape in the end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you reduce in reversed order you don't have to worry about the keepdims. Sounds good, a bit surprising that they don't support multiple axes

pytensor/link/pytorch/dispatch/elemwise.py Show resolved Hide resolved
@@ -57,6 +58,72 @@ def test_pytorch_elemwise():
compare_pytorch_and_py(fg, [[0.9, 0.9]])


@pytest.mark.parametrize("axis", [None, 0, 1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we parametrize these tests with the reduce function? Since they all look the same, we can reduce a bunch of lines. Or at least separate only those that need numerical inputs from those that need boolean (all and any).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I would like to test axis = (1, 2), and have a_pt be a tensor3, so that we cover the case with more than 1 axis, but not all of them.

@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
def test_pytorch_any(axis):
def test_pytorch_any_all(fn, axis):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_pytorch_any_all(fn, axis):
def test_pytorch_careduce_bool(fn, axis):

@@ -58,8 +58,9 @@ def test_pytorch_elemwise():
compare_pytorch_and_py(fg, [[0.9, 0.9]])


@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min])
@pytest.mark.parametrize("axis", [0, 1, (0, 1), (1, 2), (1, -1)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is sufficient

Suggested change
@pytest.mark.parametrize("axis", [0, 1, (0, 1), (1, 2), (1, -1)])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, -1)])

for d in op.axis:
x = torch.prod(x, dim=d, keepdim=True)
return x.squeeze()
for d in op.axis[::-1]:
Copy link
Member

@ricardoV94 ricardoV94 Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More readable?

Suggested change
for d in op.axis[::-1]:
for d in sorted(op.axis, reverse=True):

Same for the others

@ricardoV94 ricardoV94 merged commit e57e25b into pymc-devs:main Jul 4, 2024
56 of 57 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants