-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scatter reduce decomposition #3008
Conversation
ef97199
to
8e5151f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2024-07-15 21:13:39.692683+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2024-07-15 21:15:36.206907+00:00
@@ -1163,11 +1163,13 @@
(
"scatter_reduce_amax_zero_dim_indexOne_constant",
0,
torch.tensor([[0, 1, 2, 0]]).cuda(),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(),
- {torch.ops.aten.amax.default,},
+ {
+ torch.ops.aten.amax.default,
+ },
torch.zeros(3, 5, dtype=torch.int32).cuda(),
"amax",
),
(
"scatter_reduce_amax_zero_dim_indexTwo_constant",
5dbc520
to
6c77c44
Compare
6c77c44
to
33e76dc
Compare
33e76dc
to
d137714
Compare
inputs, | ||
expected_ops=expected_ops, | ||
unexpected_ops=unexpected_ops, | ||
min_block_size=2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the min_block size = 2 here ? What does lower_graph_testing do ? Does it use our partitioning ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes lower_graph_testing uses our partition. It is used for returning the expected ops unseen and the seen unexpected ops. The default is 3, but sometimes the graph in the test case is too small, and the block size is lesser than 3 and it errors out. Thats why I set it to 1 or 2, since that is not something which we are testing explicitly in the test.
Let me try with the default 3. If it passes, I will remove the min_block_size =2
then.
@@ -243,6 +244,99 @@ def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: | |||
) | |||
|
|||
|
|||
# enum class for reduce operation of scatter_reduce | |||
class reduceOperation(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor - consider renaming it to ReduceOperation
9aea7dd
to
f0ccb92
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2024-07-30 02:56:59.084675+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2024-07-30 02:58:55.960957+00:00
@@ -1020,11 +1020,10 @@
0,
DECIMALS_OF_AGREEMENT,
f"Scatter_add TRT outputs don't match with the original model.",
)
-
@parameterized.expand(
[
############################sum###########################
(
"scatter_reduce_add_zero_dim_indexOne_constant",
4bf82d5
to
b6aa19d
Compare
# unsqueeze src and index in dim | ||
src_slice = torch.unsqueeze(src_slice, dim) | ||
index_slice = torch.unsqueeze(index_slice, dim) | ||
device = to_torch_device(default_device()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use the device where the input_tensor exists
64442d6
to
2ce4933
Compare
b689e76
to
020d32c
Compare
@apbose CI is failing on scatter tests |
print("Invalid Operation for Reduce op!!") | ||
|
||
operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor) | ||
device = to_torch_device(default_device()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use the device of initial_tensor here instead of default
648fa95
to
f124297
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py 2024-09-10 16:33:48.731288+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py 2024-09-10 16:34:09.607993+00:00
@@ -186,11 +186,11 @@
"""
device = None
for parameter in list(module.parameters()):
if isinstance(parameter, (torch.nn.parameter.Parameter, torch.Tensor)):
return parameter.device
-
+
for buffer in list(module.buffers()):
if isinstance(buffer, (torch.Tensor)):
return buffer.device
if device is None:
f124297
to
35f2b00
Compare
index: torch.Tensor, | ||
src_tensor: torch.Tensor, | ||
reduce: str, | ||
) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a kwarg include_self
in https://github.com/pytorch/pytorch/blob/bc1b8f094d24de27432f4c29f0729e85a6b5ba63/aten/src/ATen/native/native_functions.yaml#L8237. Is it intentionally not handled in our decomposition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review! Most of the cases which I have seen is with include_self = True
. Here we have the implementation with the default case. No particular reason, I could add cases with include_self = False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add include_self=True
in the function arguments. And raise an error saying we don't support the case when user sets it False
return parameter.device | ||
|
||
for buffer in list(module.buffers()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The buffer device overrides the parameter device here which shouldn't be the case. Check device of parameters first, if not found, use buffers.
Also consider adding break once the device is found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm
index: torch.Tensor, | ||
src_tensor: torch.Tensor, | ||
reduce: str, | ||
) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add include_self=True
in the function arguments. And raise an error saying we don't support the case when user sets it False
#2740 should be using this. Will change it once this PR is finalized