|
17 | 17 | )
|
18 | 18 | from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
|
19 | 19 | from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
|
| 20 | +from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape |
| 21 | +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM |
20 | 22 | from torch_tensorrt.fx.converters.converter_utils import (
|
21 | 23 | has_dynamic_shape,
|
22 | 24 | set_layer_name,
|
@@ -111,17 +113,18 @@ def index(
|
111 | 113 | else:
|
112 | 114 | input_shape = input.shape
|
113 | 115 | _LOGGER.debug(f"The input shape is {input.shape}")
|
114 |
| - if dynamic_shape: |
115 |
| - input_shape = get_shape_with_dynamic_shape( |
116 |
| - ctx.net, target, source_ir, name, input_shape, input |
117 |
| - ) |
118 | 116 | rank = len(input_shape)
|
119 | 117 | adv_indx_count = len(adv_indx_indices)
|
120 | 118 | dim_tensor_list = []
|
121 | 119 |
|
122 | 120 | for i in range(rank):
|
123 |
| - dim = input_shape[i] |
124 |
| - dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}") |
| 121 | + if input_shape[i] != DYNAMIC_DIM: |
| 122 | + dim = input_shape[i] |
| 123 | + dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}") |
| 124 | + else: |
| 125 | + dim_tensor = get_shape( |
| 126 | + ctx, target, source_ir, name + f"_individual_dim_dyn_{i}", input, i |
| 127 | + ) |
125 | 128 | # dim_tensor_list is a list of tensors
|
126 | 129 | dim_tensor_list.append(dim_tensor)
|
127 | 130 |
|
@@ -150,12 +153,53 @@ def index(
|
150 | 153 | # transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
|
151 | 154 | transpose_tensor_shape = transpose_tensor.shape
|
152 | 155 | _LOGGER.debug(f"The shape of transpose tensor is {transpose_tensor_shape}")
|
| 156 | + |
153 | 157 | mult_d0 = 1
|
| 158 | + dim_tensor_shape_mult_d0 = 1 |
154 | 159 | for i in range(adv_indx_count):
|
155 |
| - mult_d0 = mult_d0 * transpose_tensor_shape[i] |
| 160 | + if transpose_tensor_shape[i] == DYNAMIC_DIM: |
| 161 | + dim_tensor_shape_mult_d0 = get_shape( |
| 162 | + ctx, |
| 163 | + target, |
| 164 | + source_ir, |
| 165 | + name + f"_transpose_tensor_shape_mult_d0_{i}", |
| 166 | + transpose_tensor, |
| 167 | + i, |
| 168 | + ) |
| 169 | + else: |
| 170 | + dim_tensor_shape_mult_d0 = transpose_tensor_shape[i] |
| 171 | + mult_d0 = convert_binary_elementwise( |
| 172 | + ctx, |
| 173 | + target, |
| 174 | + source_ir, |
| 175 | + name + f"_shape_{i}", |
| 176 | + trt.ElementWiseOperation.PROD, |
| 177 | + mult_d0, |
| 178 | + dim_tensor_shape_mult_d0, |
| 179 | + ) |
156 | 180 | mult_d1 = 1
|
| 181 | + dim_tensor_shape_mult_d1 = 1 |
157 | 182 | for i in range(adv_indx_count, rank):
|
158 |
| - mult_d1 = mult_d1 * transpose_tensor_shape[i] |
| 183 | + if transpose_tensor_shape[i] == DYNAMIC_DIM: |
| 184 | + dim_tensor_shape_mult_d1 = get_shape( |
| 185 | + ctx, |
| 186 | + target, |
| 187 | + source_ir, |
| 188 | + name + f"_transpose_tensor_shape_mult_d0_{i}", |
| 189 | + transpose_tensor, |
| 190 | + i, |
| 191 | + ) |
| 192 | + else: |
| 193 | + dim_tensor_shape_mult_d1 = transpose_tensor_shape[i] |
| 194 | + mult_d1 = convert_binary_elementwise( |
| 195 | + ctx, |
| 196 | + target, |
| 197 | + source_ir, |
| 198 | + name + f"_shape_{i}", |
| 199 | + trt.ElementWiseOperation.PROD, |
| 200 | + mult_d1, |
| 201 | + dim_tensor_shape_mult_d1, |
| 202 | + ) |
159 | 203 |
|
160 | 204 | concat_tensor_layer = ctx.net.add_concatenation(
|
161 | 205 | [
|
@@ -185,11 +229,7 @@ def index(
|
185 | 229 | ctx, cum_adv_index, name + "_index_sum_intermediate"
|
186 | 230 | )
|
187 | 231 | else:
|
188 |
| - multiplier = get_trt_tensor( |
189 |
| - ctx, |
190 |
| - dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], |
191 |
| - name + "_dim_last", |
192 |
| - ) |
| 232 | + multiplier = dim_tensor_list[adv_indx_indices[adv_indx_count - 1]] |
193 | 233 | cum_adv_index = tensor_indices[adv_indx_count - 1]
|
194 | 234 | for i in range(adv_indx_count - 2, -1, -1):
|
195 | 235 | adv_index = convert_binary_elementwise(
|
|
0 commit comments