Skip to content

Commit 4a70020

Browse files
authored
addresses the case when shape of upsample tensor contains ITensor (#3841)
1 parent dae1ead commit 4a70020

File tree

3 files changed

+167
-40
lines changed

3 files changed

+167
-40
lines changed
Lines changed: 107 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Sequence, Union
1+
from typing import List, Optional, Sequence, Union
22

33
import numpy as np
44
import tensorrt as trt
@@ -11,11 +11,102 @@
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
1212
cast_trt_tensor,
1313
get_positive_dim,
14-
get_trt_tensor,
1514
set_layer_name,
1615
)
1716

1817

18+
def unify_and_concat_trt_tensors(
19+
ctx: ConversionContext,
20+
target: Target,
21+
name: str,
22+
inputs: Sequence[Union[int, np.ndarray, torch.Tensor, TRTTensor]],
23+
concat_axis: int,
24+
cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None,
25+
force_trt_output: bool = False,
26+
) -> Union[TRTTensor, List[int]]:
27+
"""
28+
Normalize all inputs to TRT tensors if needed, optionally cast, and concat if any dynamic.
29+
30+
Args:
31+
ctx: TensorRT conversion context.
32+
target: Operation Target.
33+
name: Operation Name.
34+
inputs: Sequence of ints / numpy arrays / torch tensors / TRT tensors.
35+
concat_axis: Axis along which to concatenate tensors if dynamic.
36+
cast_dtype: Optional target dtype for casting TRT tensors.
37+
force_trt_output: If True, return TRT tensor even if all inputs are static ints. (True for concat operations)
38+
"""
39+
has_dynamic = any(not isinstance(x, int) for x in inputs)
40+
trt_tensors = []
41+
42+
for i, x in enumerate(inputs):
43+
# convert to TRTTensor
44+
if isinstance(x, TRTTensor):
45+
t = x
46+
elif isinstance(x, int) and not has_dynamic and not force_trt_output:
47+
t = x # pure static path
48+
else:
49+
const_arr = np.array([x], dtype=np.int32)
50+
shape = (1,)
51+
if not isinstance(x, int):
52+
const_arr = np.array(x, dtype=np.int32)
53+
shape = (x.numel(),)
54+
55+
layer = ctx.net.add_constant(shape, const_arr)
56+
set_layer_name(layer, target, f"{name}_dim{i}_const")
57+
t = layer.get_output(0)
58+
trt_tensors.append(t)
59+
60+
if not has_dynamic and not force_trt_output:
61+
return trt_tensors # all ints
62+
63+
final_dtype = None
64+
if cast_dtype:
65+
# Explicit cast requested
66+
if isinstance(cast_dtype, _enums.dtype):
67+
final_dtype = cast_dtype.to(trt.DataType)
68+
elif isinstance(cast_dtype, (np.dtype, torch.dtype)):
69+
final_dtype = _enums.dtype._from(cast_dtype).to(trt.DataType)
70+
else:
71+
final_dtype = cast_dtype # already trt.DataType
72+
else:
73+
# Automatic promotion
74+
promoted_type = None
75+
for t in trt_tensors:
76+
if isinstance(t, TRTTensor):
77+
if promoted_type is None:
78+
promoted_type = t.dtype
79+
else:
80+
promoted_type = _enums.dtype._from(
81+
torch.promote_types(
82+
_enums.dtype._from(promoted_type).to(torch.dtype),
83+
_enums.dtype._from(t.dtype).to(torch.dtype),
84+
)
85+
).to(trt.DataType)
86+
final_dtype = promoted_type
87+
88+
# promote remaining ints to TRT consts before concat
89+
for i, t in enumerate(trt_tensors):
90+
if isinstance(t, int):
91+
const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32))
92+
set_layer_name(const, target, f"{name}_static_{i}_const")
93+
trt_tensors[i] = const.get_output(0)
94+
95+
# final cast
96+
if final_dtype is not None:
97+
casted = []
98+
for i, t in enumerate(trt_tensors):
99+
if isinstance(t, TRTTensor):
100+
t = cast_trt_tensor(ctx, t, final_dtype, f"{name}_cast_{i}")
101+
casted.append(t)
102+
trt_tensors = casted
103+
104+
concat = ctx.net.add_concatenation(trt_tensors)
105+
concat.axis = concat_axis
106+
set_layer_name(concat, target, f"{name}_concat")
107+
return concat.get_output(0)
108+
109+
19110
def cat(
20111
ctx: ConversionContext,
21112
target: Target,
@@ -25,38 +116,17 @@ def cat(
25116
dim: int,
26117
cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None,
27118
) -> Union[TRTTensor, Sequence[TRTTensor]]:
28-
trt_inputs = []
29-
for i, each_input in enumerate(input):
30-
if not isinstance(each_input, TRTTensor):
31-
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
32-
if cast_dtype:
33-
each_input = cast_trt_tensor(
34-
ctx, each_input, cast_dtype, f"{name}_tensor_int32_cast_{i}"
35-
)
36-
trt_inputs.append(each_input)
37-
38-
if len(trt_inputs) > 1:
39-
# Cast to promoted type for all inputs
40-
promoted_type = trt_inputs[0].dtype
41-
for each_input in trt_inputs[1:]:
42-
promoted_type = _enums.dtype._from(
43-
torch.promote_types(
44-
_enums.dtype._from(promoted_type).to(torch.dtype),
45-
_enums.dtype._from(each_input.dtype).to(torch.dtype),
46-
)
47-
)
48-
trt_promoted_type = promoted_type.to(trt.DataType)
49-
50-
trt_casted_inputs = []
51-
for i, each_input in enumerate(trt_inputs):
52-
casted_input = cast_trt_tensor(
53-
ctx, each_input, trt_promoted_type, f"{name}_input_casted_{i}"
54-
)
55-
trt_casted_inputs.append(casted_input)
56-
trt_inputs = trt_casted_inputs
57-
58-
concat_layer = ctx.net.add_concatenation(trt_inputs)
59-
dim = get_positive_dim(dim, len(trt_inputs[0].shape))
60-
concat_layer.axis = dim
61-
set_layer_name(concat_layer, target, f"{name}_gather", source_ir)
62-
return concat_layer.get_output(0)
119+
# int is only when cat called in other ops like pad
120+
if not isinstance(input[0], int):
121+
dim = get_positive_dim(dim, len(input[0].shape))
122+
else:
123+
dim = 0
124+
return unify_and_concat_trt_tensors(
125+
ctx,
126+
target,
127+
name,
128+
input,
129+
concat_axis=dim,
130+
cast_dtype=cast_dtype,
131+
force_trt_output=True,
132+
)

py/torch_tensorrt/dynamo/conversion/impl/upsample.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
has_dynamic_shape,
1010
set_layer_name,
1111
)
12-
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
12+
from torch_tensorrt.dynamo.conversion.impl.cat import (
13+
unify_and_concat_trt_tensors as unify_trt_shape_tensors,
14+
)
15+
from torch_tensorrt.dynamo.conversion.impl.shape import (
16+
get_shape_with_dynamic_shape,
17+
)
1318

1419

1520
def upsample(
@@ -28,14 +33,22 @@ def upsample(
2833
if scale_factor is not None:
2934
layer.scales = [1.0, 1.0] + list(scale_factor)
3035
else:
31-
shape = list(input.shape)[:2] + list(size)
36+
shape = list(input.shape)[:2]
37+
if size is not None:
38+
shape += list(size)
3239
if has_dynamic_shape(shape):
3340
shape = get_shape_with_dynamic_shape(
3441
ctx, target, source_ir, name, shape, input
3542
)
3643
layer.set_input(1, shape)
3744
else:
38-
layer.shape = shape
45+
trt_shape = unify_trt_shape_tensors(
46+
ctx, target, name, shape, concat_axis=0, force_trt_output=False
47+
)
48+
if isinstance(trt_shape, list):
49+
layer.shape = trt_shape
50+
else:
51+
layer.set_input(1, trt_shape)
3952

4053
if mode == "nearest":
4154
layer.resize_mode = trt.InterpolationMode.NEAREST

tests/py/dynamo/conversion/test_upsample_aten.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,50 @@ def forward(self, x):
296296
]
297297
self.run_test_with_dynamic_shape(TestModule(), input_specs)
298298

299+
@parameterized.expand(
300+
[
301+
([torch.tensor(3), 3], None),
302+
(None, [torch.tensor(0.5), 1.5]),
303+
]
304+
)
305+
def test_nearest2d_mixed_dynamic_shape(self, output_size, scale_factors):
306+
class TestModule(torch.nn.Module):
307+
def forward(self, x):
308+
out_size = output_size
309+
scale = scale_factors
310+
311+
return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale)
312+
313+
input_specs = [
314+
Input(
315+
min_shape=(1, 1, 1, 1),
316+
opt_shape=(5, 5, 5, 5),
317+
max_shape=(9, 9, 9, 9),
318+
dtype=torch.float32,
319+
)
320+
]
321+
self.run_test_with_dynamic_shape(TestModule(), input_specs)
322+
323+
@parameterized.expand(
324+
[
325+
# Mix of Tensor and int in output_size
326+
([torch.tensor(3), 3], None),
327+
# Mix of Tensor and float in scale_factors
328+
(None, [torch.tensor(0.5), 1.5]),
329+
]
330+
)
331+
def test_nearest2d_mixed_static_input(self, output_size, scale_factors):
332+
class TestModule(torch.nn.Module):
333+
def forward(self, x):
334+
out_size = output_size
335+
scale = scale_factors
336+
return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale)
337+
338+
input_size = [7, 7] # H, W
339+
inputs = [torch.randn([1, 1] + input_size)] # shape [1, 1, 7, 7]
340+
341+
self.run_test(TestModule(), inputs)
342+
299343

300344
if __name__ == "__main__":
301345
run_tests()

0 commit comments

Comments
 (0)