|
10 | 10 | from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
|
11 | 11 | from torch_tensorrt.dynamo.conversion.converter_utils import (
|
12 | 12 | cast_trt_tensor,
|
| 13 | + create_constant, |
13 | 14 | get_axes_for_reduce_op,
|
14 | 15 | get_positive_dim,
|
15 | 16 | get_trt_tensor,
|
16 |
| - to_numpy, |
17 |
| -) |
18 |
| -from torch_tensorrt.fx.converters.converter_utils import ( |
19 | 17 | has_dynamic_shape,
|
20 | 18 | set_layer_name,
|
| 19 | + to_numpy, |
21 | 20 | )
|
| 21 | +from torch_tensorrt.dynamo.conversion.impl.cat import cat |
| 22 | +from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge |
| 23 | +from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape |
| 24 | +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM |
22 | 25 | from torch_tensorrt.fx.types import TRTTensor
|
23 | 26 | from torch_tensorrt.fx.utils import get_dynamic_dims
|
24 | 27 |
|
@@ -417,20 +420,21 @@ def pdist(
|
417 | 420 | ) -> Union[TRTTensor, Sequence[TRTTensor]]:
|
418 | 421 | shape = input.shape
|
419 | 422 | # Extend input from shape [N, D] to [N, 1, D]
|
420 |
| - extend_input = impl.shuffle.reshape( |
| 423 | + extend_input = impl.unsqueeze.unsqueeze( |
421 | 424 | ctx,
|
422 | 425 | target,
|
423 | 426 | source_ir,
|
424 |
| - f"{name}_reshape", |
| 427 | + f"{name}_unsqueeze", |
425 | 428 | input,
|
426 |
| - shape=shape[0:1] + (1,) + shape[1:], |
| 429 | + 1, |
427 | 430 | )
|
| 431 | + |
428 | 432 | # Expand the input from [N, 1, D] to [N, N, D]
|
429 | 433 | x = impl.slice.expand(
|
430 | 434 | ctx,
|
431 | 435 | target,
|
432 | 436 | source_ir,
|
433 |
| - f"{name}_sub", |
| 437 | + f"{name}_expand", |
434 | 438 | extend_input,
|
435 | 439 | (shape[0], shape[0]) + shape[1:],
|
436 | 440 | )
|
@@ -482,8 +486,194 @@ def pdist(
|
482 | 486 | raise RuntimeError(
|
483 | 487 | f"p should between [0, inf], currently p={p} is not supported!"
|
484 | 488 | )
|
485 |
| - indices = np.triu_indices(shape[0], k=1) |
486 |
| - return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices) |
| 489 | + if shape[0] == DYNAMIC_DIM: |
| 490 | + dim = get_shape(ctx, target, source_ir, f"{name}_get_shape", input, 0) |
| 491 | + shuffle_layer = ctx.net.add_shuffle(dim) |
| 492 | + shuffle_layer.reshape_dims = trt.Dims() |
| 493 | + set_layer_name(shuffle_layer, target, f"{name}_shuffle", source_ir) |
| 494 | + dim_tensor = shuffle_layer.get_output(0) |
| 495 | + indices_tensor = tri_upper_indices( |
| 496 | + ctx, target, source_ir, f"{name}_triu_indices", dim_tensor |
| 497 | + ) |
| 498 | + gather_layer = ctx.net.add_gather_v2( |
| 499 | + norm, indices_tensor, mode=trt.GatherMode.ND |
| 500 | + ) |
| 501 | + set_layer_name(gather_layer, target, f"{name}_gather_layer", source_ir) |
| 502 | + gather_layer.axis = 2 |
| 503 | + return gather_layer.get_output(0) |
| 504 | + else: |
| 505 | + indices = np.triu_indices(shape[0], k=1) |
| 506 | + return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices) |
| 507 | + |
| 508 | + |
| 509 | +def tri_upper_indices( |
| 510 | + ctx: ConversionContext, |
| 511 | + target: Target, |
| 512 | + source_ir: Optional[SourceIR], |
| 513 | + name: str, |
| 514 | + size_tensor: TRTTensor, |
| 515 | +) -> TRTTensor: |
| 516 | + """ |
| 517 | + Return the indices for the upper-triangle part of a square size of matrix in a N-by-2 Tensor, |
| 518 | + where the diagonal offset = 1. One loop is used to calculate the indices like below. |
| 519 | + x = 0, y = 0, y_start = 1 |
| 520 | + out_size = size * (size - 1) // 2 |
| 521 | + for _ in range(out_size): |
| 522 | + y_out.append(y_start + y) |
| 523 | + x_out.append(x) |
| 524 | + y += 1 |
| 525 | + if (y_start + y) >= size: |
| 526 | + x += 1 |
| 527 | + y_start += 1 |
| 528 | + y = 0 |
| 529 | + Args: |
| 530 | + ctx (ConversionContext): A ConversionContext containing the TensorRT network. |
| 531 | + target (Target): Target of calling node. |
| 532 | + source_ir (Optional[SourceIR]): SourceIR of calling converter. |
| 533 | + name (str): Name of the calling layer. |
| 534 | + size_tensor (TRTTensor): number of rows in the 2-D square matrix. scalar tensor. |
| 535 | +
|
| 536 | + Example: |
| 537 | + if size_tensor is 4, it will return [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]] |
| 538 | + """ |
| 539 | + constant_0 = create_constant(ctx, 0, f"{name}_zero", np.int32, 0) |
| 540 | + constant_1 = create_constant(ctx, 1, f"{name}_one", np.int32, 0) |
| 541 | + constant_2 = create_constant(ctx, 2, f"{name}_two", np.int32, 0) |
| 542 | + |
| 543 | + size_minus_one = impl.elementwise.sub( |
| 544 | + ctx, target, source_ir, f"{name}_size_minus_one", size_tensor, constant_1 |
| 545 | + ) |
| 546 | + |
| 547 | + size_mult_prev = impl.elementwise.mul( |
| 548 | + ctx, target, source_ir, f"{name}_size_mult_prev", size_tensor, size_minus_one |
| 549 | + ) |
| 550 | + |
| 551 | + num_loop = impl.elementwise.floor_divide( |
| 552 | + ctx, target, source_ir, f"{name}_num_loop", size_mult_prev, constant_2 |
| 553 | + ) |
| 554 | + |
| 555 | + loop = ctx.net.add_loop() |
| 556 | + loop.add_trip_limit(num_loop, trt.TripLimit.COUNT) |
| 557 | + |
| 558 | + x_recurrence = loop.add_recurrence(constant_0) |
| 559 | + set_layer_name(x_recurrence, target, f"{name}_x_recurrence", source_ir) |
| 560 | + x_tensor = x_recurrence.get_output(0) |
| 561 | + |
| 562 | + y_recurrence = loop.add_recurrence(constant_0) |
| 563 | + set_layer_name(y_recurrence, target, f"{name}_y_recurrence", source_ir) |
| 564 | + y_tensor = y_recurrence.get_output(0) |
| 565 | + |
| 566 | + y_start_recurrence = loop.add_recurrence(constant_1) |
| 567 | + set_layer_name(y_start_recurrence, target, f"{name}_y_start_recurrence", source_ir) |
| 568 | + y_start_tensor = y_start_recurrence.get_output(0) |
| 569 | + |
| 570 | + x_inc = impl.elementwise.add( |
| 571 | + ctx, |
| 572 | + target, |
| 573 | + source_ir, |
| 574 | + f"{name}_x_inc", |
| 575 | + x_tensor, |
| 576 | + constant_1, |
| 577 | + ) |
| 578 | + |
| 579 | + y_current = impl.elementwise.add( |
| 580 | + ctx, |
| 581 | + target, |
| 582 | + source_ir, |
| 583 | + f"{name}_y_current", |
| 584 | + y_start_tensor, |
| 585 | + y_tensor, |
| 586 | + ) |
| 587 | + |
| 588 | + y_inc = impl.elementwise.add( |
| 589 | + ctx, |
| 590 | + target, |
| 591 | + source_ir, |
| 592 | + f"{name}_y_inc", |
| 593 | + y_tensor, |
| 594 | + constant_1, |
| 595 | + ) |
| 596 | + |
| 597 | + next_y = impl.elementwise.add( |
| 598 | + ctx, |
| 599 | + target, |
| 600 | + source_ir, |
| 601 | + f"{name}_next_y", |
| 602 | + y_start_tensor, |
| 603 | + y_inc, |
| 604 | + ) |
| 605 | + |
| 606 | + y_start_inc = impl.elementwise.add( |
| 607 | + ctx, |
| 608 | + target, |
| 609 | + source_ir, |
| 610 | + f"{name}_y_start_inc", |
| 611 | + y_start_tensor, |
| 612 | + constant_1, |
| 613 | + ) |
| 614 | + cond = ge(ctx, target, source_ir, f"{name}_cond", next_y, size_tensor) |
| 615 | + x_output = impl.condition.select( |
| 616 | + ctx, |
| 617 | + target, |
| 618 | + source_ir, |
| 619 | + f"{name}_x_output", |
| 620 | + x_inc, |
| 621 | + x_tensor, |
| 622 | + cond, |
| 623 | + ) |
| 624 | + x_recurrence.set_input(1, x_output) |
| 625 | + |
| 626 | + y_start_current = impl.condition.select( |
| 627 | + ctx, |
| 628 | + target, |
| 629 | + source_ir, |
| 630 | + f"{name}_y_start_current", |
| 631 | + y_start_inc, |
| 632 | + y_start_tensor, |
| 633 | + cond, |
| 634 | + ) |
| 635 | + y_start_recurrence.set_input(1, y_start_current) |
| 636 | + |
| 637 | + y_val = impl.condition.select( |
| 638 | + ctx, |
| 639 | + target, |
| 640 | + source_ir, |
| 641 | + f"{name}_y_val", |
| 642 | + constant_0, |
| 643 | + y_inc, |
| 644 | + cond, |
| 645 | + ) |
| 646 | + y_recurrence.set_input(1, y_val) |
| 647 | + |
| 648 | + loop_output_x = loop.add_loop_output(x_tensor, trt.LoopOutput.CONCATENATE) |
| 649 | + loop_output_y = loop.add_loop_output(y_current, trt.LoopOutput.CONCATENATE) |
| 650 | + loop_output_x.set_input(1, num_loop) |
| 651 | + loop_output_y.set_input(1, num_loop) |
| 652 | + |
| 653 | + # Cat two N tensors into 2 x N. [0, 0, 0], [1, 2, 3] -> [[0, 0, 0], [1, 2, 3]] |
| 654 | + x_index = impl.shuffle.reshape( |
| 655 | + ctx, target, source_ir, f"{name}_x_index", loop_output_x.get_output(0), (1, -1) |
| 656 | + ) |
| 657 | + y_index = impl.shuffle.reshape( |
| 658 | + ctx, target, source_ir, f"{name}_y_index", loop_output_y.get_output(0), (1, -1) |
| 659 | + ) |
| 660 | + |
| 661 | + x_y_tensor = cat( |
| 662 | + ctx, |
| 663 | + target, |
| 664 | + source_ir, |
| 665 | + f"{name}_x_y_tensor", |
| 666 | + [x_index, y_index], |
| 667 | + 0, |
| 668 | + ) |
| 669 | + |
| 670 | + # Reshape 2 x N output to N x 2. [[0, 0, 0], [1, 2, 3]] -> [[0, 1], [0, 2], [0, 3]] |
| 671 | + indices_tensor = ctx.net.add_shuffle(x_y_tensor) |
| 672 | + set_layer_name(indices_tensor, target, f"{name}_indices_tensor", source_ir) |
| 673 | + indices_tensor.first_transpose = trt.Permutation([1, 0]) |
| 674 | + indices_tensor.reshape_dims = (-1, 2) |
| 675 | + |
| 676 | + return indices_tensor.get_output(0) |
487 | 677 |
|
488 | 678 |
|
489 | 679 | def cdist_forward(
|
|
0 commit comments