Skip to content

Commit 8ecc809

Browse files
authored
chore: dynamic shape support for pdist ops (#3068)
1 parent 19f671d commit 8ecc809

File tree

3 files changed

+260
-10
lines changed

3 files changed

+260
-10
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3573,7 +3573,9 @@ def aten_ops_any(
35733573
)
35743574

35753575

3576-
@dynamo_tensorrt_converter(torch.ops.aten._pdist_forward.default)
3576+
@dynamo_tensorrt_converter(
3577+
torch.ops.aten._pdist_forward.default, supports_dynamic_shapes=True
3578+
)
35773579
@enforce_tensor_types(
35783580
{
35793581
0: (TRTTensor,),

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

+199-9
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
1212
cast_trt_tensor,
13+
create_constant,
1314
get_axes_for_reduce_op,
1415
get_positive_dim,
1516
get_trt_tensor,
16-
to_numpy,
17-
)
18-
from torch_tensorrt.fx.converters.converter_utils import (
1917
has_dynamic_shape,
2018
set_layer_name,
19+
to_numpy,
2120
)
21+
from torch_tensorrt.dynamo.conversion.impl.cat import cat
22+
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
23+
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
24+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
2225
from torch_tensorrt.fx.types import TRTTensor
2326
from torch_tensorrt.fx.utils import get_dynamic_dims
2427

@@ -417,20 +420,21 @@ def pdist(
417420
) -> Union[TRTTensor, Sequence[TRTTensor]]:
418421
shape = input.shape
419422
# Extend input from shape [N, D] to [N, 1, D]
420-
extend_input = impl.shuffle.reshape(
423+
extend_input = impl.unsqueeze.unsqueeze(
421424
ctx,
422425
target,
423426
source_ir,
424-
f"{name}_reshape",
427+
f"{name}_unsqueeze",
425428
input,
426-
shape=shape[0:1] + (1,) + shape[1:],
429+
1,
427430
)
431+
428432
# Expand the input from [N, 1, D] to [N, N, D]
429433
x = impl.slice.expand(
430434
ctx,
431435
target,
432436
source_ir,
433-
f"{name}_sub",
437+
f"{name}_expand",
434438
extend_input,
435439
(shape[0], shape[0]) + shape[1:],
436440
)
@@ -482,8 +486,194 @@ def pdist(
482486
raise RuntimeError(
483487
f"p should between [0, inf], currently p={p} is not supported!"
484488
)
485-
indices = np.triu_indices(shape[0], k=1)
486-
return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices)
489+
if shape[0] == DYNAMIC_DIM:
490+
dim = get_shape(ctx, target, source_ir, f"{name}_get_shape", input, 0)
491+
shuffle_layer = ctx.net.add_shuffle(dim)
492+
shuffle_layer.reshape_dims = trt.Dims()
493+
set_layer_name(shuffle_layer, target, f"{name}_shuffle", source_ir)
494+
dim_tensor = shuffle_layer.get_output(0)
495+
indices_tensor = tri_upper_indices(
496+
ctx, target, source_ir, f"{name}_triu_indices", dim_tensor
497+
)
498+
gather_layer = ctx.net.add_gather_v2(
499+
norm, indices_tensor, mode=trt.GatherMode.ND
500+
)
501+
set_layer_name(gather_layer, target, f"{name}_gather_layer", source_ir)
502+
gather_layer.axis = 2
503+
return gather_layer.get_output(0)
504+
else:
505+
indices = np.triu_indices(shape[0], k=1)
506+
return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices)
507+
508+
509+
def tri_upper_indices(
510+
ctx: ConversionContext,
511+
target: Target,
512+
source_ir: Optional[SourceIR],
513+
name: str,
514+
size_tensor: TRTTensor,
515+
) -> TRTTensor:
516+
"""
517+
Return the indices for the upper-triangle part of a square size of matrix in a N-by-2 Tensor,
518+
where the diagonal offset = 1. One loop is used to calculate the indices like below.
519+
x = 0, y = 0, y_start = 1
520+
out_size = size * (size - 1) // 2
521+
for _ in range(out_size):
522+
y_out.append(y_start + y)
523+
x_out.append(x)
524+
y += 1
525+
if (y_start + y) >= size:
526+
x += 1
527+
y_start += 1
528+
y = 0
529+
Args:
530+
ctx (ConversionContext): A ConversionContext containing the TensorRT network.
531+
target (Target): Target of calling node.
532+
source_ir (Optional[SourceIR]): SourceIR of calling converter.
533+
name (str): Name of the calling layer.
534+
size_tensor (TRTTensor): number of rows in the 2-D square matrix. scalar tensor.
535+
536+
Example:
537+
if size_tensor is 4, it will return [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
538+
"""
539+
constant_0 = create_constant(ctx, 0, f"{name}_zero", np.int32, 0)
540+
constant_1 = create_constant(ctx, 1, f"{name}_one", np.int32, 0)
541+
constant_2 = create_constant(ctx, 2, f"{name}_two", np.int32, 0)
542+
543+
size_minus_one = impl.elementwise.sub(
544+
ctx, target, source_ir, f"{name}_size_minus_one", size_tensor, constant_1
545+
)
546+
547+
size_mult_prev = impl.elementwise.mul(
548+
ctx, target, source_ir, f"{name}_size_mult_prev", size_tensor, size_minus_one
549+
)
550+
551+
num_loop = impl.elementwise.floor_divide(
552+
ctx, target, source_ir, f"{name}_num_loop", size_mult_prev, constant_2
553+
)
554+
555+
loop = ctx.net.add_loop()
556+
loop.add_trip_limit(num_loop, trt.TripLimit.COUNT)
557+
558+
x_recurrence = loop.add_recurrence(constant_0)
559+
set_layer_name(x_recurrence, target, f"{name}_x_recurrence", source_ir)
560+
x_tensor = x_recurrence.get_output(0)
561+
562+
y_recurrence = loop.add_recurrence(constant_0)
563+
set_layer_name(y_recurrence, target, f"{name}_y_recurrence", source_ir)
564+
y_tensor = y_recurrence.get_output(0)
565+
566+
y_start_recurrence = loop.add_recurrence(constant_1)
567+
set_layer_name(y_start_recurrence, target, f"{name}_y_start_recurrence", source_ir)
568+
y_start_tensor = y_start_recurrence.get_output(0)
569+
570+
x_inc = impl.elementwise.add(
571+
ctx,
572+
target,
573+
source_ir,
574+
f"{name}_x_inc",
575+
x_tensor,
576+
constant_1,
577+
)
578+
579+
y_current = impl.elementwise.add(
580+
ctx,
581+
target,
582+
source_ir,
583+
f"{name}_y_current",
584+
y_start_tensor,
585+
y_tensor,
586+
)
587+
588+
y_inc = impl.elementwise.add(
589+
ctx,
590+
target,
591+
source_ir,
592+
f"{name}_y_inc",
593+
y_tensor,
594+
constant_1,
595+
)
596+
597+
next_y = impl.elementwise.add(
598+
ctx,
599+
target,
600+
source_ir,
601+
f"{name}_next_y",
602+
y_start_tensor,
603+
y_inc,
604+
)
605+
606+
y_start_inc = impl.elementwise.add(
607+
ctx,
608+
target,
609+
source_ir,
610+
f"{name}_y_start_inc",
611+
y_start_tensor,
612+
constant_1,
613+
)
614+
cond = ge(ctx, target, source_ir, f"{name}_cond", next_y, size_tensor)
615+
x_output = impl.condition.select(
616+
ctx,
617+
target,
618+
source_ir,
619+
f"{name}_x_output",
620+
x_inc,
621+
x_tensor,
622+
cond,
623+
)
624+
x_recurrence.set_input(1, x_output)
625+
626+
y_start_current = impl.condition.select(
627+
ctx,
628+
target,
629+
source_ir,
630+
f"{name}_y_start_current",
631+
y_start_inc,
632+
y_start_tensor,
633+
cond,
634+
)
635+
y_start_recurrence.set_input(1, y_start_current)
636+
637+
y_val = impl.condition.select(
638+
ctx,
639+
target,
640+
source_ir,
641+
f"{name}_y_val",
642+
constant_0,
643+
y_inc,
644+
cond,
645+
)
646+
y_recurrence.set_input(1, y_val)
647+
648+
loop_output_x = loop.add_loop_output(x_tensor, trt.LoopOutput.CONCATENATE)
649+
loop_output_y = loop.add_loop_output(y_current, trt.LoopOutput.CONCATENATE)
650+
loop_output_x.set_input(1, num_loop)
651+
loop_output_y.set_input(1, num_loop)
652+
653+
# Cat two N tensors into 2 x N. [0, 0, 0], [1, 2, 3] -> [[0, 0, 0], [1, 2, 3]]
654+
x_index = impl.shuffle.reshape(
655+
ctx, target, source_ir, f"{name}_x_index", loop_output_x.get_output(0), (1, -1)
656+
)
657+
y_index = impl.shuffle.reshape(
658+
ctx, target, source_ir, f"{name}_y_index", loop_output_y.get_output(0), (1, -1)
659+
)
660+
661+
x_y_tensor = cat(
662+
ctx,
663+
target,
664+
source_ir,
665+
f"{name}_x_y_tensor",
666+
[x_index, y_index],
667+
0,
668+
)
669+
670+
# Reshape 2 x N output to N x 2. [[0, 0, 0], [1, 2, 3]] -> [[0, 1], [0, 2], [0, 3]]
671+
indices_tensor = ctx.net.add_shuffle(x_y_tensor)
672+
set_layer_name(indices_tensor, target, f"{name}_indices_tensor", source_ir)
673+
indices_tensor.first_transpose = trt.Permutation([1, 0])
674+
indices_tensor.reshape_dims = (-1, 2)
675+
676+
return indices_tensor.get_output(0)
487677

488678

489679
def cdist_forward(

tests/py/dynamo/conversion/test_pdist_aten.py

+58
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -32,5 +33,62 @@ def forward(self, input):
3233
)
3334

3435

36+
class TestDynamicShapePdistConverter(DispatchTestCase):
37+
@parameterized.expand(
38+
[
39+
(
40+
"dim0_dynamic_dim1_static_p_0",
41+
(1, 4),
42+
(2, 4),
43+
(4, 4),
44+
0,
45+
),
46+
(
47+
"dim0_static_dim1_dynamic_p_1",
48+
(3, 1),
49+
(3, 2),
50+
(3, 4),
51+
1,
52+
),
53+
(
54+
"dim0_dynamic_dim1_static_p_other",
55+
(1, 5),
56+
(2, 5),
57+
(6, 5),
58+
0.4,
59+
),
60+
(
61+
"dim0_dynamic_dim1_dynamic_p_inf",
62+
(1, 1),
63+
(2, 2),
64+
(5, 4),
65+
float("inf"),
66+
),
67+
(
68+
"dim0_dynamic_dim1_dynamic_p_other",
69+
(2, 1),
70+
(3, 2),
71+
(4, 7),
72+
1.7,
73+
),
74+
]
75+
)
76+
def test_pdist_float(self, _, min_shape, opt_shape, max_shape, p):
77+
class Pdist(nn.Module):
78+
def forward(self, input):
79+
return torch.ops.aten._pdist_forward.default(input, p)
80+
81+
input_specs = [
82+
Input(
83+
min_shape=min_shape,
84+
opt_shape=opt_shape,
85+
max_shape=max_shape,
86+
dtype=torch.float,
87+
),
88+
]
89+
90+
self.run_test_with_dynamic_shape(Pdist(), input_specs)
91+
92+
3593
if __name__ == "__main__":
3694
run_tests()

0 commit comments

Comments
 (0)