Skip to content

Commit

Permalink
addressing review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Aug 29, 2024
1 parent 5dd825e commit a64c29a
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,48 +155,50 @@ def index(
_LOGGER.debug(f"The shape of transpose tensor is {transpose_tensor_shape}")

mult_d0 = 1
dim_tensor_shape_mult_d0 = 1
for i in range(adv_indx_count):
if transpose_tensor_shape[i] == DYNAMIC_DIM:
dim_tensor_shape_mult_d0 = get_shape(
ctx,
target,
source_ir,
name + f"_transpose_tensor_shape_mult_d0_{i}",
transpose_tensor,
i,
)
else:
dim_tensor_shape_mult_d0 = transpose_tensor_shape[i]
mult_d0 = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_shape_{i}",
trt.ElementWiseOperation.PROD,
mult_d0,
(
get_shape(
ctx,
target,
source_ir,
name + f"_transpose_tensor_shape_mult_d0_{i}",
transpose_tensor,
i,
)
if transpose_tensor_shape[i] == DYNAMIC_DIM
else transpose_tensor_shape[i]
),
dim_tensor_shape_mult_d0,
)
mult_d1 = 1
dim_tensor_shape_mult_d1 = 1
for i in range(adv_indx_count, rank):
if transpose_tensor_shape[i] == DYNAMIC_DIM:
dim_tensor_shape_mult_d1 = get_shape(
ctx,
target,
source_ir,
name + f"_transpose_tensor_shape_mult_d0_{i}",
transpose_tensor,
i,
)
else:
dim_tensor_shape_mult_d1 = transpose_tensor_shape[i]
mult_d1 = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_shape_{i}",
trt.ElementWiseOperation.PROD,
mult_d1,
(
get_shape(
ctx,
target,
source_ir,
name + f"_transpose_tensor_shape_mult_d1_{i}",
transpose_tensor,
i,
)
if transpose_tensor_shape[i] == DYNAMIC_DIM
else transpose_tensor_shape[i]
),
dim_tensor_shape_mult_d1,
)

concat_tensor_layer = ctx.net.add_concatenation(
Expand Down

0 comments on commit a64c29a

Please sign in to comment.