Skip to content

Commit 501a1e1

Browse files
authored
scatter reduce decomposition (#3008)
1 parent 7c56d58 commit 501a1e1

File tree

4 files changed

+507
-4
lines changed

4 files changed

+507
-4
lines changed

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, Union
22

33
import numpy as np
4+
import tensorrt as trt
45
import torch
56
import torch_tensorrt.dynamo.conversion.impl as impl
67
from torch.fx.node import Target
@@ -17,6 +18,7 @@
1718
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1819
convert_binary_elementwise,
1920
)
21+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
2022
from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
2123
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
2224
from torch_tensorrt.fx.converters.converter_utils import broadcast
@@ -67,6 +69,11 @@ def trunc_div(
6769
prod_output,
6870
)
6971

72+
# cast the sign_output back to int32 for trunc div
73+
# This is required for scatter_reduce_.two(reduce='mean' where trunc_div casts it to float32 and TRTInterpreter expects int32)
74+
if (isinstance(sign_output, TRTTensor)) and (sign_output.dtype == trt.float32):
75+
sign_output = cast_trt_tensor(ctx, sign_output, trt.int32, name)
76+
7077
# Convert constant input into ITensor for UnaryOperation
7178
if not isinstance(input, trt.tensorrt.ITensor):
7279
input = get_trt_tensor(ctx, input, f"{name}_input")

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+96
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from enum import Enum, auto
23
from typing import Any, Callable, Dict, List, Optional
34

45
import torch
@@ -287,6 +288,101 @@ def scatter_add_decomposition(
287288
return scatter_add_tensor
288289

289290

291+
# enum class for reduce operation of scatter_reduce
292+
class ReduceOperation(Enum):
293+
SUM = ("Sum reduce operation", lambda x, y: torch.add(x, y))
294+
PROD = ("Product reduce operation", lambda x, y: torch.mul(x, y))
295+
MEAN = ("Mean reduce operation", lambda x, y: torch.add(x, y))
296+
AMAX = ("Amax reduce operation", lambda x, y: torch.max(x, y))
297+
AMIN = ("Amin reduce operation", lambda x, y: torch.min(x, y))
298+
299+
def __new__(cls, description, func):
300+
obj = object.__new__(cls)
301+
obj._value_ = auto()
302+
obj.description = description
303+
obj.func = func
304+
return obj
305+
306+
def reduce_operation_with_scatter(
307+
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor
308+
):
309+
scatter_tensor = None
310+
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
311+
scatter_tensor = torch.zeros_like(initial_tensor)
312+
elif self == ReduceOperation.PROD:
313+
scatter_tensor = torch.ones_like(initial_tensor)
314+
elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX:
315+
scatter_tensor = initial_tensor
316+
else:
317+
# This case would not be encountered from torch itself
318+
print("Invalid Operation for Reduce op!!")
319+
320+
operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
321+
device = to_torch_device(scatter_tensor.device)
322+
operation_lhs = operation_lhs.to(device)
323+
operation_rhs = operation_rhs.to(device)
324+
return self.func(operation_lhs, operation_rhs)
325+
326+
327+
@register_torch_trt_decomposition(
328+
torch.ops.aten.scatter_reduce.two, registry=TORCH_TRT_DECOMPOSITIONS
329+
)
330+
def scatter_reduce_decomposition(
331+
input_tensor: torch.Tensor,
332+
dim: int,
333+
index: torch.Tensor,
334+
src_tensor: torch.Tensor,
335+
reduce: str,
336+
include_self: bool = True,
337+
) -> torch.Tensor:
338+
scatter_loop_tensor = input_tensor
339+
device_input_tensor = input_tensor.device
340+
# required for mean reduce operation
341+
scatter_count_tensor = torch.zeros_like(input_tensor)
342+
src_shape = list(src_tensor.shape)
343+
src_dim = src_shape[dim]
344+
if include_self == False:
345+
raise AssertionError("include_self False for scatter reduce not yet supported")
346+
for i in range(0, src_dim):
347+
src_slice = torch.select(src_tensor, dim, i)
348+
index_slice = torch.select(index, dim, i)
349+
# unsqueeze src and index in dim
350+
src_slice = torch.unsqueeze(src_slice, dim)
351+
index_slice = torch.unsqueeze(index_slice, dim)
352+
353+
# moving tensor to default device
354+
scatter_loop_tensor = scatter_loop_tensor.to(device_input_tensor)
355+
index_slice = index_slice.to(device_input_tensor)
356+
src_slice = src_slice.to(device_input_tensor)
357+
if reduce == "sum":
358+
reduceOp = ReduceOperation.SUM
359+
elif reduce == "prod":
360+
reduceOp = ReduceOperation.PROD
361+
elif reduce == "mean":
362+
reduceOp = ReduceOperation.MEAN
363+
scatter_count_tensor = reduceOp.reduce_operation_with_scatter(
364+
scatter_count_tensor,
365+
input_tensor,
366+
dim,
367+
index_slice,
368+
torch.ones_like(src_slice),
369+
)
370+
elif reduce == "amax":
371+
reduceOp = ReduceOperation.AMAX
372+
elif reduce == "amin":
373+
reduceOp = ReduceOperation.AMIN
374+
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
375+
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
376+
)
377+
if reduce == "mean":
378+
scatter_loop_tensor = torch.div(
379+
scatter_loop_tensor,
380+
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)),
381+
rounding_mode="trunc",
382+
)
383+
return scatter_loop_tensor
384+
385+
290386
def get_decompositions(
291387
enable_experimental_decompositions: bool = False,
292388
) -> Dict[OpOverload, Callable[[Any], Any]]:

py/torch_tensorrt/dynamo/utils.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@
66
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
77

88
import numpy as np
9+
import tensorrt as trt
910
import torch
1011
from torch._subclasses.fake_tensor import FakeTensor
1112
from torch_tensorrt._Device import Device
1213
from torch_tensorrt._enums import dtype
1314
from torch_tensorrt._Input import Input
1415
from torch_tensorrt.dynamo import _defaults
16+
from torch_tensorrt.dynamo._defaults import default_device
1517
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
1618
from torch_tensorrt.dynamo._settings import CompilationSettings
1719

18-
import tensorrt as trt
1920
from packaging import version
2021

2122
from .types import TRTDataType
@@ -186,11 +187,14 @@ def get_model_device(module: torch.fx.GraphModule) -> torch.device:
186187
device = None
187188
for parameter in list(module.parameters()):
188189
if isinstance(parameter, (torch.nn.parameter.Parameter, torch.Tensor)):
189-
device = parameter.device
190-
break
190+
return parameter.device
191+
192+
for buffer in list(module.buffers()):
193+
if isinstance(buffer, (torch.Tensor)):
194+
return buffer.device
191195

192196
if device is None:
193-
device = torch.device("cpu")
197+
device = to_torch_device(default_device())
194198
logger.warning(
195199
"Could not detect the device on which the model exists. Assuming the model is on CPU"
196200
)

0 commit comments

Comments
 (0)