Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scatter reduce decomposition #3008

Merged
merged 5 commits into from
Sep 11, 2024
Merged

scatter reduce decomposition #3008

merged 5 commits into from
Sep 11, 2024

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Jul 15, 2024

#2740 should be using this. Will change it once this PR is finalized

@apbose apbose requested a review from peri044 July 15, 2024 21:10
@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jul 15, 2024
@apbose apbose force-pushed the scatter_reduce_decomposition branch from ef97199 to 8e5151f Compare July 15, 2024 21:13
@apbose apbose marked this pull request as draft July 15, 2024 21:14
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py	2024-07-15 21:13:39.692683+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py	2024-07-15 21:15:36.206907+00:00
@@ -1163,11 +1163,13 @@
            (
                "scatter_reduce_amax_zero_dim_indexOne_constant",
                0,
                torch.tensor([[0, 1, 2, 0]]).cuda(),
                torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(),
-                {torch.ops.aten.amax.default,},
+                {
+                    torch.ops.aten.amax.default,
+                },
                torch.zeros(3, 5, dtype=torch.int32).cuda(),
                "amax",
            ),
            (
                "scatter_reduce_amax_zero_dim_indexTwo_constant",

@apbose apbose force-pushed the scatter_reduce_decomposition branch 3 times, most recently from 5dbc520 to 6c77c44 Compare July 17, 2024 17:11
@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters labels Jul 17, 2024
@apbose apbose force-pushed the scatter_reduce_decomposition branch from 6c77c44 to 33e76dc Compare July 17, 2024 17:13
@apbose apbose marked this pull request as ready for review July 17, 2024 17:14
@apbose apbose force-pushed the scatter_reduce_decomposition branch from 33e76dc to d137714 Compare July 17, 2024 17:23
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the min_block size = 2 here ? What does lower_graph_testing do ? Does it use our partitioning ?

Copy link
Collaborator Author

@apbose apbose Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes lower_graph_testing uses our partition. It is used for returning the expected ops unseen and the seen unexpected ops. The default is 3, but sometimes the graph in the test case is too small, and the block size is lesser than 3 and it errors out. Thats why I set it to 1 or 2, since that is not something which we are testing explicitly in the test.
Let me try with the default 3. If it passes, I will remove the min_block_size =2 then.

@@ -243,6 +244,99 @@ def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
)


# enum class for reduce operation of scatter_reduce
class reduceOperation(Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor - consider renaming it to ReduceOperation

@apbose apbose force-pushed the scatter_reduce_decomposition branch 3 times, most recently from 9aea7dd to f0ccb92 Compare July 30, 2024 02:58
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py	2024-07-30 02:56:59.084675+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py	2024-07-30 02:58:55.960957+00:00
@@ -1020,11 +1020,10 @@
            0,
            DECIMALS_OF_AGREEMENT,
            f"Scatter_add TRT outputs don't match with the original model.",
        )

-
    @parameterized.expand(
        [
            ############################sum###########################
            (
                "scatter_reduce_add_zero_dim_indexOne_constant",

@apbose apbose force-pushed the scatter_reduce_decomposition branch 2 times, most recently from 4bf82d5 to b6aa19d Compare August 6, 2024 00:00
py/torch_tensorrt/dynamo/lowering/_decompositions.py Outdated Show resolved Hide resolved
# unsqueeze src and index in dim
src_slice = torch.unsqueeze(src_slice, dim)
index_slice = torch.unsqueeze(index_slice, dim)
device = to_torch_device(default_device())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use the device where the input_tensor exists

@apbose apbose force-pushed the scatter_reduce_decomposition branch 2 times, most recently from 64442d6 to 2ce4933 Compare August 22, 2024 23:49
@apbose apbose requested a review from peri044 August 27, 2024 15:44
@apbose apbose force-pushed the scatter_reduce_decomposition branch 2 times, most recently from b689e76 to 020d32c Compare August 30, 2024 07:54
@peri044
Copy link
Collaborator

peri044 commented Sep 3, 2024

@apbose CI is failing on scatter tests

print("Invalid Operation for Reduce op!!")

operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
device = to_torch_device(default_device())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the device of initial_tensor here instead of default

@apbose apbose force-pushed the scatter_reduce_decomposition branch 3 times, most recently from 648fa95 to f124297 Compare September 10, 2024 16:33
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2024-09-10 16:33:48.731288+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2024-09-10 16:34:09.607993+00:00
@@ -186,11 +186,11 @@
    """
    device = None
    for parameter in list(module.parameters()):
        if isinstance(parameter, (torch.nn.parameter.Parameter, torch.Tensor)):
            return parameter.device
-    
+
    for buffer in list(module.buffers()):
        if isinstance(buffer, (torch.Tensor)):
            return buffer.device

    if device is None:

index: torch.Tensor,
src_tensor: torch.Tensor,
reduce: str,
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a kwarg include_self in https://github.com/pytorch/pytorch/blob/bc1b8f094d24de27432f4c29f0729e85a6b5ba63/aten/src/ATen/native/native_functions.yaml#L8237. Is it intentionally not handled in our decomposition?

Copy link
Collaborator Author

@apbose apbose Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! Most of the cases which I have seen is with include_self = True. Here we have the implementation with the default case. No particular reason, I could add cases with include_self = False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add include_self=True in the function arguments. And raise an error saying we don't support the case when user sets it False

Comment on lines +190 to +192
return parameter.device

for buffer in list(module.buffers()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buffer device overrides the parameter device here which shouldn't be the case. Check device of parameters first, if not found, use buffers.
Also consider adding break once the device is found.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm

index: torch.Tensor,
src_tensor: torch.Tensor,
reduce: str,
) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add include_self=True in the function arguments. And raise an error saying we don't support the case when user sets it False

@apbose apbose merged commit 501a1e1 into main Sep 11, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants