Skip to content

Commit 0f8f23d

Browse files
authored
Dynamic shape index (#3039)
1 parent 95cc532 commit 0f8f23d

File tree

2 files changed

+181
-154
lines changed

2 files changed

+181
-154
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

+53-13
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
)
1818
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
1919
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
20+
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
21+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
2022
from torch_tensorrt.fx.converters.converter_utils import (
2123
has_dynamic_shape,
2224
set_layer_name,
@@ -111,17 +113,18 @@ def index(
111113
else:
112114
input_shape = input.shape
113115
_LOGGER.debug(f"The input shape is {input.shape}")
114-
if dynamic_shape:
115-
input_shape = get_shape_with_dynamic_shape(
116-
ctx.net, target, source_ir, name, input_shape, input
117-
)
118116
rank = len(input_shape)
119117
adv_indx_count = len(adv_indx_indices)
120118
dim_tensor_list = []
121119

122120
for i in range(rank):
123-
dim = input_shape[i]
124-
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
121+
if input_shape[i] != DYNAMIC_DIM:
122+
dim = input_shape[i]
123+
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
124+
else:
125+
dim_tensor = get_shape(
126+
ctx, target, source_ir, name + f"_individual_dim_dyn_{i}", input, i
127+
)
125128
# dim_tensor_list is a list of tensors
126129
dim_tensor_list.append(dim_tensor)
127130

@@ -150,12 +153,53 @@ def index(
150153
# transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
151154
transpose_tensor_shape = transpose_tensor.shape
152155
_LOGGER.debug(f"The shape of transpose tensor is {transpose_tensor_shape}")
156+
153157
mult_d0 = 1
158+
dim_tensor_shape_mult_d0 = 1
154159
for i in range(adv_indx_count):
155-
mult_d0 = mult_d0 * transpose_tensor_shape[i]
160+
if transpose_tensor_shape[i] == DYNAMIC_DIM:
161+
dim_tensor_shape_mult_d0 = get_shape(
162+
ctx,
163+
target,
164+
source_ir,
165+
name + f"_transpose_tensor_shape_mult_d0_{i}",
166+
transpose_tensor,
167+
i,
168+
)
169+
else:
170+
dim_tensor_shape_mult_d0 = transpose_tensor_shape[i]
171+
mult_d0 = convert_binary_elementwise(
172+
ctx,
173+
target,
174+
source_ir,
175+
name + f"_shape_{i}",
176+
trt.ElementWiseOperation.PROD,
177+
mult_d0,
178+
dim_tensor_shape_mult_d0,
179+
)
156180
mult_d1 = 1
181+
dim_tensor_shape_mult_d1 = 1
157182
for i in range(adv_indx_count, rank):
158-
mult_d1 = mult_d1 * transpose_tensor_shape[i]
183+
if transpose_tensor_shape[i] == DYNAMIC_DIM:
184+
dim_tensor_shape_mult_d1 = get_shape(
185+
ctx,
186+
target,
187+
source_ir,
188+
name + f"_transpose_tensor_shape_mult_d0_{i}",
189+
transpose_tensor,
190+
i,
191+
)
192+
else:
193+
dim_tensor_shape_mult_d1 = transpose_tensor_shape[i]
194+
mult_d1 = convert_binary_elementwise(
195+
ctx,
196+
target,
197+
source_ir,
198+
name + f"_shape_{i}",
199+
trt.ElementWiseOperation.PROD,
200+
mult_d1,
201+
dim_tensor_shape_mult_d1,
202+
)
159203

160204
concat_tensor_layer = ctx.net.add_concatenation(
161205
[
@@ -185,11 +229,7 @@ def index(
185229
ctx, cum_adv_index, name + "_index_sum_intermediate"
186230
)
187231
else:
188-
multiplier = get_trt_tensor(
189-
ctx,
190-
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
191-
name + "_dim_last",
192-
)
232+
multiplier = dim_tensor_list[adv_indx_indices[adv_indx_count - 1]]
193233
cum_adv_index = tensor_indices[adv_indx_count - 1]
194234
for i in range(adv_indx_count - 2, -1, -1):
195235
adv_index = convert_binary_elementwise(

0 commit comments

Comments
 (0)