Skip to content

Commit

Permalink
scatter reduce decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Jul 17, 2024
1 parent ca4b263 commit 6c77c44
Show file tree
Hide file tree
Showing 3 changed files with 510 additions and 1 deletion.
7 changes: 6 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional, Union

import numpy as np
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
Expand All @@ -19,6 +18,7 @@
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
convert_binary_elementwise,
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
from torch_tensorrt.fx.converters.converter_utils import broadcast
Expand Down Expand Up @@ -67,6 +67,11 @@ def trunc_div(
prod_output,
)

# cast the sign_output back to int32 for trunc div
# This is required for scatter_reduce_.two(reduce='mean' where trunc_div casts it to float32 and TRTInterpreter expects int32)
if (isinstance(sign_output, TRTTensor)) and (sign_output.dtype == trt.float32):
sign_output = cast_trt_tensor(ctx, sign_output, trt.int32, name)

# Convert constant input into ITensor for UnaryOperation
if not isinstance(input, trt.tensorrt.ITensor):
input = get_trt_tensor(ctx, input, f"{name}_input")
Expand Down
94 changes: 94 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional

import torch
Expand Down Expand Up @@ -243,6 +244,99 @@ def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
)


# enum class for reduce operation of scatter_reduce
class reduceOperation(Enum):
SUM = ("Sum reduce operation", lambda x, y: torch.add(x, y))
PROD = ("Product reduce operation", lambda x, y: torch.mul(x, y))
MEAN = ("Mean reduce operation", lambda x, y: torch.add(x, y))
AMAX = ("Amax reduce operation", lambda x, y: torch.amax(x, y))
AMIN = ("Amin reduce operation", lambda x, y: torch.amin(x, y))

def __new__(cls, description, func):
obj = object.__new__(cls)
obj._value_ = auto() # Assign a unique value based on the number of members
obj.description = description
obj.func = func
return obj

def reduce_operation_with_scatter(
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor
):
scatter_tensor = None
if self == reduceOperation.SUM or self == reduceOperation.MEAN:
scatter_tensor = torch.zeros_like(initial_tensor)
elif self == reduceOperation.PROD:
scatter_tensor = torch.ones_like(initial_tensor)
elif self == reduceOperation.AMIN or self == reduceOperation.AMAX:
scatter_tensor = initial_tensor
else:
# This case would not be encountered from torch itself
print("Invalid Operation for Reduce op!!")

operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
device = to_torch_device(default_device())
operation_lhs = operation_lhs.to(device)
operation_rhs = operation_rhs.to(device)
return self.func(operation_lhs, operation_rhs)


@register_torch_trt_decomposition(
torch.ops.aten.scatter_reduce.two, registry=TORCH_TRT_DECOMPOSITIONS
)
def scatter_reduce_decomposition(
input_tensor: torch.Tensor,
dim: int,
index: torch.Tensor,
src_tensor: torch.Tensor,
reduce: str,
) -> torch.Tensor:
scatter_loop_tensor = input_tensor
# required for mean reduce operation
scatter_count_tensor = torch.zeros_like(input_tensor)
src_shape = list(src_tensor.shape)
src_dim = src_shape[dim]

for i in range(0, src_dim):
src_slice = torch.select(src_tensor, dim, i)
index_slice = torch.select(index, dim, i)
# unsqueeze src and index in dim
src_slice = torch.unsqueeze(src_slice, dim)
index_slice = torch.unsqueeze(index_slice, dim)
device = to_torch_device(default_device())

# moving tensor to default device
scatter_loop_tensor = scatter_loop_tensor.to(device)
index_slice = index_slice.to(device)
src_slice = src_slice.to(device)
if reduce == "sum":
reduceOp = reduceOperation.SUM
elif reduce == "prod":
reduceOp = reduceOperation.PROD
elif reduce == "mean":
reduceOp = reduceOperation.MEAN
scatter_count_tensor = reduceOp.reduce_operation_with_scatter(
scatter_count_tensor,
input_tensor,
dim,
index_slice,
torch.ones_like(src_slice),
)
elif reduce == "amax":
reduceOp = reduceOperation.AMAX
elif reduce == "amin":
reduceOp = reduceOperation.AMIN
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
)
if reduce == "mean":
scatter_loop_tensor = torch.div(
scatter_loop_tensor,
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)),
rounding_mode="trunc",
)
return scatter_loop_tensor


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
Loading

0 comments on commit 6c77c44

Please sign in to comment.