|
1 | 1 | import logging
|
| 2 | +from enum import Enum, auto |
2 | 3 | from typing import Any, Callable, Dict, List, Optional
|
3 | 4 |
|
4 | 5 | import torch
|
@@ -287,6 +288,101 @@ def scatter_add_decomposition(
|
287 | 288 | return scatter_add_tensor
|
288 | 289 |
|
289 | 290 |
|
| 291 | +# enum class for reduce operation of scatter_reduce |
| 292 | +class ReduceOperation(Enum): |
| 293 | + SUM = ("Sum reduce operation", lambda x, y: torch.add(x, y)) |
| 294 | + PROD = ("Product reduce operation", lambda x, y: torch.mul(x, y)) |
| 295 | + MEAN = ("Mean reduce operation", lambda x, y: torch.add(x, y)) |
| 296 | + AMAX = ("Amax reduce operation", lambda x, y: torch.max(x, y)) |
| 297 | + AMIN = ("Amin reduce operation", lambda x, y: torch.min(x, y)) |
| 298 | + |
| 299 | + def __new__(cls, description, func): |
| 300 | + obj = object.__new__(cls) |
| 301 | + obj._value_ = auto() |
| 302 | + obj.description = description |
| 303 | + obj.func = func |
| 304 | + return obj |
| 305 | + |
| 306 | + def reduce_operation_with_scatter( |
| 307 | + self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor |
| 308 | + ): |
| 309 | + scatter_tensor = None |
| 310 | + if self == ReduceOperation.SUM or self == ReduceOperation.MEAN: |
| 311 | + scatter_tensor = torch.zeros_like(initial_tensor) |
| 312 | + elif self == ReduceOperation.PROD: |
| 313 | + scatter_tensor = torch.ones_like(initial_tensor) |
| 314 | + elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX: |
| 315 | + scatter_tensor = initial_tensor |
| 316 | + else: |
| 317 | + # This case would not be encountered from torch itself |
| 318 | + print("Invalid Operation for Reduce op!!") |
| 319 | + |
| 320 | + operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor) |
| 321 | + device = to_torch_device(scatter_tensor.device) |
| 322 | + operation_lhs = operation_lhs.to(device) |
| 323 | + operation_rhs = operation_rhs.to(device) |
| 324 | + return self.func(operation_lhs, operation_rhs) |
| 325 | + |
| 326 | + |
| 327 | +@register_torch_trt_decomposition( |
| 328 | + torch.ops.aten.scatter_reduce.two, registry=TORCH_TRT_DECOMPOSITIONS |
| 329 | +) |
| 330 | +def scatter_reduce_decomposition( |
| 331 | + input_tensor: torch.Tensor, |
| 332 | + dim: int, |
| 333 | + index: torch.Tensor, |
| 334 | + src_tensor: torch.Tensor, |
| 335 | + reduce: str, |
| 336 | + include_self: bool = True, |
| 337 | +) -> torch.Tensor: |
| 338 | + scatter_loop_tensor = input_tensor |
| 339 | + device_input_tensor = input_tensor.device |
| 340 | + # required for mean reduce operation |
| 341 | + scatter_count_tensor = torch.zeros_like(input_tensor) |
| 342 | + src_shape = list(src_tensor.shape) |
| 343 | + src_dim = src_shape[dim] |
| 344 | + if include_self == False: |
| 345 | + raise AssertionError("include_self False for scatter reduce not yet supported") |
| 346 | + for i in range(0, src_dim): |
| 347 | + src_slice = torch.select(src_tensor, dim, i) |
| 348 | + index_slice = torch.select(index, dim, i) |
| 349 | + # unsqueeze src and index in dim |
| 350 | + src_slice = torch.unsqueeze(src_slice, dim) |
| 351 | + index_slice = torch.unsqueeze(index_slice, dim) |
| 352 | + |
| 353 | + # moving tensor to default device |
| 354 | + scatter_loop_tensor = scatter_loop_tensor.to(device_input_tensor) |
| 355 | + index_slice = index_slice.to(device_input_tensor) |
| 356 | + src_slice = src_slice.to(device_input_tensor) |
| 357 | + if reduce == "sum": |
| 358 | + reduceOp = ReduceOperation.SUM |
| 359 | + elif reduce == "prod": |
| 360 | + reduceOp = ReduceOperation.PROD |
| 361 | + elif reduce == "mean": |
| 362 | + reduceOp = ReduceOperation.MEAN |
| 363 | + scatter_count_tensor = reduceOp.reduce_operation_with_scatter( |
| 364 | + scatter_count_tensor, |
| 365 | + input_tensor, |
| 366 | + dim, |
| 367 | + index_slice, |
| 368 | + torch.ones_like(src_slice), |
| 369 | + ) |
| 370 | + elif reduce == "amax": |
| 371 | + reduceOp = ReduceOperation.AMAX |
| 372 | + elif reduce == "amin": |
| 373 | + reduceOp = ReduceOperation.AMIN |
| 374 | + scatter_loop_tensor = reduceOp.reduce_operation_with_scatter( |
| 375 | + scatter_loop_tensor, input_tensor, dim, index_slice, src_slice |
| 376 | + ) |
| 377 | + if reduce == "mean": |
| 378 | + scatter_loop_tensor = torch.div( |
| 379 | + scatter_loop_tensor, |
| 380 | + torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)), |
| 381 | + rounding_mode="trunc", |
| 382 | + ) |
| 383 | + return scatter_loop_tensor |
| 384 | + |
| 385 | + |
290 | 386 | def get_decompositions(
|
291 | 387 | enable_experimental_decompositions: bool = False,
|
292 | 388 | ) -> Dict[OpOverload, Callable[[Any], Any]]:
|
|
0 commit comments