Skip to content

Commit

Permalink
[MLIR][ONNX] Add OnnxToTorch support for Conv and ConvTranspose op.
Browse files Browse the repository at this point in the history
This commit adds the OnnxToTorch support for Conv and ConvTranspose op.

Signed-Off By: [email protected]
  • Loading branch information
vivekkhandelwal1 committed Dec 21, 2023
1 parent d75cff6 commit 3226241
Show file tree
Hide file tree
Showing 2 changed files with 466 additions and 0 deletions.
336 changes: 336 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,342 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
}
return failure();
});
patterns.onOp(
"Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
return failure();
if (autoPad != "NOTSET") {
// TODO: Add support for `auto_pad` != "NOTSET"
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: auto_pad != NOTSET");
}

Torch::ValueTensorType resultType;
Value input, weight;
int64_t group;
if (binder.tensorOperands(input, weight) ||
binder.s64IntegerAttr(group, "group", 1) ||
binder.tensorResultType(resultType))
return failure();

auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>();
if (!weightTensorType || !weightTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected weight type having sizes");
}
ArrayRef<int64_t> weightShape = weightTensorType.getSizes();
SmallVector<int64_t> kernelShape;
if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}))
return failure();
if (kernelShape.size()) {
if (kernelShape.size() != weightShape.size() - 2) {
return rewriter.notifyMatchFailure(
binder.op,
"unsupported conversion: kernel_shape list size should have "
"number of values equal to weight_rank - 2");
} else {
for (unsigned i = 0; i < kernelShape.size(); i++) {
if (weightShape[i + 2] != kernelShape[i]) {
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: kernel_shape value "
"should be equal to the weight tensor shape");
}
}
}
}

// Determine the rank of input tensor.
std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
if (!maybeRank)
return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: unranked tensor");
unsigned rank = *maybeRank;

SmallVector<int64_t> padding, strides, dilations;
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations;
for (unsigned i = 0; i < rank - 2; i++) {
defaultPadding.push_back(0);
defaultStrides.push_back(1);
defaultDilations.push_back(1);
}
// Padding for the beginning and ending along each spatial axis, it can
// take any value greater than or equal to 0. The value represent the
// number of pixels added to the beginning and end part of the
// corresponding axis. pads format should be as follow [x1_begin,
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
// at the beginning of axis i and xi_end, the number of pixels added at
// the end of axis i.
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
return failure();
}
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
return rewriter.notifyMatchFailure(
binder.op, "padding list size does not match the number of axes");
}
if (binder.s64IntegerArrayAttr(dilations, "dilations",
defaultDilations)) {
return failure();
}
if (dilations.size() != rank - 2) {
return rewriter.notifyMatchFailure(
binder.op,
"dilations list size does not match the number of axes");
}
if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) {
return failure();
}
if (strides.size() != rank - 2) {
return rewriter.notifyMatchFailure(
binder.op, "strides list size does not match the number of axes");
}

SmallVector<Value> cstPadding, cstStrides, cstDilations,
cstOutputPadding;
if (padding.size() != 2 * (rank - 2)) {
for (int64_t i : padding) {
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
} else {
for (unsigned i = 0; i < padding.size() / 2; i++) {
if (padding[i] != padding[i + (padding.size() / 2)]) {
// TODO: Add support for different padding values for the
// beginning and ending along each spatial axis
return rewriter.notifyMatchFailure(
binder.op,
"unsupported conversion: padding values for the beginning "
"and ending along each spatial axis must be equal");
}
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
}
}
for (int64_t i : dilations) {
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
for (int64_t i : strides) {
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
cstOutputPadding = {cstZero, cstZero};

Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstPadding);
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstDilations);
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstStrides);
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstOutputPadding);
Value transposed =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value bias;
if (binder.op->getNumOperands() == 3) {
if (binder.tensorOperandAtIndex(bias, 2)) {
return failure();
}
} else {
bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
}
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(group));

rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
binder.op, resultType, input, weight, bias, stridesList,
paddingList, dilationsList, transposed, outputPaddingList,
cstGroup);
return success();
});
patterns.onOp(
"ConvTranspose", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
return failure();
if (autoPad != "NOTSET") {
// TODO: Add support for `auto_pad` != "NOTSET"
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: auto_pad != NOTSET");
}
SmallVector<int64_t> outputShape;
if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {}))
return failure();
if (outputShape.size()) {
// TODO: Add support for non-None output_shape value.
return rewriter.notifyMatchFailure(
binder.op,
"unsupported conversion: output_shape should be absent");
}
Torch::ValueTensorType resultType;
Value input, weight;
int64_t group;
if (binder.tensorOperands(input, weight) ||
binder.s64IntegerAttr(group, "group", 1) ||
binder.tensorResultType(resultType))
return failure();

auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>();
if (!weightTensorType || !weightTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected weight type having sizes");
}
ArrayRef<int64_t> weightShape = weightTensorType.getSizes();
SmallVector<int64_t> kernelShape;
if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}))
return failure();
if (kernelShape.size()) {
if (kernelShape.size() != weightShape.size() - 2) {
return rewriter.notifyMatchFailure(
binder.op,
"unsupported conversion: kernel_shape list size should have "
"number of values equal to weight_rank - 2");
} else {
for (unsigned i = 0; i < kernelShape.size(); i++) {
if (weightShape[i + 2] != kernelShape[i]) {
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: kernel_shape value "
"should be equal to the weight tensor shape");
}
}
}
}

// Determine the rank of input tensor.
std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
if (!maybeRank)
return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: unranked tensor");
unsigned rank = *maybeRank;

SmallVector<int64_t> padding, strides, dilations, outputPadding;
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding;
for (unsigned i = 0; i < rank - 2; i++) {
defaultPadding.push_back(0);
defaultStrides.push_back(1);
defaultDilations.push_back(1);
defaultOutputPadding.push_back(0);
}
// Padding for the beginning and ending along each spatial axis, it can
// take any value greater than or equal to 0. The value represent the
// number of pixels added to the beginning and end part of the
// corresponding axis. pads format should be as follow [x1_begin,
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
// at the beginning of axis i and xi_end, the number of pixels added at
// the end of axis i.
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
return failure();
}
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
return rewriter.notifyMatchFailure(
binder.op, "padding list size does not match the number of axes");
}
if (binder.s64IntegerArrayAttr(dilations, "dilations",
defaultDilations)) {
return failure();
}
if (dilations.size() != rank - 2) {
return rewriter.notifyMatchFailure(
binder.op,
"dilations list size does not match the number of axes");
}
if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) {
return failure();
}
if (strides.size() != rank - 2) {
return rewriter.notifyMatchFailure(
binder.op, "strides list size does not match the number of axes");
}
if (binder.s64IntegerArrayAttr(outputPadding, "output_padding",
defaultOutputPadding)) {
return failure();
}
if (outputPadding.size() != rank - 2) {
return rewriter.notifyMatchFailure(
binder.op,
"output_padding list size does not match the number of axes");
}

SmallVector<Value> cstPadding, cstStrides, cstDilations,
cstOutputPadding;
if (padding.size() != 2 * (rank - 2)) {
for (int64_t i : padding) {
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
} else {
for (unsigned i = 0; i < padding.size() / 2; i++) {
if (padding[i] != padding[i + (padding.size() / 2)]) {
// TODO: Add support for different padding values for the
// beginning and ending along each spatial axis
return rewriter.notifyMatchFailure(
binder.op,
"unsupported conversion: padding values for the beginning "
"and ending along each spatial axis must be equal");
}
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
}
}
for (int64_t i : dilations) {
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
for (int64_t i : strides) {
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
for (int64_t i : outputPadding) {
cstOutputPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}

Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstPadding);
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstDilations);
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstStrides);
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstOutputPadding);
Value transposed =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value bias;
if (binder.op->getNumOperands() == 3) {
if (binder.tensorOperandAtIndex(bias, 2)) {
return failure();
}
} else {
bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
}
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(group));

rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
binder.op, resultType, input, weight, bias, stridesList,
paddingList, dilationsList, transposed, outputPaddingList,
cstGroup);
return success();
});
patterns.onOp("Cos", 7,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
Loading

0 comments on commit 3226241

Please sign in to comment.