Skip to content

Commit f0ccb92

Browse files
committed
scatter reduce decomposition
1 parent 3215712 commit f0ccb92

File tree

3 files changed

+470
-17
lines changed

3 files changed

+470
-17
lines changed

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Optional, Union
22

33
import numpy as np
4-
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
54
import tensorrt as trt
65
import torch
76
import torch_tensorrt.dynamo.conversion.impl as impl
@@ -19,6 +18,7 @@
1918
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
2019
convert_binary_elementwise,
2120
)
21+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
2222
from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
2323
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
2424
from torch_tensorrt.fx.converters.converter_utils import broadcast
@@ -67,6 +67,11 @@ def trunc_div(
6767
prod_output,
6868
)
6969

70+
# cast the sign_output back to int32 for trunc div
71+
# This is required for scatter_reduce_.two(reduce='mean' where trunc_div casts it to float32 and TRTInterpreter expects int32)
72+
if (isinstance(sign_output, TRTTensor)) and (sign_output.dtype == trt.float32):
73+
sign_output = cast_trt_tensor(ctx, sign_output, trt.int32, name)
74+
7075
# Convert constant input into ITensor for UnaryOperation
7176
if not isinstance(input, trt.tensorrt.ITensor):
7277
input = get_trt_tensor(ctx, input, f"{name}_input")

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+94
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from enum import Enum, auto
23
from typing import Any, Callable, Dict, List, Optional
34

45
import torch
@@ -281,6 +282,99 @@ def scatter_add_decomposition(
281282
return scatter_add_tensor
282283

283284

285+
# enum class for reduce operation of scatter_reduce
286+
class ReduceOperation(Enum):
287+
SUM = ("Sum reduce operation", lambda x, y: torch.add(x, y))
288+
PROD = ("Product reduce operation", lambda x, y: torch.mul(x, y))
289+
MEAN = ("Mean reduce operation", lambda x, y: torch.add(x, y))
290+
AMAX = ("Amax reduce operation", lambda x, y: torch.max(x, y))
291+
AMIN = ("Amin reduce operation", lambda x, y: torch.min(x, y))
292+
293+
def __new__(cls, description, func):
294+
obj = object.__new__(cls)
295+
obj._value_ = auto()
296+
obj.description = description
297+
obj.func = func
298+
return obj
299+
300+
def reduce_operation_with_scatter(
301+
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor
302+
):
303+
scatter_tensor = None
304+
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
305+
scatter_tensor = torch.zeros_like(initial_tensor)
306+
elif self == ReduceOperation.PROD:
307+
scatter_tensor = torch.ones_like(initial_tensor)
308+
elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX:
309+
scatter_tensor = initial_tensor
310+
else:
311+
# This case would not be encountered from torch itself
312+
print("Invalid Operation for Reduce op!!")
313+
314+
operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
315+
device = to_torch_device(default_device())
316+
operation_lhs = operation_lhs.to(device)
317+
operation_rhs = operation_rhs.to(device)
318+
return self.func(operation_lhs, operation_rhs)
319+
320+
321+
@register_torch_trt_decomposition(
322+
torch.ops.aten.scatter_reduce.two, registry=TORCH_TRT_DECOMPOSITIONS
323+
)
324+
def scatter_reduce_decomposition(
325+
input_tensor: torch.Tensor,
326+
dim: int,
327+
index: torch.Tensor,
328+
src_tensor: torch.Tensor,
329+
reduce: str,
330+
) -> torch.Tensor:
331+
scatter_loop_tensor = input_tensor
332+
# required for mean reduce operation
333+
scatter_count_tensor = torch.zeros_like(input_tensor)
334+
src_shape = list(src_tensor.shape)
335+
src_dim = src_shape[dim]
336+
337+
for i in range(0, src_dim):
338+
src_slice = torch.select(src_tensor, dim, i)
339+
index_slice = torch.select(index, dim, i)
340+
# unsqueeze src and index in dim
341+
src_slice = torch.unsqueeze(src_slice, dim)
342+
index_slice = torch.unsqueeze(index_slice, dim)
343+
device = to_torch_device(default_device())
344+
345+
# moving tensor to default device
346+
scatter_loop_tensor = scatter_loop_tensor.to(device)
347+
index_slice = index_slice.to(device)
348+
src_slice = src_slice.to(device)
349+
if reduce == "sum":
350+
reduceOp = ReduceOperation.SUM
351+
elif reduce == "prod":
352+
reduceOp = ReduceOperation.PROD
353+
elif reduce == "mean":
354+
reduceOp = ReduceOperation.MEAN
355+
scatter_count_tensor = reduceOp.reduce_operation_with_scatter(
356+
scatter_count_tensor,
357+
input_tensor,
358+
dim,
359+
index_slice,
360+
torch.ones_like(src_slice),
361+
)
362+
elif reduce == "amax":
363+
reduceOp = ReduceOperation.AMAX
364+
elif reduce == "amin":
365+
reduceOp = ReduceOperation.AMIN
366+
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
367+
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
368+
)
369+
if reduce == "mean":
370+
scatter_loop_tensor = torch.div(
371+
scatter_loop_tensor,
372+
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)),
373+
rounding_mode="trunc",
374+
)
375+
return scatter_loop_tensor
376+
377+
284378
def get_decompositions(
285379
enable_experimental_decompositions: bool = False,
286380
) -> Dict[OpOverload, Callable[[Any], Any]]:

0 commit comments

Comments
 (0)