Skip to content

Commit 6ebadc1

Browse files
IanWood1raayandhar
authored andcommitted
Add support for 3D Grouped Conv (#4354)
This PR adds support for 3D grouped convolutions by implementing the lowering from torch.aten.convolution to linalg.generic operations. A `linalg.generic` is used because there is no linalg named op for 3d grouped conv. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent cb1ad5c commit 6ebadc1

File tree

3 files changed

+203
-17
lines changed

3 files changed

+203
-17
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 101 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,9 +1391,13 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
13911391
return success();
13921392
}
13931393

1394-
if (numSpatialDims != 2)
1394+
if (numSpatialDims != 2 && numSpatialDims != 3)
13951395
return rewriter.notifyMatchFailure(
1396-
op, "unimplemented: only 1D and 2D grouped convolution supported");
1396+
op, "unimplemented: only 2D and 3D grouped convolution supported");
1397+
if (numSpatialDims == 3 && inputZp) {
1398+
return rewriter.notifyMatchFailure(
1399+
op, "unimplemented: quantized 3D grouped convolution not supported");
1400+
}
13971401

13981402
// Grouped case, use the grouped conv linalg op
13991403
auto expandGroups = [&](Value tensor, size_t dim) {
@@ -1435,21 +1439,101 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
14351439
weight = transposed ? weight : expandWeight(weight);
14361440
auto expandOutputTensor = expandGroups(outputTensor, 1);
14371441

1438-
// TODO: add 1D and 3D case
1439-
if (!inputZp) {
1440-
conv = rewriter
1441-
.create<linalg::Conv2DNgchwGfchwOp>(
1442-
loc, expandOutputTensor.getResultType(),
1443-
ValueRange{paddedInputExpanded, weight},
1444-
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
1445-
.getResult(0);
1446-
} else {
1447-
conv = rewriter
1448-
.create<linalg::Conv2DNgchwGfchwQOp>(
1449-
loc, expandOutputTensor.getResultType(),
1450-
ValueRange{paddedInputExpanded, weight, inputZp, weightZp},
1451-
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
1452-
.getResult(0);
1442+
if (numSpatialDims == 2) {
1443+
// 2D grouped convolution
1444+
if (!inputZp) {
1445+
conv =
1446+
rewriter
1447+
.create<linalg::Conv2DNgchwGfchwOp>(
1448+
loc, expandOutputTensor.getResultType(),
1449+
ValueRange{paddedInputExpanded, weight},
1450+
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
1451+
.getResult(0);
1452+
} else {
1453+
conv =
1454+
rewriter
1455+
.create<linalg::Conv2DNgchwGfchwQOp>(
1456+
loc, expandOutputTensor.getResultType(),
1457+
ValueRange{paddedInputExpanded, weight, inputZp, weightZp},
1458+
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
1459+
.getResult(0);
1460+
}
1461+
} else if (numSpatialDims == 3) {
1462+
// MLIR does not have a named 3D grouped convolution op, so we use
1463+
// linalg.generic instead.
1464+
AffineExpr d0, d1, d2, d3, d4, d5, d6, d7, d8, d9;
1465+
bindDims(context, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9);
1466+
1467+
SmallVector<AffineExpr> inputExprs = {
1468+
d0, // N
1469+
d1, // G
1470+
d6, // C/G
1471+
d3 * strideInts[0] + d7 * dilationInts[0], // D
1472+
d4 * strideInts[1] + d8 * dilationInts[1], // H
1473+
d5 * strideInts[2] + d9 * dilationInts[2] // W
1474+
};
1475+
1476+
SmallVector<AffineExpr> weightExprs = {
1477+
d1, // G
1478+
d2, // F/G
1479+
d6, // C/G
1480+
d7, // KD
1481+
d8, // KH
1482+
d9 // KW
1483+
};
1484+
1485+
SmallVector<AffineExpr> outputExprs = {
1486+
d0, // N
1487+
d1, // G
1488+
d2, // F/G
1489+
d3, // OD
1490+
d4, // OH
1491+
d5, // OW
1492+
};
1493+
1494+
SmallVector<AffineMap> indexingMaps = {
1495+
AffineMap::get(10, 0, inputExprs, rewriter.getContext()),
1496+
AffineMap::get(10, 0, weightExprs, rewriter.getContext()),
1497+
AffineMap::get(10, 0, outputExprs, rewriter.getContext())};
1498+
1499+
SmallVector<utils::IteratorType> iteratorTypes = {
1500+
utils::IteratorType::parallel, // N
1501+
utils::IteratorType::parallel, // G
1502+
utils::IteratorType::parallel, // F/G
1503+
utils::IteratorType::parallel, // OD
1504+
utils::IteratorType::parallel, // OH
1505+
utils::IteratorType::parallel, // OW
1506+
utils::IteratorType::reduction, // C/G
1507+
utils::IteratorType::reduction, // KD
1508+
utils::IteratorType::reduction, // KH
1509+
utils::IteratorType::reduction // KW
1510+
};
1511+
1512+
conv =
1513+
rewriter
1514+
.create<linalg::GenericOp>(
1515+
loc, expandOutputTensor.getResultType(),
1516+
ValueRange{paddedInputExpanded, weight},
1517+
expandOutputTensor.getResult(), indexingMaps, iteratorTypes,
1518+
[&](OpBuilder &b, Location loc, ValueRange args) {
1519+
Value input = args[0];
1520+
Value weight = args[1];
1521+
Value output = args[2];
1522+
1523+
// Convert input and weight to accumulator type if needed
1524+
Type accType = output.getType();
1525+
if (input.getType() != accType) {
1526+
input = b.create<arith::ExtFOp>(loc, accType, input);
1527+
}
1528+
if (weight.getType() != accType) {
1529+
weight = b.create<arith::ExtFOp>(loc, accType, weight);
1530+
}
1531+
1532+
Value mul = b.create<arith::MulFOp>(loc, input, weight);
1533+
Value add = b.create<arith::AddFOp>(loc, mul, output);
1534+
b.create<linalg::YieldOp>(loc, add);
1535+
})
1536+
.getResult(0);
14531537
}
14541538
conv = rewriter.create<tensor::CollapseShapeOp>(
14551539
loc, outputTensor.getType(), conv,

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2914,6 +2914,9 @@
29142914
"Conv3dModule_basic",
29152915
"Conv3dWithSamePaddingModule_basic",
29162916
"Conv3dWithValidPaddingModule_basic",
2917+
"ConvolutionModule3DGroups_basic",
2918+
"ConvolutionModule3DGroupsStrided_basic",
2919+
"ConvolutionModule3DGroupsDilated_basic",
29172920
"ConvTbcModule_basic",
29182921
"ConvTranspose2DQInt8_basic",
29192922
"Conv_Transpose2dModule_basic",
@@ -3721,6 +3724,9 @@
37213724
"ConvolutionModule2DTransposeStrided_basic",
37223725
"ConvolutionModule2DTranspose_basic",
37233726
"ConvolutionModule2DGroupedTranspose_basic",
3727+
"ConvolutionModule3DGroups_basic",
3728+
"ConvolutionModule3DGroupsStrided_basic",
3729+
"ConvolutionModule3DGroupsDilated_basic",
37243730
"CumsumInputDtypeInt32Module_basic",
37253731
"CumsumWithDtypeModule_basic",
37263732
"CumsumModule_basic",
@@ -4369,6 +4375,9 @@
43694375
"ConvolutionModule2DTransposeStrided_basic",
43704376
"ConvolutionModule2DTranspose_basic",
43714377
"ConvolutionModule2DGroupedTranspose_basic",
4378+
"ConvolutionModule3DGroups_basic",
4379+
"ConvolutionModule3DGroupsStrided_basic",
4380+
"ConvolutionModule3DGroupsDilated_basic",
43724381
"CopyModule_basic",
43734382
"CopyWithDifferentDTypesAndSizesModule_basic",
43744383
"CopyWithDifferentDTypesModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,99 @@ def ConvolutionModule2DGroups_basic(module, tu: TestUtils):
679679
module.forward(tu.rand(1, 32, 4, 4), tu.rand(32, 8, 3, 3))
680680

681681

682+
class ConvolutionModule3DGroups(torch.nn.Module):
683+
def __init__(self):
684+
super().__init__()
685+
686+
@export
687+
@annotate_args(
688+
[
689+
None,
690+
([-1, -1, -1, -1, -1], torch.float32, True),
691+
([-1, -1, -1, -1, -1], torch.float32, True),
692+
]
693+
)
694+
def forward(self, inputVec, weight):
695+
return torch.ops.aten.convolution(
696+
inputVec,
697+
weight,
698+
bias=None,
699+
stride=[1, 1, 1],
700+
padding=[0, 0, 0],
701+
dilation=[1, 1, 1],
702+
transposed=False,
703+
output_padding=[0, 0, 0],
704+
groups=2,
705+
)
706+
707+
708+
@register_test_case(module_factory=lambda: ConvolutionModule3DGroups())
709+
def ConvolutionModule3DGroups_basic(module, tu: TestUtils):
710+
module.forward(tu.rand(2, 4, 6, 6, 6), tu.rand(8, 2, 3, 3, 3))
711+
712+
713+
class ConvolutionModule3DGroupsStrided(torch.nn.Module):
714+
def __init__(self):
715+
super().__init__()
716+
717+
@export
718+
@annotate_args(
719+
[
720+
None,
721+
([-1, -1, -1, -1, -1], torch.float32, True),
722+
([-1, -1, -1, -1, -1], torch.float32, True),
723+
]
724+
)
725+
def forward(self, inputVec, weight):
726+
return torch.ops.aten.convolution(
727+
inputVec,
728+
weight,
729+
bias=None,
730+
stride=[2, 2, 2],
731+
padding=[1, 1, 1],
732+
dilation=[1, 1, 1],
733+
transposed=False,
734+
output_padding=[0, 0, 0],
735+
groups=4,
736+
)
737+
738+
739+
@register_test_case(module_factory=lambda: ConvolutionModule3DGroupsStrided())
740+
def ConvolutionModule3DGroupsStrided_basic(module, tu: TestUtils):
741+
module.forward(tu.rand(2, 8, 8, 8, 8), tu.rand(16, 2, 3, 3, 3))
742+
743+
744+
class ConvolutionModule3DGroupsDilated(torch.nn.Module):
745+
def __init__(self):
746+
super().__init__()
747+
748+
@export
749+
@annotate_args(
750+
[
751+
None,
752+
([-1, -1, -1, -1, -1], torch.float32, True),
753+
([-1, -1, -1, -1, -1], torch.float32, True),
754+
]
755+
)
756+
def forward(self, inputVec, weight):
757+
return torch.ops.aten.convolution(
758+
inputVec,
759+
weight,
760+
bias=None,
761+
stride=[1, 1, 1],
762+
padding=[2, 2, 2],
763+
dilation=[2, 2, 2],
764+
transposed=False,
765+
output_padding=[0, 0, 0],
766+
groups=2,
767+
)
768+
769+
770+
@register_test_case(module_factory=lambda: ConvolutionModule3DGroupsDilated())
771+
def ConvolutionModule3DGroupsDilated_basic(module, tu: TestUtils):
772+
module.forward(tu.rand(2, 4, 8, 8, 8), tu.rand(8, 2, 3, 3, 3))
773+
774+
682775
# ==============================================================================
683776

684777

0 commit comments

Comments
 (0)