@@ -426,6 +426,342 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
426
426
}
427
427
return failure ();
428
428
});
429
+ patterns.onOp (
430
+ " Conv" , 11 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
431
+ std::string autoPad;
432
+ if (binder.customOpNameStringAttr (autoPad, " auto_pad" , " NOTSET" ))
433
+ return failure ();
434
+ if (autoPad != " NOTSET" ) {
435
+ // TODO: Add support for `auto_pad` != "NOTSET"
436
+ return rewriter.notifyMatchFailure (
437
+ binder.op , " unsupported conversion: auto_pad != NOTSET" );
438
+ }
439
+
440
+ Torch::ValueTensorType resultType;
441
+ Value input, weight;
442
+ int64_t group;
443
+ if (binder.tensorOperands (input, weight) ||
444
+ binder.s64IntegerAttr (group, " group" , 1 ) ||
445
+ binder.tensorResultType (resultType))
446
+ return failure ();
447
+
448
+ auto weightTensorType = weight.getType ().cast <Torch::ValueTensorType>();
449
+ if (!weightTensorType || !weightTensorType.hasSizes ()) {
450
+ return rewriter.notifyMatchFailure (
451
+ binder.op , " Expected weight type having sizes" );
452
+ }
453
+ ArrayRef<int64_t > weightShape = weightTensorType.getSizes ();
454
+ SmallVector<int64_t > kernelShape;
455
+ if (binder.s64IntegerArrayAttr (kernelShape, " kernel_shape" , {}))
456
+ return failure ();
457
+ if (kernelShape.size ()) {
458
+ if (kernelShape.size () != weightShape.size () - 2 ) {
459
+ return rewriter.notifyMatchFailure (
460
+ binder.op ,
461
+ " unsupported conversion: kernel_shape list size should have "
462
+ " number of values equal to weight_rank - 2" );
463
+ } else {
464
+ for (unsigned i = 0 ; i < kernelShape.size (); i++) {
465
+ if (weightShape[i + 2 ] != kernelShape[i]) {
466
+ return rewriter.notifyMatchFailure (
467
+ binder.op , " unsupported conversion: kernel_shape value "
468
+ " should be equal to the weight tensor shape" );
469
+ }
470
+ }
471
+ }
472
+ }
473
+
474
+ // Determine the rank of input tensor.
475
+ std::optional<unsigned > maybeRank = Torch::getTensorRank (input);
476
+ if (!maybeRank)
477
+ return rewriter.notifyMatchFailure (binder.op ,
478
+ " Unimplemented: unranked tensor" );
479
+ unsigned rank = *maybeRank;
480
+
481
+ SmallVector<int64_t > padding, strides, dilations;
482
+ SmallVector<int64_t > defaultPadding, defaultStrides, defaultDilations;
483
+ for (unsigned i = 0 ; i < rank - 2 ; i++) {
484
+ defaultPadding.push_back (0 );
485
+ defaultStrides.push_back (1 );
486
+ defaultDilations.push_back (1 );
487
+ }
488
+ // Padding for the beginning and ending along each spatial axis, it can
489
+ // take any value greater than or equal to 0. The value represent the
490
+ // number of pixels added to the beginning and end part of the
491
+ // corresponding axis. pads format should be as follow [x1_begin,
492
+ // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
493
+ // at the beginning of axis i and xi_end, the number of pixels added at
494
+ // the end of axis i.
495
+ if (binder.s64IntegerArrayAttr (padding, " pads" , defaultPadding)) {
496
+ return failure ();
497
+ }
498
+ if (padding.size () != rank - 2 && padding.size () != 2 * (rank - 2 )) {
499
+ return rewriter.notifyMatchFailure (
500
+ binder.op , " padding list size does not match the number of axes" );
501
+ }
502
+ if (binder.s64IntegerArrayAttr (dilations, " dilations" ,
503
+ defaultDilations)) {
504
+ return failure ();
505
+ }
506
+ if (dilations.size () != rank - 2 ) {
507
+ return rewriter.notifyMatchFailure (
508
+ binder.op ,
509
+ " dilations list size does not match the number of axes" );
510
+ }
511
+ if (binder.s64IntegerArrayAttr (strides, " strides" , defaultStrides)) {
512
+ return failure ();
513
+ }
514
+ if (strides.size () != rank - 2 ) {
515
+ return rewriter.notifyMatchFailure (
516
+ binder.op , " strides list size does not match the number of axes" );
517
+ }
518
+
519
+ SmallVector<Value> cstPadding, cstStrides, cstDilations,
520
+ cstOutputPadding;
521
+ if (padding.size () != 2 * (rank - 2 )) {
522
+ for (int64_t i : padding) {
523
+ cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
524
+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
525
+ }
526
+ } else {
527
+ for (unsigned i = 0 ; i < padding.size () / 2 ; i++) {
528
+ if (padding[i] != padding[i + (padding.size () / 2 )]) {
529
+ // TODO: Add support for different padding values for the
530
+ // beginning and ending along each spatial axis
531
+ return rewriter.notifyMatchFailure (
532
+ binder.op ,
533
+ " unsupported conversion: padding values for the beginning "
534
+ " and ending along each spatial axis must be equal" );
535
+ }
536
+ cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
537
+ binder.getLoc (), rewriter.getI64IntegerAttr (padding[i])));
538
+ }
539
+ }
540
+ for (int64_t i : dilations) {
541
+ cstDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
542
+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
543
+ }
544
+ for (int64_t i : strides) {
545
+ cstStrides.push_back (rewriter.create <Torch::ConstantIntOp>(
546
+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
547
+ }
548
+ Value cstZero = rewriter.create <Torch::ConstantIntOp>(
549
+ binder.getLoc (), rewriter.getI64IntegerAttr (0 ));
550
+ cstOutputPadding = {cstZero, cstZero};
551
+
552
+ Value paddingList = rewriter.create <Torch::PrimListConstructOp>(
553
+ binder.getLoc (),
554
+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
555
+ cstPadding);
556
+ Value dilationsList = rewriter.create <Torch::PrimListConstructOp>(
557
+ binder.getLoc (),
558
+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
559
+ cstDilations);
560
+ Value stridesList = rewriter.create <Torch::PrimListConstructOp>(
561
+ binder.getLoc (),
562
+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
563
+ cstStrides);
564
+ Value outputPaddingList = rewriter.create <Torch::PrimListConstructOp>(
565
+ binder.getLoc (),
566
+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
567
+ cstOutputPadding);
568
+ Value transposed =
569
+ rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), false );
570
+ Value bias;
571
+ if (binder.op ->getNumOperands () == 3 ) {
572
+ if (binder.tensorOperandAtIndex (bias, 2 )) {
573
+ return failure ();
574
+ }
575
+ } else {
576
+ bias = rewriter.create <Torch::ConstantNoneOp>(binder.getLoc ());
577
+ }
578
+ Value cstGroup = rewriter.create <Torch::ConstantIntOp>(
579
+ binder.getLoc (), rewriter.getI64IntegerAttr (group));
580
+
581
+ rewriter.replaceOpWithNewOp <Torch::AtenConvolutionOp>(
582
+ binder.op , resultType, input, weight, bias, stridesList,
583
+ paddingList, dilationsList, transposed, outputPaddingList,
584
+ cstGroup);
585
+ return success ();
586
+ });
587
+ patterns.onOp (
588
+ " ConvTranspose" , 11 ,
589
+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
590
+ std::string autoPad;
591
+ if (binder.customOpNameStringAttr (autoPad, " auto_pad" , " NOTSET" ))
592
+ return failure ();
593
+ if (autoPad != " NOTSET" ) {
594
+ // TODO: Add support for `auto_pad` != "NOTSET"
595
+ return rewriter.notifyMatchFailure (
596
+ binder.op , " unsupported conversion: auto_pad != NOTSET" );
597
+ }
598
+ SmallVector<int64_t > outputShape;
599
+ if (binder.s64IntegerArrayAttr (outputShape, " output_shape" , {}))
600
+ return failure ();
601
+ if (outputShape.size ()) {
602
+ // TODO: Add support for non-None output_shape value.
603
+ return rewriter.notifyMatchFailure (
604
+ binder.op ,
605
+ " unsupported conversion: output_shape should be absent" );
606
+ }
607
+ Torch::ValueTensorType resultType;
608
+ Value input, weight;
609
+ int64_t group;
610
+ if (binder.tensorOperands (input, weight) ||
611
+ binder.s64IntegerAttr (group, " group" , 1 ) ||
612
+ binder.tensorResultType (resultType))
613
+ return failure ();
614
+
615
+ auto weightTensorType = weight.getType ().cast <Torch::ValueTensorType>();
616
+ if (!weightTensorType || !weightTensorType.hasSizes ()) {
617
+ return rewriter.notifyMatchFailure (
618
+ binder.op , " Expected weight type having sizes" );
619
+ }
620
+ ArrayRef<int64_t > weightShape = weightTensorType.getSizes ();
621
+ SmallVector<int64_t > kernelShape;
622
+ if (binder.s64IntegerArrayAttr (kernelShape, " kernel_shape" , {}))
623
+ return failure ();
624
+ if (kernelShape.size ()) {
625
+ if (kernelShape.size () != weightShape.size () - 2 ) {
626
+ return rewriter.notifyMatchFailure (
627
+ binder.op ,
628
+ " unsupported conversion: kernel_shape list size should have "
629
+ " number of values equal to weight_rank - 2" );
630
+ } else {
631
+ for (unsigned i = 0 ; i < kernelShape.size (); i++) {
632
+ if (weightShape[i + 2 ] != kernelShape[i]) {
633
+ return rewriter.notifyMatchFailure (
634
+ binder.op , " unsupported conversion: kernel_shape value "
635
+ " should be equal to the weight tensor shape" );
636
+ }
637
+ }
638
+ }
639
+ }
640
+
641
+ // Determine the rank of input tensor.
642
+ std::optional<unsigned > maybeRank = Torch::getTensorRank (input);
643
+ if (!maybeRank)
644
+ return rewriter.notifyMatchFailure (binder.op ,
645
+ " Unimplemented: unranked tensor" );
646
+ unsigned rank = *maybeRank;
647
+
648
+ SmallVector<int64_t > padding, strides, dilations, outputPadding;
649
+ SmallVector<int64_t > defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding;
650
+ for (unsigned i = 0 ; i < rank - 2 ; i++) {
651
+ defaultPadding.push_back (0 );
652
+ defaultStrides.push_back (1 );
653
+ defaultDilations.push_back (1 );
654
+ defaultOutputPadding.push_back (0 );
655
+ }
656
+ // Padding for the beginning and ending along each spatial axis, it can
657
+ // take any value greater than or equal to 0. The value represent the
658
+ // number of pixels added to the beginning and end part of the
659
+ // corresponding axis. pads format should be as follow [x1_begin,
660
+ // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
661
+ // at the beginning of axis i and xi_end, the number of pixels added at
662
+ // the end of axis i.
663
+ if (binder.s64IntegerArrayAttr (padding, " pads" , defaultPadding)) {
664
+ return failure ();
665
+ }
666
+ if (padding.size () != rank - 2 && padding.size () != 2 * (rank - 2 )) {
667
+ return rewriter.notifyMatchFailure (
668
+ binder.op , " padding list size does not match the number of axes" );
669
+ }
670
+ if (binder.s64IntegerArrayAttr (dilations, " dilations" ,
671
+ defaultDilations)) {
672
+ return failure ();
673
+ }
674
+ if (dilations.size () != rank - 2 ) {
675
+ return rewriter.notifyMatchFailure (
676
+ binder.op ,
677
+ " dilations list size does not match the number of axes" );
678
+ }
679
+ if (binder.s64IntegerArrayAttr (strides, " strides" , defaultStrides)) {
680
+ return failure ();
681
+ }
682
+ if (strides.size () != rank - 2 ) {
683
+ return rewriter.notifyMatchFailure (
684
+ binder.op , " strides list size does not match the number of axes" );
685
+ }
686
+ if (binder.s64IntegerArrayAttr (outputPadding, " output_padding" ,
687
+ defaultOutputPadding)) {
688
+ return failure ();
689
+ }
690
+ if (outputPadding.size () != rank - 2 ) {
691
+ return rewriter.notifyMatchFailure (
692
+ binder.op ,
693
+ " output_padding list size does not match the number of axes" );
694
+ }
695
+
696
+ SmallVector<Value> cstPadding, cstStrides, cstDilations,
697
+ cstOutputPadding;
698
+ if (padding.size () != 2 * (rank - 2 )) {
699
+ for (int64_t i : padding) {
700
+ cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
701
+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
702
+ }
703
+ } else {
704
+ for (unsigned i = 0 ; i < padding.size () / 2 ; i++) {
705
+ if (padding[i] != padding[i + (padding.size () / 2 )]) {
706
+ // TODO: Add support for different padding values for the
707
+ // beginning and ending along each spatial axis
708
+ return rewriter.notifyMatchFailure (
709
+ binder.op ,
710
+ " unsupported conversion: padding values for the beginning "
711
+ " and ending along each spatial axis must be equal" );
712
+ }
713
+ cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
714
+ binder.getLoc (), rewriter.getI64IntegerAttr (padding[i])));
715
+ }
716
+ }
717
+ for (int64_t i : dilations) {
718
+ cstDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
719
+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
720
+ }
721
+ for (int64_t i : strides) {
722
+ cstStrides.push_back (rewriter.create <Torch::ConstantIntOp>(
723
+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
724
+ }
725
+ for (int64_t i : outputPadding) {
726
+ cstOutputPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
727
+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
728
+ }
729
+
730
+ Value paddingList = rewriter.create <Torch::PrimListConstructOp>(
731
+ binder.getLoc (),
732
+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
733
+ cstPadding);
734
+ Value dilationsList = rewriter.create <Torch::PrimListConstructOp>(
735
+ binder.getLoc (),
736
+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
737
+ cstDilations);
738
+ Value stridesList = rewriter.create <Torch::PrimListConstructOp>(
739
+ binder.getLoc (),
740
+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
741
+ cstStrides);
742
+ Value outputPaddingList = rewriter.create <Torch::PrimListConstructOp>(
743
+ binder.getLoc (),
744
+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
745
+ cstOutputPadding);
746
+ Value transposed =
747
+ rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), true );
748
+ Value bias;
749
+ if (binder.op ->getNumOperands () == 3 ) {
750
+ if (binder.tensorOperandAtIndex (bias, 2 )) {
751
+ return failure ();
752
+ }
753
+ } else {
754
+ bias = rewriter.create <Torch::ConstantNoneOp>(binder.getLoc ());
755
+ }
756
+ Value cstGroup = rewriter.create <Torch::ConstantIntOp>(
757
+ binder.getLoc (), rewriter.getI64IntegerAttr (group));
758
+
759
+ rewriter.replaceOpWithNewOp <Torch::AtenConvolutionOp>(
760
+ binder.op , resultType, input, weight, bias, stridesList,
761
+ paddingList, dilationsList, transposed, outputPaddingList,
762
+ cstGroup);
763
+ return success ();
764
+ });
429
765
patterns.onOp (" Cos" , 7 ,
430
766
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
431
767
Torch::ValueTensorType resultType;
0 commit comments