Skip to content

Commit

Permalink
feat: support adaptive_avg_pool1d dynamo converter (#2614)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Mar 21, 2024
1 parent 766c270 commit 3a5c39f
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 83 deletions.
18 changes: 18 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2184,6 +2184,24 @@ def aten_ops_avg_pool(
)


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default)
def aten_ops_adaptive_avg_pool(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pool.adaptive_avg_pool1d(
ctx,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
output_size=args[1],
)


def max_pool_param_validator(pool_node: Node) -> bool:
dilation = args_bounds_check(pool_node.args, 4, 1)
ceil_mode = args_bounds_check(pool_node.args, 5, False)
Expand Down
67 changes: 66 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional, Sequence, Union
import math
from typing import Dict, Optional, Sequence, Union

import tensorrt as trt
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand Down Expand Up @@ -104,3 +106,66 @@ def max_poolNd(

set_layer_name(pool_layer, target, name, source_ir)
return pool_layer.get_output(0)


def adaptive_avg_pool1d(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
output_size: Union[int, Sequence[int]],
) -> TRTTensor:
def start_index(idx: int, out_dim: int, in_dim: int) -> int:
"""Calculate the start index of each pooling window"""
return math.floor((float(idx) * float(in_dim)) / out_dim)

def end_index(idx: int, out_dim: int, in_dim: int) -> int:
"""Calculate the end index of each pooling window"""
return math.ceil((float(idx + 1) * float(in_dim)) / out_dim)

in_dim = input.shape[-1]
out_dim = output_size if isinstance(output_size, int) else output_size[0]
output_list = []

# store {index: slice} for reducing repeated slice ops
idx_slice_map: Dict[int, TRTTensor] = {}
# iterate over each output dimension
for i in range(out_dim):
# calculate the start and end index of each pooling window
start = start_index(i, out_dim, in_dim)
end = end_index(i, out_dim, in_dim)

# slice the input tensor from start to end index, the result of which is the window waiting for pooling
slices = []
for j in range(start, end):
if j in idx_slice_map:
slice = idx_slice_map[j]
else:
slice = impl.select.select(
ctx, target, source_ir, f"{name}_select_{j}", input, -1, j
)
slice = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_{i}_{j}",
slice,
(*slice.shape, 1),
)
idx_slice_map[j] = slice

slices.append(slice)

slices = impl.cat.cat(
ctx, target, source_ir, f"{name}_slices_cat_{i}", slices, dim=-1
)
# calculate the mean of the slices (average pooling output) and append to the output list
output_list.append(
impl.reduce.mean(
ctx, target, source_ir, f"{name}_sum_{i}", slices, dim=-1, keepdim=True
)
)

output = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", output_list, dim=-1)
return output
139 changes: 57 additions & 82 deletions tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,102 +9,77 @@
class TestAdaptiveAvgPoolConverter(DispatchTestCase):
@parameterized.expand(
[
((64, 64),),
((128, 64),),
# (64,), This case has been there in previous code but it isn't a valid pytorch code.
]
)
def test_adaptive_avgpool(
self,
output_size,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool2d(output_size)

def forward(self, x):
return self.pool(x)

inputs = [torch.randn(1, 3, 256, 256)]
self.run_test(
TestModule(),
inputs,
use_dynamo_tracer=True,
)

def test_adaptive_avgpool_with_dynamic_shape(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool2d((64, 64))

def forward(self, x):
return self.pool(x)

input_specs = [
Input(
shape=(-1, -1, 256, 256),
dtype=torch.float32,
shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))],
(
(2, 3),
2,
),
(
(2, 8),
8,
),
(
(1, 2, 3),
2,
),
(
(2, 2, 8),
16,
),
(
(2, 3),
(1,),
),
(
(2, 3),
(2,),
),
(
(2, 8),
(4,),
),
(
(2, 8),
(16,),
),
(
(2, 3, 1),
(1,),
),
(
(2, 3, 2),
(2,),
),
(
(2, 3, 4),
(4,),
),
(
(2, 2, 32),
(31,),
),
(
(2, 2, 32),
(64,),
),
]
self.run_test_with_dynamic_shape(
TestModule(), input_specs, use_dynamo_tracer=True
)

@parameterized.expand(
[
((16, 16, 16),),
((32, 16, 4),),
(32,),
]
)
def test_adaptive_avgpool3d(
def test_adaptive_avg_pool1d(
self,
input_shape,
output_size,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool3d(output_size)

def forward(self, x):
return self.pool(x)
return torch.ops.aten.adaptive_avg_pool1d.default(x, output_size)

inputs = [torch.randn(1, 3, 32, 64, 64)]
inputs = [torch.randn(input_shape)]
self.run_test(
TestModule(),
inputs,
use_dynamo_tracer=True,
# use_dynamo_tracer=True,
enable_passes=True,
)

def test_adaptive_avgpool3d_with_dynamic_shape(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16))

def forward(self, x):
return self.pool(x)

input_specs = [
Input(
shape=(-1, -1, 32, 64, 64),
dtype=torch.float32,
shape_ranges=[
((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64))
],
),
]
self.run_test_with_dynamic_shape(
TestModule(),
input_specs,
use_dynamo_tracer=True,
)

# Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."


if __name__ == "__main__":
run_tests()

0 comments on commit 3a5c39f

Please sign in to comment.