Skip to content

Commit

Permalink
Dynamic shape index
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Aug 29, 2024
1 parent 39f8255 commit 5dd825e
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 154 deletions.
64 changes: 51 additions & 13 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
)
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
set_layer_name,
Expand Down Expand Up @@ -111,17 +113,18 @@ def index(
else:
input_shape = input.shape
_LOGGER.debug(f"The input shape is {input.shape}")
if dynamic_shape:
input_shape = get_shape_with_dynamic_shape(
ctx.net, target, source_ir, name, input_shape, input
)
rank = len(input_shape)
adv_indx_count = len(adv_indx_indices)
dim_tensor_list = []

for i in range(rank):
dim = input_shape[i]
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
if input_shape[i] != DYNAMIC_DIM:
dim = input_shape[i]
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
else:
dim_tensor = get_shape(
ctx, target, source_ir, name + f"_individual_dim_dyn_{i}", input, i
)
# dim_tensor_list is a list of tensors
dim_tensor_list.append(dim_tensor)

Expand Down Expand Up @@ -150,12 +153,51 @@ def index(
# transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
transpose_tensor_shape = transpose_tensor.shape
_LOGGER.debug(f"The shape of transpose tensor is {transpose_tensor_shape}")

mult_d0 = 1
for i in range(adv_indx_count):
mult_d0 = mult_d0 * transpose_tensor_shape[i]
mult_d0 = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_shape_{i}",
trt.ElementWiseOperation.PROD,
mult_d0,
(
get_shape(
ctx,
target,
source_ir,
name + f"_transpose_tensor_shape_mult_d0_{i}",
transpose_tensor,
i,
)
if transpose_tensor_shape[i] == DYNAMIC_DIM
else transpose_tensor_shape[i]
),
)
mult_d1 = 1
for i in range(adv_indx_count, rank):
mult_d1 = mult_d1 * transpose_tensor_shape[i]
mult_d1 = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_shape_{i}",
trt.ElementWiseOperation.PROD,
mult_d1,
(
get_shape(
ctx,
target,
source_ir,
name + f"_transpose_tensor_shape_mult_d1_{i}",
transpose_tensor,
i,
)
if transpose_tensor_shape[i] == DYNAMIC_DIM
else transpose_tensor_shape[i]
),
)

concat_tensor_layer = ctx.net.add_concatenation(
[
Expand Down Expand Up @@ -185,11 +227,7 @@ def index(
ctx, cum_adv_index, name + "_index_sum_intermediate"
)
else:
multiplier = get_trt_tensor(
ctx,
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
name + "_dim_last",
)
multiplier = dim_tensor_list[adv_indx_indices[adv_indx_count - 1]]
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = convert_binary_elementwise(
Expand Down
Loading

0 comments on commit 5dd825e

Please sign in to comment.