@@ -1391,9 +1391,13 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
13911391 return success ();
13921392 }
13931393
1394- if (numSpatialDims != 2 )
1394+ if (numSpatialDims != 2 && numSpatialDims != 3 )
13951395 return rewriter.notifyMatchFailure (
1396- op, " unimplemented: only 1D and 2D grouped convolution supported" );
1396+ op, " unimplemented: only 2D and 3D grouped convolution supported" );
1397+ if (numSpatialDims == 3 && inputZp) {
1398+ return rewriter.notifyMatchFailure (
1399+ op, " unimplemented: quantized 3D grouped convolution not supported" );
1400+ }
13971401
13981402 // Grouped case, use the grouped conv linalg op
13991403 auto expandGroups = [&](Value tensor, size_t dim) {
@@ -1435,21 +1439,101 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
14351439 weight = transposed ? weight : expandWeight (weight);
14361440 auto expandOutputTensor = expandGroups (outputTensor, 1 );
14371441
1438- // TODO: add 1D and 3D case
1439- if (!inputZp) {
1440- conv = rewriter
1441- .create <linalg::Conv2DNgchwGfchwOp>(
1442- loc, expandOutputTensor.getResultType (),
1443- ValueRange{paddedInputExpanded, weight},
1444- expandOutputTensor.getResult (), stridesAttr, dilationAttr)
1445- .getResult (0 );
1446- } else {
1447- conv = rewriter
1448- .create <linalg::Conv2DNgchwGfchwQOp>(
1449- loc, expandOutputTensor.getResultType (),
1450- ValueRange{paddedInputExpanded, weight, inputZp, weightZp},
1451- expandOutputTensor.getResult (), stridesAttr, dilationAttr)
1452- .getResult (0 );
1442+ if (numSpatialDims == 2 ) {
1443+ // 2D grouped convolution
1444+ if (!inputZp) {
1445+ conv =
1446+ rewriter
1447+ .create <linalg::Conv2DNgchwGfchwOp>(
1448+ loc, expandOutputTensor.getResultType (),
1449+ ValueRange{paddedInputExpanded, weight},
1450+ expandOutputTensor.getResult (), stridesAttr, dilationAttr)
1451+ .getResult (0 );
1452+ } else {
1453+ conv =
1454+ rewriter
1455+ .create <linalg::Conv2DNgchwGfchwQOp>(
1456+ loc, expandOutputTensor.getResultType (),
1457+ ValueRange{paddedInputExpanded, weight, inputZp, weightZp},
1458+ expandOutputTensor.getResult (), stridesAttr, dilationAttr)
1459+ .getResult (0 );
1460+ }
1461+ } else if (numSpatialDims == 3 ) {
1462+ // MLIR does not have a named 3D grouped convolution op, so we use
1463+ // linalg.generic instead.
1464+ AffineExpr d0, d1, d2, d3, d4, d5, d6, d7, d8, d9;
1465+ bindDims (context, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9);
1466+
1467+ SmallVector<AffineExpr> inputExprs = {
1468+ d0, // N
1469+ d1, // G
1470+ d6, // C/G
1471+ d3 * strideInts[0 ] + d7 * dilationInts[0 ], // D
1472+ d4 * strideInts[1 ] + d8 * dilationInts[1 ], // H
1473+ d5 * strideInts[2 ] + d9 * dilationInts[2 ] // W
1474+ };
1475+
1476+ SmallVector<AffineExpr> weightExprs = {
1477+ d1, // G
1478+ d2, // F/G
1479+ d6, // C/G
1480+ d7, // KD
1481+ d8, // KH
1482+ d9 // KW
1483+ };
1484+
1485+ SmallVector<AffineExpr> outputExprs = {
1486+ d0, // N
1487+ d1, // G
1488+ d2, // F/G
1489+ d3, // OD
1490+ d4, // OH
1491+ d5, // OW
1492+ };
1493+
1494+ SmallVector<AffineMap> indexingMaps = {
1495+ AffineMap::get (10 , 0 , inputExprs, rewriter.getContext ()),
1496+ AffineMap::get (10 , 0 , weightExprs, rewriter.getContext ()),
1497+ AffineMap::get (10 , 0 , outputExprs, rewriter.getContext ())};
1498+
1499+ SmallVector<utils::IteratorType> iteratorTypes = {
1500+ utils::IteratorType::parallel, // N
1501+ utils::IteratorType::parallel, // G
1502+ utils::IteratorType::parallel, // F/G
1503+ utils::IteratorType::parallel, // OD
1504+ utils::IteratorType::parallel, // OH
1505+ utils::IteratorType::parallel, // OW
1506+ utils::IteratorType::reduction, // C/G
1507+ utils::IteratorType::reduction, // KD
1508+ utils::IteratorType::reduction, // KH
1509+ utils::IteratorType::reduction // KW
1510+ };
1511+
1512+ conv =
1513+ rewriter
1514+ .create <linalg::GenericOp>(
1515+ loc, expandOutputTensor.getResultType (),
1516+ ValueRange{paddedInputExpanded, weight},
1517+ expandOutputTensor.getResult (), indexingMaps, iteratorTypes,
1518+ [&](OpBuilder &b, Location loc, ValueRange args) {
1519+ Value input = args[0 ];
1520+ Value weight = args[1 ];
1521+ Value output = args[2 ];
1522+
1523+ // Convert input and weight to accumulator type if needed
1524+ Type accType = output.getType ();
1525+ if (input.getType () != accType) {
1526+ input = b.create <arith::ExtFOp>(loc, accType, input);
1527+ }
1528+ if (weight.getType () != accType) {
1529+ weight = b.create <arith::ExtFOp>(loc, accType, weight);
1530+ }
1531+
1532+ Value mul = b.create <arith::MulFOp>(loc, input, weight);
1533+ Value add = b.create <arith::AddFOp>(loc, mul, output);
1534+ b.create <linalg::YieldOp>(loc, add);
1535+ })
1536+ .getResult (0 );
14531537 }
14541538 conv = rewriter.create <tensor::CollapseShapeOp>(
14551539 loc, outputTensor.getType (), conv,
0 commit comments