From 8295117eb168b284c2e06446fde86710b7b203a8 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 15 Aug 2024 17:18:24 -0700 Subject: [PATCH] addressing review comments --- .../dynamo/conversion/impl/select.py | 50 ++++++++++--------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 2ea34acb95..a788699bcd 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -156,7 +156,19 @@ def index( _LOGGER.debug(f"The shape of transpose tensor is {transpose_tensor_shape}") mult_d0 = 1 + dim_tensor_shape_mult_d0 = 1 for i in range(adv_indx_count): + if transpose_tensor_shape[i] == DYNAMIC_DIM: + dim_tensor_shape_mult_d0 = get_shape( + ctx, + target, + source_ir, + name + f"_transpose_tensor_shape_mult_d0_{i}", + transpose_tensor, + i, + ) + else: + dim_tensor_shape_mult_d0 = transpose_tensor_shape[i] mult_d0 = convert_binary_elementwise( ctx, target, @@ -164,21 +176,22 @@ def index( name + f"_shape_{i}", trt.ElementWiseOperation.PROD, mult_d0, - ( - get_shape( - ctx, - target, - source_ir, - name + f"_transpose_tensor_shape_mult_d0_{i}", - transpose_tensor, - i, - ) - if transpose_tensor_shape[i] == DYNAMIC_DIM - else transpose_tensor_shape[i] - ), + dim_tensor_shape_mult_d0, ) mult_d1 = 1 + dim_tensor_shape_mult_d1 = 1 for i in range(adv_indx_count, rank): + if transpose_tensor_shape[i] == DYNAMIC_DIM: + dim_tensor_shape_mult_d1 = get_shape( + ctx, + target, + source_ir, + name + f"_transpose_tensor_shape_mult_d0_{i}", + transpose_tensor, + i, + ) + else: + dim_tensor_shape_mult_d1 = transpose_tensor_shape[i] mult_d1 = convert_binary_elementwise( ctx, target, @@ -186,18 +199,7 @@ def index( name + f"_shape_{i}", trt.ElementWiseOperation.PROD, mult_d1, - ( - get_shape( - ctx, - target, - source_ir, - name + f"_transpose_tensor_shape_mult_d1_{i}", - transpose_tensor, - i, - ) - if transpose_tensor_shape[i] == DYNAMIC_DIM - else transpose_tensor_shape[i] - ), + dim_tensor_shape_mult_d1, ) concat_tensor_layer = ctx.net.add_concatenation(