diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index aa55b23153..0d55c5f014 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -155,7 +155,19 @@ def index( _LOGGER.debug(f"The shape of transpose tensor is {transpose_tensor_shape}") mult_d0 = 1 + dim_tensor_shape_mult_d0 = 1 for i in range(adv_indx_count): + if transpose_tensor_shape[i] == DYNAMIC_DIM: + dim_tensor_shape_mult_d0 = get_shape( + ctx, + target, + source_ir, + name + f"_transpose_tensor_shape_mult_d0_{i}", + transpose_tensor, + i, + ) + else: + dim_tensor_shape_mult_d0 = transpose_tensor_shape[i] mult_d0 = convert_binary_elementwise( ctx, target, @@ -163,21 +175,22 @@ def index( name + f"_shape_{i}", trt.ElementWiseOperation.PROD, mult_d0, - ( - get_shape( - ctx, - target, - source_ir, - name + f"_transpose_tensor_shape_mult_d0_{i}", - transpose_tensor, - i, - ) - if transpose_tensor_shape[i] == DYNAMIC_DIM - else transpose_tensor_shape[i] - ), + dim_tensor_shape_mult_d0, ) mult_d1 = 1 + dim_tensor_shape_mult_d1 = 1 for i in range(adv_indx_count, rank): + if transpose_tensor_shape[i] == DYNAMIC_DIM: + dim_tensor_shape_mult_d1 = get_shape( + ctx, + target, + source_ir, + name + f"_transpose_tensor_shape_mult_d0_{i}", + transpose_tensor, + i, + ) + else: + dim_tensor_shape_mult_d1 = transpose_tensor_shape[i] mult_d1 = convert_binary_elementwise( ctx, target, @@ -185,18 +198,7 @@ def index( name + f"_shape_{i}", trt.ElementWiseOperation.PROD, mult_d1, - ( - get_shape( - ctx, - target, - source_ir, - name + f"_transpose_tensor_shape_mult_d1_{i}", - transpose_tensor, - i, - ) - if transpose_tensor_shape[i] == DYNAMIC_DIM - else transpose_tensor_shape[i] - ), + dim_tensor_shape_mult_d1, ) concat_tensor_layer = ctx.net.add_concatenation(