Skip to content

Commit 3226241

Browse files
[MLIR][ONNX] Add OnnxToTorch support for Conv and ConvTranspose op.
This commit adds the OnnxToTorch support for Conv and ConvTranspose op. Signed-Off By: [email protected]
1 parent d75cff6 commit 3226241

File tree

2 files changed

+466
-0
lines changed

2 files changed

+466
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,342 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
426426
}
427427
return failure();
428428
});
429+
patterns.onOp(
430+
"Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
431+
std::string autoPad;
432+
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
433+
return failure();
434+
if (autoPad != "NOTSET") {
435+
// TODO: Add support for `auto_pad` != "NOTSET"
436+
return rewriter.notifyMatchFailure(
437+
binder.op, "unsupported conversion: auto_pad != NOTSET");
438+
}
439+
440+
Torch::ValueTensorType resultType;
441+
Value input, weight;
442+
int64_t group;
443+
if (binder.tensorOperands(input, weight) ||
444+
binder.s64IntegerAttr(group, "group", 1) ||
445+
binder.tensorResultType(resultType))
446+
return failure();
447+
448+
auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>();
449+
if (!weightTensorType || !weightTensorType.hasSizes()) {
450+
return rewriter.notifyMatchFailure(
451+
binder.op, "Expected weight type having sizes");
452+
}
453+
ArrayRef<int64_t> weightShape = weightTensorType.getSizes();
454+
SmallVector<int64_t> kernelShape;
455+
if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}))
456+
return failure();
457+
if (kernelShape.size()) {
458+
if (kernelShape.size() != weightShape.size() - 2) {
459+
return rewriter.notifyMatchFailure(
460+
binder.op,
461+
"unsupported conversion: kernel_shape list size should have "
462+
"number of values equal to weight_rank - 2");
463+
} else {
464+
for (unsigned i = 0; i < kernelShape.size(); i++) {
465+
if (weightShape[i + 2] != kernelShape[i]) {
466+
return rewriter.notifyMatchFailure(
467+
binder.op, "unsupported conversion: kernel_shape value "
468+
"should be equal to the weight tensor shape");
469+
}
470+
}
471+
}
472+
}
473+
474+
// Determine the rank of input tensor.
475+
std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
476+
if (!maybeRank)
477+
return rewriter.notifyMatchFailure(binder.op,
478+
"Unimplemented: unranked tensor");
479+
unsigned rank = *maybeRank;
480+
481+
SmallVector<int64_t> padding, strides, dilations;
482+
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations;
483+
for (unsigned i = 0; i < rank - 2; i++) {
484+
defaultPadding.push_back(0);
485+
defaultStrides.push_back(1);
486+
defaultDilations.push_back(1);
487+
}
488+
// Padding for the beginning and ending along each spatial axis, it can
489+
// take any value greater than or equal to 0. The value represent the
490+
// number of pixels added to the beginning and end part of the
491+
// corresponding axis. pads format should be as follow [x1_begin,
492+
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
493+
// at the beginning of axis i and xi_end, the number of pixels added at
494+
// the end of axis i.
495+
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
496+
return failure();
497+
}
498+
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
499+
return rewriter.notifyMatchFailure(
500+
binder.op, "padding list size does not match the number of axes");
501+
}
502+
if (binder.s64IntegerArrayAttr(dilations, "dilations",
503+
defaultDilations)) {
504+
return failure();
505+
}
506+
if (dilations.size() != rank - 2) {
507+
return rewriter.notifyMatchFailure(
508+
binder.op,
509+
"dilations list size does not match the number of axes");
510+
}
511+
if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) {
512+
return failure();
513+
}
514+
if (strides.size() != rank - 2) {
515+
return rewriter.notifyMatchFailure(
516+
binder.op, "strides list size does not match the number of axes");
517+
}
518+
519+
SmallVector<Value> cstPadding, cstStrides, cstDilations,
520+
cstOutputPadding;
521+
if (padding.size() != 2 * (rank - 2)) {
522+
for (int64_t i : padding) {
523+
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
524+
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
525+
}
526+
} else {
527+
for (unsigned i = 0; i < padding.size() / 2; i++) {
528+
if (padding[i] != padding[i + (padding.size() / 2)]) {
529+
// TODO: Add support for different padding values for the
530+
// beginning and ending along each spatial axis
531+
return rewriter.notifyMatchFailure(
532+
binder.op,
533+
"unsupported conversion: padding values for the beginning "
534+
"and ending along each spatial axis must be equal");
535+
}
536+
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
537+
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
538+
}
539+
}
540+
for (int64_t i : dilations) {
541+
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
542+
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
543+
}
544+
for (int64_t i : strides) {
545+
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
546+
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
547+
}
548+
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
549+
binder.getLoc(), rewriter.getI64IntegerAttr(0));
550+
cstOutputPadding = {cstZero, cstZero};
551+
552+
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
553+
binder.getLoc(),
554+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
555+
cstPadding);
556+
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
557+
binder.getLoc(),
558+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
559+
cstDilations);
560+
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
561+
binder.getLoc(),
562+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
563+
cstStrides);
564+
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
565+
binder.getLoc(),
566+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
567+
cstOutputPadding);
568+
Value transposed =
569+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
570+
Value bias;
571+
if (binder.op->getNumOperands() == 3) {
572+
if (binder.tensorOperandAtIndex(bias, 2)) {
573+
return failure();
574+
}
575+
} else {
576+
bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
577+
}
578+
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
579+
binder.getLoc(), rewriter.getI64IntegerAttr(group));
580+
581+
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
582+
binder.op, resultType, input, weight, bias, stridesList,
583+
paddingList, dilationsList, transposed, outputPaddingList,
584+
cstGroup);
585+
return success();
586+
});
587+
patterns.onOp(
588+
"ConvTranspose", 11,
589+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
590+
std::string autoPad;
591+
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
592+
return failure();
593+
if (autoPad != "NOTSET") {
594+
// TODO: Add support for `auto_pad` != "NOTSET"
595+
return rewriter.notifyMatchFailure(
596+
binder.op, "unsupported conversion: auto_pad != NOTSET");
597+
}
598+
SmallVector<int64_t> outputShape;
599+
if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {}))
600+
return failure();
601+
if (outputShape.size()) {
602+
// TODO: Add support for non-None output_shape value.
603+
return rewriter.notifyMatchFailure(
604+
binder.op,
605+
"unsupported conversion: output_shape should be absent");
606+
}
607+
Torch::ValueTensorType resultType;
608+
Value input, weight;
609+
int64_t group;
610+
if (binder.tensorOperands(input, weight) ||
611+
binder.s64IntegerAttr(group, "group", 1) ||
612+
binder.tensorResultType(resultType))
613+
return failure();
614+
615+
auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>();
616+
if (!weightTensorType || !weightTensorType.hasSizes()) {
617+
return rewriter.notifyMatchFailure(
618+
binder.op, "Expected weight type having sizes");
619+
}
620+
ArrayRef<int64_t> weightShape = weightTensorType.getSizes();
621+
SmallVector<int64_t> kernelShape;
622+
if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}))
623+
return failure();
624+
if (kernelShape.size()) {
625+
if (kernelShape.size() != weightShape.size() - 2) {
626+
return rewriter.notifyMatchFailure(
627+
binder.op,
628+
"unsupported conversion: kernel_shape list size should have "
629+
"number of values equal to weight_rank - 2");
630+
} else {
631+
for (unsigned i = 0; i < kernelShape.size(); i++) {
632+
if (weightShape[i + 2] != kernelShape[i]) {
633+
return rewriter.notifyMatchFailure(
634+
binder.op, "unsupported conversion: kernel_shape value "
635+
"should be equal to the weight tensor shape");
636+
}
637+
}
638+
}
639+
}
640+
641+
// Determine the rank of input tensor.
642+
std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
643+
if (!maybeRank)
644+
return rewriter.notifyMatchFailure(binder.op,
645+
"Unimplemented: unranked tensor");
646+
unsigned rank = *maybeRank;
647+
648+
SmallVector<int64_t> padding, strides, dilations, outputPadding;
649+
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding;
650+
for (unsigned i = 0; i < rank - 2; i++) {
651+
defaultPadding.push_back(0);
652+
defaultStrides.push_back(1);
653+
defaultDilations.push_back(1);
654+
defaultOutputPadding.push_back(0);
655+
}
656+
// Padding for the beginning and ending along each spatial axis, it can
657+
// take any value greater than or equal to 0. The value represent the
658+
// number of pixels added to the beginning and end part of the
659+
// corresponding axis. pads format should be as follow [x1_begin,
660+
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
661+
// at the beginning of axis i and xi_end, the number of pixels added at
662+
// the end of axis i.
663+
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
664+
return failure();
665+
}
666+
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
667+
return rewriter.notifyMatchFailure(
668+
binder.op, "padding list size does not match the number of axes");
669+
}
670+
if (binder.s64IntegerArrayAttr(dilations, "dilations",
671+
defaultDilations)) {
672+
return failure();
673+
}
674+
if (dilations.size() != rank - 2) {
675+
return rewriter.notifyMatchFailure(
676+
binder.op,
677+
"dilations list size does not match the number of axes");
678+
}
679+
if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) {
680+
return failure();
681+
}
682+
if (strides.size() != rank - 2) {
683+
return rewriter.notifyMatchFailure(
684+
binder.op, "strides list size does not match the number of axes");
685+
}
686+
if (binder.s64IntegerArrayAttr(outputPadding, "output_padding",
687+
defaultOutputPadding)) {
688+
return failure();
689+
}
690+
if (outputPadding.size() != rank - 2) {
691+
return rewriter.notifyMatchFailure(
692+
binder.op,
693+
"output_padding list size does not match the number of axes");
694+
}
695+
696+
SmallVector<Value> cstPadding, cstStrides, cstDilations,
697+
cstOutputPadding;
698+
if (padding.size() != 2 * (rank - 2)) {
699+
for (int64_t i : padding) {
700+
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
701+
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
702+
}
703+
} else {
704+
for (unsigned i = 0; i < padding.size() / 2; i++) {
705+
if (padding[i] != padding[i + (padding.size() / 2)]) {
706+
// TODO: Add support for different padding values for the
707+
// beginning and ending along each spatial axis
708+
return rewriter.notifyMatchFailure(
709+
binder.op,
710+
"unsupported conversion: padding values for the beginning "
711+
"and ending along each spatial axis must be equal");
712+
}
713+
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
714+
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
715+
}
716+
}
717+
for (int64_t i : dilations) {
718+
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
719+
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
720+
}
721+
for (int64_t i : strides) {
722+
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
723+
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
724+
}
725+
for (int64_t i : outputPadding) {
726+
cstOutputPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
727+
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
728+
}
729+
730+
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
731+
binder.getLoc(),
732+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
733+
cstPadding);
734+
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
735+
binder.getLoc(),
736+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
737+
cstDilations);
738+
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
739+
binder.getLoc(),
740+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
741+
cstStrides);
742+
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
743+
binder.getLoc(),
744+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
745+
cstOutputPadding);
746+
Value transposed =
747+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
748+
Value bias;
749+
if (binder.op->getNumOperands() == 3) {
750+
if (binder.tensorOperandAtIndex(bias, 2)) {
751+
return failure();
752+
}
753+
} else {
754+
bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
755+
}
756+
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
757+
binder.getLoc(), rewriter.getI64IntegerAttr(group));
758+
759+
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
760+
binder.op, resultType, input, weight, bias, stridesList,
761+
paddingList, dilationsList, transposed, outputPaddingList,
762+
cstGroup);
763+
return success();
764+
});
429765
patterns.onOp("Cos", 7,
430766
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
431767
Torch::ValueTensorType resultType;

0 commit comments

Comments
 (0)