Skip to content

Commit

Permalink
fix: get_padded_shape_tensors can now handle dynamic pads (#3123)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiwoong-choi authored Sep 3, 2024
1 parent d75f588 commit ae7e6c8
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions py/torch_tensorrt/dynamo/conversion/impl/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_padded_shape_tensors(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
pad: Sequence[int],
pad: Sequence[Union[int, TRTTensor]],
) -> TRTTensor:
rank = len(input.shape)
if len(pad) // 2 > rank:
Expand All @@ -47,11 +47,11 @@ def get_padded_shape_tensors(
start_list = [0] * rank
for i in range(len(pad) // 2):
dim_index = rank - (i + 1)
pad_before = pad[i * 2]
pad_after = pad[i * 2 + 1]
pad_before = get_trt_tensor(ctx, pad[i * 2], f"{name}_pad_before_{i}")
pad_after = get_trt_tensor(ctx, pad[i * 2 + 1], f"{name}_pad_after_{i}")

pad_sum = get_trt_tensor(
ctx, pad_before + pad_after, f"{name}_pad_sum_{i}", dtype=np.int32
pad_sum = impl.elementwise.add(
ctx, target, source_ir, f"{name}_pad_sum_{i}", pad_before, pad_after
)
dim_shape = ctx.net.add_slice(
input_shape_tensor,
Expand All @@ -63,7 +63,9 @@ def get_padded_shape_tensors(
new_dim_shape = impl.elementwise.add(
ctx, target, source_ir, f"{name}_shape_dim_{i}", dim_shape, pad_sum
)
start_list[dim_index] = -pad_before
start_list[dim_index] = impl.elementwise.sub(
ctx, target, source_ir, f"{name}_pad_before_neg_{i}", 0, pad_before
)

slices = []
for j in range(rank):
Expand All @@ -79,14 +81,23 @@ def get_padded_shape_tensors(
).get_output(0)
)
padded_shape_tensor = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_dim_{i}", slices, 0
ctx,
target,
source_ir,
f"{name}_cat_dim_{i}",
slices,
0,
cast_dtype=padded_shape_tensor.dtype,
)

start_indices_tensor = get_trt_tensor(
start_indices_tensor = impl.cat.cat(
ctx,
np.array(start_list, dtype=np.int32),
target,
source_ir,
f"{name}_start_indices_tensor",
dtype=np.int32,
start_list,
0,
cast_dtype=padded_shape_tensor.dtype,
)

return start_indices_tensor, padded_shape_tensor
Expand All @@ -98,7 +109,7 @@ def constant_padNd(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
pad: Sequence[int],
pad: Sequence[Union[int, TRTTensor]],
value: Union[int, float] = 0,
) -> TRTTensor:

Expand Down

0 comments on commit ae7e6c8

Please sign in to comment.