-
Notifications
You must be signed in to change notification settings - Fork 114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implements shape Ops and MakeVector in PyTorch #926
Implements shape Ops and MakeVector in PyTorch #926
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #926 +/- ##
==========================================
+ Coverage 81.38% 81.40% +0.01%
==========================================
Files 172 173 +1
Lines 46868 46914 +46
Branches 11423 11426 +3
==========================================
+ Hits 38145 38188 +43
- Misses 6540 6544 +4
+ Partials 2183 2182 -1
|
@pytorch_funcify.register(Reshape) | ||
def pytorch_funcify_Reshape(op, node, **kwargs): | ||
shape = node.inputs[1] | ||
|
||
def reshape(x, shape=shape): | ||
return torch.reshape(x, tuple(shape)) | ||
|
||
return reshape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to use the runtime shape, since it is not always a constant.
@pytorch_funcify.register(Reshape) | |
def pytorch_funcify_Reshape(op, node, **kwargs): | |
shape = node.inputs[1] | |
def reshape(x, shape=shape): | |
return torch.reshape(x, tuple(shape)) | |
return reshape | |
@pytorch_funcify.register(Reshape) | |
def pytorch_funcify_Reshape(op, node, **kwargs): | |
def reshape(x, shape): | |
return torch.reshape(x, tuple(shape)) | |
return reshape |
tests/link/pytorch/test_shape.py
Outdated
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) | ||
|
||
|
||
def test_pytorch_Reshape_shape_graph_input(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_pytorch_Reshape_shape_graph_input(): | |
def test_pytorch_Reshape_dynamic(): |
tests/link/pytorch/test_shape.py
Outdated
x = DeepCopyOp()(pt.as_tensor_variable(1.1)) | ||
x_fg = FunctionGraph([], [x]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to make sure DeepCopy does the expected thing. For instance here is how we know it is currently not doing the right thing in the Numba backend: #50
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the DeepCopyOp
actually in the scope of these changes?
Not sure why DeepCopy
and View
were being tested together with Unbroadcast
(I adapted the tests from the JAX backend implementation).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can leave then out if you didn't mean to implement it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fwiw, i tested the torch clone and it seems like it's fine
import pytensor
from pytensor import tensor as pt
x = pytensor.shared(0, name="x")
f = pytensor.function([], x, mode=None)
f().itemset(2)
assert x.get_value() == 0
f = pytensor.function([], x, mode="PYTORCH")
f().apply_(lambda _: 2)
assert x.get_value() == 0
@singledispatch | ||
def pytorch_typify(data, dtype=None, **kwargs): | ||
r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" | ||
return torch.as_tensor(data, dtype=dtype) | ||
if data is not None: | ||
return torch.as_tensor(data, dtype=dtype) | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should dispatch on NoneType
:
@pytorch_typify.register(NoneType):
def pytorch_typify_None(data, **kwargs):
return None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Where should this go?
- In addition to condition in
def pytorch_typify(data, dtype=None, **kwargs): ...
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be fine in dispatch.basic.py.
If you dispatch you don't need the if, because the dispatch mechanism already chooses which function to call based on the type of data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried something like:
from pytensor.tensor.type_other import NoneTypeT
@pytorch_typify.register(NoneTypeT)
def pytorch_typify_None(data, **kwargs):
return None
but the condition is still required. I checked the JAX backend for reference and there is also a similar condition there for pytorch_typify(data ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A different issue surfaced. I don't know whether it is related to these changes or not.
Calling repeat
with axis=None
results in an error in ElementWise
at this point; namely:
torch._dynamo.exc.InternalTorchDynamoError: 'int' object has no attribute 'shape'
the values of inputs
is (tensor(3), 3, 2)
. The corresponding graph is shown below.
Reshape{1} [id A] <Vector(float64, shape=(?,))> 8
├─ Alloc [id B] <Matrix(float64, shape=(?, 3))> 7
│ ├─ ExpandDims{axis=1} [id C] <Matrix(float64, shape=(?, 1))> 6
│ │ └─ Reshape{1} [id D] <Vector(float64, shape=(?,))> 5
│ │ ├─ a [id E] <Matrix(float64, shape=(?, ?))>
│ │ └─ [-1] [id F] <Vector(int64, shape=(1,))>
│ ├─ Mul [id G] <Scalar(int64, shape=())> 4
│ │ ├─ Shape_i{0} [id H] <Scalar(int64, shape=())> 1
│ │ │ └─ a [id E] <Matrix(float64, shape=(?, ?))>
│ │ └─ Shape_i{1} [id I] <Scalar(int64, shape=())> 0
│ │ └─ a [id E] <Matrix(float64, shape=(?, ?))>
│ └─ 3 [id J] <Scalar(int64, shape=())>
└─ MakeVector{dtype='int64'} [id K] <Vector(int64, shape=(1,))> 3
└─ Mul [id L] <Scalar(int64, shape=())> 2
├─ 3 [id J] <Scalar(int64, shape=())>
├─ Shape_i{0} [id H] <Scalar(int64, shape=())> 1
│ └─ ···
└─ Shape_i{1} [id I] <Scalar(int64, shape=())> 0
└─ ···
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typify should be registered on NoneType (python) not NoneTypeT (PyTensor)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding your other error, there shouldn't be integers 3 and 2 as inputs, it should be tensor(2) and tensor(3).
We need to track where those come from and fix it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably Shape_i is not correctly implemented and is returning integers?
- Shape - Shape_i - Reshape - SpecifyShape - Unbroadcast - MakeVector
bf50423
to
0e455fd
Compare
- Fixed Shape_i - Typified Python NoneType
Awesome stuff @twaclaw |
Description
Implements
Shape
Shape_i
Reshape
SpecifyShape
Unbroadcast
MakeVector
Related Issue
Checklist
Type of change