Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements shape Ops and MakeVector in PyTorch #926

Conversation

twaclaw
Copy link
Contributor

@twaclaw twaclaw commented Jul 12, 2024

Description

Implements

  • Shape

  • Shape_i

  • Reshape

  • SpecifyShape

  • Unbroadcast

  • MakeVector

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Copy link

codecov bot commented Jul 12, 2024

Codecov Report

Attention: Patch coverage is 80.85106% with 9 lines in your changes missing coverage. Please review.

Project coverage is 81.40%. Comparing base (72c6a81) to head (0786f2c).
Report is 98 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/shape.py 74.28% 8 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #926      +/-   ##
==========================================
+ Coverage   81.38%   81.40%   +0.01%     
==========================================
  Files         172      173       +1     
  Lines       46868    46914      +46     
  Branches    11423    11426       +3     
==========================================
+ Hits        38145    38188      +43     
- Misses       6540     6544       +4     
+ Partials     2183     2182       -1     
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/pytorch/dispatch/basic.py 89.74% <100.00%> (+4.44%) ⬆️
pytensor/link/pytorch/dispatch/shape.py 74.28% <74.28%> (ø)

... and 4 files with indirect coverage changes

Comment on lines 7 to 12
@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
shape = node.inputs[1]

def reshape(x, shape=shape):
return torch.reshape(x, tuple(shape))

return reshape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to use the runtime shape, since it is not always a constant.

Suggested change
@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
shape = node.inputs[1]
def reshape(x, shape=shape):
return torch.reshape(x, tuple(shape))
return reshape
@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
def reshape(x, shape):
return torch.reshape(x, tuple(shape))
return reshape

compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])


def test_pytorch_Reshape_shape_graph_input():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_pytorch_Reshape_shape_graph_input():
def test_pytorch_Reshape_dynamic():

Comment on lines 58 to 59
x = DeepCopyOp()(pt.as_tensor_variable(1.1))
x_fg = FunctionGraph([], [x])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to make sure DeepCopy does the expected thing. For instance here is how we know it is currently not doing the right thing in the Numba backend: #50

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the DeepCopyOp actually in the scope of these changes?
Not sure why DeepCopy and View were being tested together with Unbroadcast (I adapted the tests from the JAX backend implementation).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can leave then out if you didn't mean to implement it

Copy link
Contributor

@Ch0ronomato Ch0ronomato Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw, i tested the torch clone and it seems like it's fine

import pytensor
from pytensor import tensor as pt

x = pytensor.shared(0, name="x")
f = pytensor.function([], x, mode=None)
f().itemset(2)
assert x.get_value() == 0

f = pytensor.function([], x, mode="PYTORCH")
f().apply_(lambda _: 2)
assert x.get_value() == 0

Comment on lines 12 to 17
@singledispatch
def pytorch_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
return torch.as_tensor(data, dtype=dtype)
if data is not None:
return torch.as_tensor(data, dtype=dtype)
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should dispatch on NoneType:

@pytorch_typify.register(NoneType):
def pytorch_typify_None(data, **kwargs):
  return None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Where should this go?
  • In addition to condition in def pytorch_typify(data, dtype=None, **kwargs): ... ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be fine in dispatch.basic.py.

If you dispatch you don't need the if, because the dispatch mechanism already chooses which function to call based on the type of data

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried something like:

from pytensor.tensor.type_other import NoneTypeT
@pytorch_typify.register(NoneTypeT)
def pytorch_typify_None(data, **kwargs):
    return None

but the condition is still required. I checked the JAX backend for reference and there is also a similar condition there for pytorch_typify(data ...

Copy link
Contributor Author

@twaclaw twaclaw Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A different issue surfaced. I don't know whether it is related to these changes or not.

Calling repeat with axis=None results in an error in ElementWise at this point; namely:

torch._dynamo.exc.InternalTorchDynamoError: 'int' object has no attribute 'shape'

the values of inputs is (tensor(3), 3, 2). The corresponding graph is shown below.

Reshape{1} [id A] <Vector(float64, shape=(?,))> 8
 ├─ Alloc [id B] <Matrix(float64, shape=(?, 3))> 7
 │  ├─ ExpandDims{axis=1} [id C] <Matrix(float64, shape=(?, 1))> 6
 │  │  └─ Reshape{1} [id D] <Vector(float64, shape=(?,))> 5
 │  │     ├─ a [id E] <Matrix(float64, shape=(?, ?))>
 │  │     └─ [-1] [id F] <Vector(int64, shape=(1,))>
 │  ├─ Mul [id G] <Scalar(int64, shape=())> 4
 │  │  ├─ Shape_i{0} [id H] <Scalar(int64, shape=())> 1
 │  │  │  └─ a [id E] <Matrix(float64, shape=(?, ?))>
 │  │  └─ Shape_i{1} [id I] <Scalar(int64, shape=())> 0
 │  │     └─ a [id E] <Matrix(float64, shape=(?, ?))>
 │  └─ 3 [id J] <Scalar(int64, shape=())>
 └─ MakeVector{dtype='int64'} [id K] <Vector(int64, shape=(1,))> 3
    └─ Mul [id L] <Scalar(int64, shape=())> 2
       ├─ 3 [id J] <Scalar(int64, shape=())>
       ├─ Shape_i{0} [id H] <Scalar(int64, shape=())> 1
       │  └─ ···
       └─ Shape_i{1} [id I] <Scalar(int64, shape=())> 0
          └─ ···

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typify should be registered on NoneType (python) not NoneTypeT (PyTensor)

Copy link
Member

@ricardoV94 ricardoV94 Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding your other error, there shouldn't be integers 3 and 2 as inputs, it should be tensor(2) and tensor(3).

We need to track where those come from and fix it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably Shape_i is not correctly implemented and is returning integers?

twaclaw added 2 commits July 15, 2024 19:45
- Shape
- Shape_i
- Reshape
- SpecifyShape
- Unbroadcast

- MakeVector
@twaclaw twaclaw force-pushed the implement_shape_ops_and_makevector_in_pytorch branch from bf50423 to 0e455fd Compare July 16, 2024 18:21
- Fixed Shape_i
- Typified Python NoneType
@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend labels Jul 17, 2024
@ricardoV94 ricardoV94 merged commit 426931b into pymc-devs:main Jul 17, 2024
58 of 59 checks passed
@ricardoV94
Copy link
Member

Awesome stuff @twaclaw

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement all Ops in PyTorch (help welcome!)
3 participants