diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 95fc744afa09..7fc683107219 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -105,7 +105,7 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter, FailureOr getConvBiasForNoneType(Operation *op, PatternRewriter &rewriter, Type inputElemTy, Type outputElemTy, - ArrayRef weightShape); + int64_t numOutputChannels); // Emit an explicit zero-valued `tosa.pad` around an NHWC tensor so that later // avg_pool lowering can run with `pad = 0`. `padExtents` is ordered as diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index bb4a4b75fc31..b429fb087a7f 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" @@ -26,6 +27,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include @@ -46,6 +48,183 @@ namespace mlir::torch { namespace { +static SmallVector permuteShape(ArrayRef originalShape, + ArrayRef permutation) { + SmallVector result; + result.reserve(permutation.size()); + for (int32_t dim : permutation) + result.push_back(originalShape[dim]); + return result; +} + +struct ZeroInsertionResult { + Value value; + bool trimmedTail; +}; + +static std::pair +getOrCreateConvZeroPoints(PatternRewriter &rewriter, Location loc, Value input, + Type inputElemTy, Value weight, Type weightElemTy) { + auto zps = tosa::createZPsAsConst(rewriter, input, weight); + Value inputZp = zps.first; + if (!inputZp) + inputZp = + tosa::createZeroPointTensor(rewriter, loc, inputElemTy, 0).value(); + Value weightZp = zps.second; + if (!weightZp) + weightZp = + tosa::createZeroPointTensor(rewriter, loc, weightElemTy, 0).value(); + return {inputZp, weightZp}; +} + +static FailureOr +insertZerosAlongAxis(Value input, int axis, int64_t stride, + ConversionPatternRewriter &rewriter, Location loc) { + if (stride == 1) + return ZeroInsertionResult{input, /*trimmedTail=*/true}; + + if (stride <= 0) + return failure(); + + auto inputType = dyn_cast(input.getType()); + if (!inputType) + return failure(); + + auto elementType = inputType.getElementType(); + // Work on a mutable copy of the shape since we insert/drop singleton dims + // and update the axis extent below. + SmallVector shape(inputType.getShape().begin(), + inputType.getShape().end()); + if (axis < 0 || axis >= static_cast(shape.size())) + return failure(); + + int64_t dim = shape[axis]; + // The slice at the end requires a static trip count, so we can only upsample + // axes with a known length when we later trim the padded tail. + if (stride > 1 && ShapedType::isDynamic(dim)) + return failure(); + + SmallVector expandedShape; + expandedShape.reserve(shape.size() + 1); + for (int i = 0; i < static_cast(shape.size()); ++i) { + expandedShape.push_back(shape[i]); + if (i == axis) + expandedShape.push_back(1); + } + + auto expandedType = RankedTensorType::get( + makeShapeLLVMCompatible(expandedShape), elementType); + Value reshapeToExpanded = tosa::ReshapeOp::create( + rewriter, loc, expandedType, input, + tosa::getTosaConstShape(rewriter, loc, expandedShape)); + + SmallVector paddedShape = expandedShape; + paddedShape[axis + 1] = stride; + SmallVector pads(2 * expandedShape.size(), 0); + pads[2 * (axis + 1) + 1] = stride - 1; + + Value padsConst = tosa::getTosaConstShape(rewriter, loc, pads); + + // Torch IR does not convey quantization params via tensor element types, so + // we use a literal zero here. Quantized frontends will insert the necessary + // rescale ops before we hit this lowering. + auto padValueOr = + tosa::createZeroPointTensor(rewriter, loc, elementType, /*zeroPoint=*/0); + if (!padValueOr.has_value()) + return failure(); + Value padValue = *padValueOr; + + auto paddedType = + RankedTensorType::get(makeShapeLLVMCompatible(paddedShape), elementType); + Value padded = tosa::PadOp::create(rewriter, loc, paddedType, + reshapeToExpanded, padsConst, padValue); + + SmallVector collapsedShape = shape; + collapsedShape[axis] = + ShapedType::isDynamic(dim) ? ShapedType::kDynamic : dim * stride; + auto collapsedType = RankedTensorType::get( + makeShapeLLVMCompatible(collapsedShape), elementType); + + Value result = tosa::ReshapeOp::create( + rewriter, loc, collapsedType, padded, + tosa::getTosaConstShape(rewriter, loc, collapsedShape)); + + bool trimmedTail = stride > 1; + if (stride > 1) { + // Padding adds (stride - 1) zeros after every element, so the collapsed + // tensor has an extra run of zeros at the tail. Slice those zeros to get + // the `(dim - 1) * stride + 1` elements that PyTorch expects. + int64_t trimmedLength = (dim - 1) * stride + 1; + if (trimmedLength < collapsedShape[axis]) { + SmallVector startIndices(collapsedShape.size(), 0); + SmallVector sliceSizes = collapsedShape; + sliceSizes[axis] = trimmedLength; + SmallVector trimmedShape = + llvm::to_vector(collapsedType.getShape()); + trimmedShape[axis] = trimmedLength; + auto trimmedType = RankedTensorType::get( + makeShapeLLVMCompatible(trimmedShape), elementType); + result = tosa::SliceOp::create( + rewriter, loc, trimmedType, result, + tosa::getTosaConstShape(rewriter, loc, startIndices), + tosa::getTosaConstShape(rewriter, loc, sliceSizes)); + } + + trimmedTail = true; + } + + return ZeroInsertionResult{result, trimmedTail}; +} + +static LogicalResult +getTorchToTosaPermutations(Location loc, int64_t rank, + SmallVectorImpl &torchToTosa, + SmallVectorImpl &tosaToTorch) { + if (rank < 3) + return emitError(loc) << "expected convolution tensor rank >= 3, got " + << rank; + + torchToTosa.clear(); + tosaToTorch.clear(); + + torchToTosa.push_back(0); // batch dim stays first + for (int64_t dim = 2; dim < rank; ++dim) + torchToTosa.push_back(dim); // spatial dims in order + torchToTosa.push_back(1); // channel moves to last position + + tosaToTorch.resize(torchToTosa.size()); + for (auto pair : llvm::enumerate(torchToTosa)) + tosaToTorch[pair.value()] = pair.index(); + + return success(); +} + +static LogicalResult +getTorchConvWeightPermutation(Location loc, int64_t rank, bool isTransposed, + SmallVectorImpl &permutation) { + if (rank < 3) + return emitError(loc) << "expected convolution weight rank >= 3, got " + << rank; + + permutation.clear(); + + if (!isTransposed) { + // Torch weight layout: [O, I, spatial...]; TOSA expects [O, spatial..., I]. + permutation.push_back(0); + for (int64_t dim = 2; dim < rank; ++dim) + permutation.push_back(dim); + permutation.push_back(1); + } else { + // Transposed layout: [I, O, spatial...] -> [O, spatial..., I]. + permutation.push_back(1); + for (int64_t dim = 2; dim < rank; ++dim) + permutation.push_back(dim); + permutation.push_back(0); + } + + return success(); +} + // These legalizations are for unary ops with promoting input to floating-point // datatypes only. There is no supported quantized integer mode for these. template @@ -2384,39 +2563,44 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto weightShape = makeShapeTorchCompatible(weightTy.getShape()); auto outputElemTy = outputTy.getElementType(); - if (inputTy.getRank() != 4) + int64_t inputRank = inputTy.getRank(); + int64_t weightRank = weightTy.getRank(); + int64_t outputRank = outputTy.getRank(); + + if (inputRank != weightRank || outputRank != inputRank) + return rewriter.notifyMatchFailure( + op, "Input, weight and output ranks must match for convolution"); + + if (inputRank != 4 && inputRank != 5) return rewriter.notifyMatchFailure( - op, "Unimplemented: only 2D convolutions supported"); + op, "Unimplemented: only 2D or 3D convolutions supported"); + + bool is3D = inputRank == 5; + int64_t spatialRank = inputRank - 2; if (!weightTy.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Unimplemented: TOSA only supports static weight"); - // Bias is optional. TOSA mandates a zero tensor here, so construct one if - // required. - auto bias = adaptor.getBias(); - - if (isa(bias.getType())) { - // ConvTranspose weights use IOHW; the helper expects OIHW, so swap - // dims 0/1 before we synthesize the bias. - SmallVector biasWeightShape = - transposed ? SmallVector{weightShape[1], weightShape[0], - weightShape[2], weightShape[3]} - : weightShape; - - auto biasResult = tosa::getConvBiasForNoneType( - op, rewriter, inputElemTy, outputElemTy, biasWeightShape); - if (failed(biasResult)) - return rewriter.notifyMatchFailure( - op, "Failed to create bias tensor for none type."); - bias = biasResult.value(); - } else { - if (!isa(bias.getType())) - return rewriter.notifyMatchFailure( + Value biasArg = adaptor.getBias(); + auto getOrCreateBias = [&](int64_t outChannels) -> FailureOr { + if (isa(biasArg.getType())) { + auto biasResult = tosa::getConvBiasForNoneType(op, rewriter, inputElemTy, + outputElemTy, outChannels); + if (failed(biasResult)) { + (void)rewriter.notifyMatchFailure( + op, "Failed to create bias tensor for none type."); + return failure(); + } + return biasResult.value(); + } + if (!isa(biasArg.getType())) { + (void)rewriter.notifyMatchFailure( op, "Bias provided but not a ranked tensor"); - } - - Type biasElemTy = cast(bias.getType()).getElementType(); + return failure(); + } + return biasArg; + }; int64_t groups; if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) { @@ -2427,26 +2611,38 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "(depthwise convolution)"); } - SmallVector stride; + SmallVector stride; if (!matchPattern(adaptor.getStride(), m_TorchListOfConstantInts(stride))) return rewriter.notifyMatchFailure(op, "non-const stride list unsupported"); + if (static_cast(stride.size()) != spatialRank) + return rewriter.notifyMatchFailure(op, "stride rank mismatch"); - SmallVector padding_2d; + SmallVector paddingList; if (!matchPattern(adaptor.getPadding(), - m_TorchListOfConstantInts(padding_2d))) + m_TorchListOfConstantInts(paddingList))) return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); + if (static_cast(paddingList.size()) != spatialRank) + return rewriter.notifyMatchFailure(op, "padding rank mismatch"); + // TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D // padding {height, width}. The PyTorch OFM computation uses 2*pad in each // spatial direction, implying the same top=bottom=height and left=right=width // values for TOSA. - SmallVector padding( - {padding_2d[0], padding_2d[0], padding_2d[1], padding_2d[1]}); + SmallVector padding; + if (is3D) { + padding = {paddingList[0], paddingList[0], paddingList[1], + paddingList[1], paddingList[2], paddingList[2]}; + } else { + padding = {paddingList[0], paddingList[0], paddingList[1], paddingList[1]}; + } - SmallVector dilation; + SmallVector dilation; if (!matchPattern(adaptor.getDilation(), m_TorchListOfConstantInts(dilation))) return rewriter.notifyMatchFailure(op, "non-const dilation list unsupported"); + if (static_cast(dilation.size()) != spatialRank) + return rewriter.notifyMatchFailure(op, "dilation rank mismatch"); TypeAttr accType; if (failed(tosa::getConvOpsAccType(rewriter, inputTy, weightTy, outputTy, @@ -2454,27 +2650,57 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "failed to get accumulator type for convolution ops"); - // Weight layout reference: - // Conv : PyTorch OIHW -> TOSA OHWI - // Depthwise : PyTorch OIHW* -> TOSA HWIM - // (PyTorch depthwise uses out_ch=in_ch*depth_multiplier) - // Grouped : PyTorch O(I/G)HW -> N/A - // Transposed : PyTorch IOHW -> TOSA OHWI - // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. - // Perform the necessary transformations. - SmallVector nchwToNhwcDims({0, 2, 3, 1}); - SmallVector nhwcToNchwDims({0, 3, 1, 2}); - SmallVector transposedInputShape; - for (int32_t dim : nchwToNhwcDims) - transposedInputShape.push_back(inputShape[dim]); + // TOSA works in NHWC (2D) / NDHWC (3D) and takes OHWI / ODHWI weights for + // convolution. Perform the necessary transformations. + SmallVector torchToTosaDims; + SmallVector tosaToTorchDims; + if (failed(getTorchToTosaPermutations(op->getLoc(), inputRank, + torchToTosaDims, tosaToTorchDims))) + return rewriter.notifyMatchFailure(op, + "unsupported convolution input rank"); + + SmallVector transposedInputShape = + permuteShape(inputShape, torchToTosaDims); auto transposedInputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedInputShape), inputElemTy); - auto createTransposedInput = [&]() { - return tosa::TransposeOp::create( - rewriter, op->getLoc(), - getTypeConverter()->convertType(transposedInputType), input, - rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) - .getResult(); + Value transposedInput = + tosa::TransposeOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(transposedInputType), input, + rewriter.getDenseI32ArrayAttr(torchToTosaDims)) + .getResult(); + + auto adjustSpatialDim = [&](SmallVector &shapeVec, Value &tensor, + int axis, int padBeforeIdx, int padAfterIdx, + int64_t weightDim, int64_t strideVal, + int64_t dilationVal, + int64_t &outputDim) -> LogicalResult { + int nhwcAxis = axis + 1; + int64_t inputDim = shapeVec[nhwcAxis]; + int64_t fullDim = inputDim + padding[padBeforeIdx] + padding[padAfterIdx] - + dilationVal * (weightDim - 1) - 1; + int64_t remainder = fullDim % strideVal; + if (remainder != 0) { + if (remainder > padding[padAfterIdx]) { + SmallVector startSlice(shapeVec.size(), 0); + SmallVector sizeSlice = shapeVec; + sizeSlice[nhwcAxis] = inputDim - (remainder - padding[padAfterIdx]); + tensor = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy), + tensor, tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice)); + if (auto updatedType = dyn_cast(tensor.getType())) + shapeVec = llvm::to_vector(updatedType.getShape()); + fullDim = fullDim - padding[padAfterIdx]; + padding[padAfterIdx] = 0; + } else { + fullDim = fullDim - padding[padAfterIdx]; + padding[padAfterIdx] = padding[padAfterIdx] - remainder; + fullDim = fullDim + padding[padAfterIdx]; + } + } + outputDim = fullDim / strideVal + 1; + return success(); }; if (transposed) { @@ -2482,26 +2708,176 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Unimplemented: grouped transposed convolution not supported by " "TOSA"); - if (dilation[0] != 1 || dilation[1] != 1) + + SmallVector outPaddingList; + if (!matchPattern(adaptor.getOutputPadding(), + m_TorchListOfConstantInts(outPaddingList))) return rewriter.notifyMatchFailure( - op, "Unimplemented: dilated transposed convolution not supported by " - "TOSA"); + op, "non-const output_padding list unsupported for transposed conv"); + if (static_cast(outPaddingList.size()) != spatialRank) + return rewriter.notifyMatchFailure(op, "output_padding rank mismatch"); - SmallVector iohwToOhwi({1, 2, 3, 0}); + SmallVector transposedWeightPermutation; + if (failed(getTorchConvWeightPermutation(op->getLoc(), weightRank, + /*isTransposed=*/true, + transposedWeightPermutation))) + return rewriter.notifyMatchFailure( + op, "unsupported convolution weight rank for transpose conv"); + + if (is3D) { + if (groups != 1) + return rewriter.notifyMatchFailure( + op, "Unimplemented: grouped transposed 3D convolution not " + "supported by TOSA"); + + SmallVector transformedWeightShape = + permuteShape(weightShape, transposedWeightPermutation); + auto transformedWeightType = RankedTensorType::get( + makeShapeLLVMCompatible(transformedWeightShape), weightElemTy); + Value transformedWeight = + tosa::TransposeOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(transformedWeightType), weight, + rewriter.getDenseI32ArrayAttr(transposedWeightPermutation)) + .getResult(); + + // Reverse spatial dims of the kernel. + Value flippedWeight = transformedWeight; + for (int reverseAxis = 1; reverseAxis <= 3; ++reverseAxis) { + flippedWeight = tosa::ReverseOp::create( + rewriter, op->getLoc(), flippedWeight.getType(), flippedWeight, + rewriter.getI32IntegerAttr(reverseAxis)); + } + + Value upsampledInput = transposedInput; + SmallVector tailTrimmed(/*N=*/3, /*Value=*/true); + Location loc = op->getLoc(); + for (int axis = 0; axis < 3; ++axis) { + auto insertedResult = insertZerosAlongAxis(upsampledInput, axis + 1, + stride[axis], rewriter, loc); + if (failed(insertedResult)) + return rewriter.notifyMatchFailure( + op, "Unsupported parameters for transposed 3D convolution"); + upsampledInput = insertedResult->value; + tailTrimmed[axis] = insertedResult->trimmedTail; + } + + auto upsampledType = cast(upsampledInput.getType()); + SmallVector upsampledShape = + llvm::to_vector(upsampledType.getShape()); + + SmallVector padVec(2 * upsampledShape.size(), 0); + SmallVector paddedShape = upsampledShape; + for (int axis = 0; axis < 3; ++axis) { + int spatialIndex = axis + 1; // NDHWC ordering + int64_t kernel = weightShape[2 + axis]; + int64_t dil = dilation[axis]; + int64_t pad = paddingList[axis]; + int64_t outPad = outPaddingList[axis]; + int64_t before = dil * (kernel - 1) - pad; + int64_t after = before + outPad; + if (!tailTrimmed[axis]) { + after -= (stride[axis] - 1); + } + if (before < 0 || after < 0) + return rewriter.notifyMatchFailure( + op, "Unsupported padding combination for transposed 3D " + "convolution"); + padVec[2 * spatialIndex] = before; + padVec[2 * spatialIndex + 1] = after; + if (!ShapedType::isDynamic(paddedShape[spatialIndex])) + paddedShape[spatialIndex] += before + after; + else + paddedShape[spatialIndex] = ShapedType::kDynamic; + } + + Value paddedInput = upsampledInput; + if (llvm::any_of(padVec, [](int64_t v) { return v != 0; })) { + Value padsConst = tosa::getTosaConstShape(rewriter, loc, padVec); + Value padTensor = + tosa::createZeroPointTensor(rewriter, loc, inputElemTy, 0).value(); + auto paddedType = RankedTensorType::get( + makeShapeLLVMCompatible(paddedShape), inputElemTy); + paddedInput = tosa::PadOp::create(rewriter, loc, paddedType, + upsampledInput, padsConst, padTensor); + } + + auto biasValueOr = getOrCreateBias(weightShape[1]); + if (failed(biasValueOr)) + return failure(); + Value bias = *biasValueOr; + Type biasElemTy = cast(bias.getType()).getElementType(); + + auto outTorchShape = makeShapeTorchCompatible(outputTy.getShape()); + SmallVector outTosaShape = + permuteShape(outTorchShape, torchToTosaDims); + auto convOpTy = RankedTensorType::get( + makeShapeLLVMCompatible(outTosaShape), biasElemTy); + + auto [inputZp, weightZp] = getOrCreateConvZeroPoints( + rewriter, loc, input, inputElemTy, weight, weightElemTy); + + auto convResult = + tosa::Conv3DOp::create( + rewriter, loc, getTypeConverter()->convertType(convOpTy), + paddedInput, flippedWeight, bias, inputZp, weightZp, + rewriter.getDenseI64ArrayAttr({0, 0, 0, 0, 0, 0}), + rewriter.getDenseI64ArrayAttr({1, 1, 1}), + rewriter.getDenseI64ArrayAttr(dilation), accType) + .getResult(); + + SmallVector transposedOutputShape = + permuteShape(outTosaShape, tosaToTorchDims); + auto transposedOutputType = RankedTensorType::get( + makeShapeLLVMCompatible(transposedOutputShape), biasElemTy); + Value transposedOutput = + tosa::TransposeOp::create( + rewriter, loc, + getTypeConverter()->convertType(transposedOutputType), convResult, + rewriter.getDenseI32ArrayAttr(tosaToTorchDims)) + .getResult(); + + Value rescaledResult = transposedOutput; + if (isa(inputElemTy)) { + rescaledResult = tosa::buildRescaleOpConvOutput( + rewriter, op, transposedOutput, inputTy, weightTy, outputTy); + } + + rewriter.replaceOp( + op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy) + .value()}); + return success(); + } + + if (dilation[0] != 1 || dilation[1] != 1) + return rewriter.notifyMatchFailure(op, + "Unimplemented: dilated transposed " + "convolution not supported by TOSA"); + + auto biasValueOr = getOrCreateBias(weightShape[1]); + if (failed(biasValueOr)) + return failure(); + Value bias = *biasValueOr; + Type biasElemTy = cast(bias.getType()).getElementType(); + + SmallVector ohwiWeightShape = + permuteShape(weightShape, transposedWeightPermutation); + auto ohwiWeightType = RankedTensorType::get( + makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy); + Value transformedWeight = + tosa::TransposeOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(ohwiWeightType), weight, + rewriter.getDenseI32ArrayAttr(transposedWeightPermutation)) + .getResult(); // TOSA 'out_pad' is a 4D array {top,bottom,left,right}. // Map from PyTorch's (padding, output_padding): // out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W) // Negative values are allowed and will be handled by the TOSA // decomposition. - SmallVector outPadding2D; - if (!matchPattern(adaptor.getOutputPadding(), - m_TorchListOfConstantInts(outPadding2D))) - return rewriter.notifyMatchFailure( - op, "non-const output_padding list unsupported for transposed conv"); - - int64_t outPadH = outPadding2D[0] - 2 * padding_2d[0]; - int64_t outPadW = outPadding2D[1] - 2 * padding_2d[1]; + int64_t outPadH = outPaddingList[0] - 2 * paddingList[0]; + int64_t outPadW = outPaddingList[1] - 2 * paddingList[1]; int64_t outPadTop = outPadH / 2; int64_t outPadBottom = outPadH - outPadTop; int64_t outPadLeft = outPadW / 2; @@ -2509,59 +2885,44 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector outPad( {outPadTop, outPadBottom, outPadLeft, outPadRight}); - Value nhwcInput = createTransposedInput(); - SmallVector ohwiWeightShape; - for (int32_t dim : iohwToOhwi) - ohwiWeightShape.push_back(weightShape[dim]); - auto ohwiWeightType = RankedTensorType::get( - makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy); - Value transformedWeight = - tosa::TransposeOp::create( - rewriter, op->getLoc(), - getTypeConverter()->convertType(ohwiWeightType), weight, - rewriter.getDenseI32ArrayAttr(iohwToOhwi)) - .getResult(); - // Result type is NHWC (we'll transpose back). - auto outNCHW = makeShapeTorchCompatible(outputTy.getShape()); - SmallVector outNHWC; - for (int32_t dim : nchwToNhwcDims) - outNHWC.push_back(outNCHW[dim]); - auto transConvOpTy = - RankedTensorType::get(makeShapeLLVMCompatible(outNHWC), biasElemTy); - - // Zero-points. - auto zps = tosa::createZPsAsConst(rewriter, input, weight); - Value inputZp = zps.first ? zps.first - : tosa::createZeroPointTensor( - rewriter, op->getLoc(), inputElemTy, 0) - .value(); - Value weightZp = zps.second ? zps.second - : tosa::createZeroPointTensor( - rewriter, op->getLoc(), weightElemTy, 0) - .value(); + auto outTorchShape = makeShapeTorchCompatible(outputTy.getShape()); + SmallVector outTosaShape = + permuteShape(outTorchShape, torchToTosaDims); + auto transConvOpTy = RankedTensorType::get( + makeShapeLLVMCompatible(outTosaShape), biasElemTy); + + // Zero-points (same helpers as conv path). + auto [inputZp, weightZp] = getOrCreateConvZeroPoints( + rewriter, op->getLoc(), input, inputElemTy, weight, weightElemTy); + // Build tosa.transpose_conv2d. Value convTOut = tosa::TransposeConv2DOp::create( rewriter, op->getLoc(), getTypeConverter()->convertType(transConvOpTy), - nhwcInput, transformedWeight, bias, inputZp, weightZp, - rewriter.getDenseI64ArrayAttr(outPad), - rewriter.getDenseI64ArrayAttr(stride), accType) + /*input*/ transposedInput, + /*weight*/ transformedWeight, + /*bias*/ bias, + /*input_zp*/ inputZp, + /*weight_zp*/ weightZp, + /*out_pad*/ rewriter.getDenseI64ArrayAttr(outPad), + /*stride*/ rewriter.getDenseI64ArrayAttr(stride), + /*acc_type*/ accType) .getResult(); - SmallVector transposedOutputShape; - for (int32_t dim : nhwcToNchwDims) - transposedOutputShape.push_back(outNHWC[dim]); + // NHWC -> NCHW + SmallVector transposedOutputShape = + permuteShape(outTosaShape, tosaToTorchDims); auto transposedOutputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedOutputShape), biasElemTy); Value transposedOutput = tosa::TransposeOp::create( rewriter, op->getLoc(), getTypeConverter()->convertType(transposedOutputType), convTOut, - rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) + rewriter.getDenseI32ArrayAttr(tosaToTorchDims)) .getResult(); - // Quantized rescale. + // Quantized rescale (reuse existing helper). Value rescaledResult = transposedOutput; if (isa(inputElemTy)) { rescaledResult = tosa::buildRescaleOpConvOutput( @@ -2574,25 +2935,30 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value()}); return success(); } - SmallVector transformedWeightShape; RankedTensorType transformedWeightType; Value transformedWeight; int64_t outputCDim; + SmallVector weightPermutation; + if (failed(getTorchConvWeightPermutation(op->getLoc(), weightRank, + /*isTransposed=*/false, + weightPermutation))) + return rewriter.notifyMatchFailure(op, + "unsupported convolution weight rank"); + if (groups == 1) { - // full convolution: O(I/G)HW-> OHWI - transformedWeightShape = {weightShape[0], weightShape[2], weightShape[3], - weightShape[1]}; + // full convolution: Torch: O(I/G)spatial -> TOSA: O spatial I + transformedWeightShape = permuteShape(weightShape, weightPermutation); transformedWeightType = RankedTensorType::get( makeShapeLLVMCompatible(transformedWeightShape), weightElemTy); transformedWeight = tosa::TransposeOp::create( rewriter, op->getLoc(), getTypeConverter()->convertType(transformedWeightType), weight, - rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) + rewriter.getDenseI32ArrayAttr(weightPermutation)) .getResult(); outputCDim = transformedWeightShape[0]; - } else if (weightShape[1] == 1) { + } else if (!is3D && weightShape[1] == 1) { // depthwise convolution: O(I/G)HW-> HWIM) // transpose: O(I/G)HW -> HWO(I/G) SmallVector transposedDims({2, 3, 0, 1}); @@ -2633,114 +2999,168 @@ LogicalResult ConvertAtenOp::matchAndRewrite( transformedWeightShape)) .getResult(); } else { - llvm_unreachable("Unhandled convolution type"); - } - - Value transposedInput = createTransposedInput(); - - int64_t outputHDim, outputWDim; - int64_t inputHDim = inputShape[2]; - int64_t inputWDim = inputShape[3]; - - bool isStaticSpatialDims = - !ShapedType::isDynamic(inputHDim) && !ShapedType::isDynamic(inputWDim); - if (isStaticSpatialDims) { - - int64_t weightHDim = weightShape[2]; - int64_t weightWDim = weightShape[3]; - - // fullDim = - // inputDim + padBefore + padAfter - dilation * (weightDim - 1) - 1 - // According to TOSA spec: - // https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d, fullDim values - // must be divisible by stride values. - int64_t fullHDim = inputHDim + padding[0] + padding[1] - - dilation[0] * (weightHDim - 1) - 1; - int64_t remainderHDim = fullHDim % stride[0]; - if (remainderHDim != 0) { - if (remainderHDim > padding[1]) { - SmallVector startHSlice(inputTy.getRank(), 0); - SmallVector sizeHSlice(transposedInputShape); - // TOSA uses NHWC, so we will slice dim 1 for Height value - sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]); - transposedInput = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy), - transposedInput, - tosa::getTosaConstShape(rewriter, op->getLoc(), startHSlice), - tosa::getTosaConstShape(rewriter, op->getLoc(), sizeHSlice)); - fullHDim = fullHDim - padding[1]; - padding[1] = 0; - } else { - fullHDim = fullHDim - padding[1]; - padding[1] = padding[1] - remainderHDim; - fullHDim = fullHDim + padding[1]; + return rewriter.notifyMatchFailure( + op, is3D ? "Unimplemented: grouped or depthwise 3D convolution " + "not supported by TOSA" + : "Unhandled convolution type"); + } + + SmallVector outputShape; + if (!is3D) { + int64_t outputHDim, outputWDim; + int64_t inputHDim = inputShape[2]; + int64_t inputWDim = inputShape[3]; + + bool isStaticSpatialDims = + !ShapedType::isDynamic(inputHDim) && !ShapedType::isDynamic(inputWDim); + if (isStaticSpatialDims) { + int64_t weightHDim = weightShape[2]; + int64_t weightWDim = weightShape[3]; + + // fullDim = + // inputDim + padBefore + padAfter - dilation * (weightDim - 1) - 1 + // According to TOSA spec fullDim values must be divisible by the + // stride values. + int64_t fullHDim = inputHDim + padding[0] + padding[1] - + dilation[0] * (weightHDim - 1) - 1; + int64_t remainderHDim = fullHDim % stride[0]; + if (remainderHDim != 0) { + if (remainderHDim > padding[1]) { + SmallVector startHSlice(inputTy.getRank(), 0); + SmallVector sizeHSlice(transposedInputShape); + // TOSA uses NHWC, so slice dim 1 for Height. + sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]); + transposedInput = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy), + transposedInput, + tosa::getTosaConstShape(rewriter, op->getLoc(), startHSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), sizeHSlice)); + fullHDim = fullHDim - padding[1]; + padding[1] = 0; + } else { + fullHDim = fullHDim - padding[1]; + padding[1] = padding[1] - remainderHDim; + fullHDim = fullHDim + padding[1]; + } } - } - outputHDim = fullHDim / stride[0] + 1; - - int64_t fullWDim = inputWDim + padding[2] + padding[3] - - dilation[1] * (weightWDim - 1) - 1; - int64_t remainderWDim = fullWDim % stride[1]; - if (remainderWDim != 0) { - if (remainderWDim > padding[3]) { - SmallVector startWSlice(inputTy.getRank(), 0); - SmallVector sizeWSlice( - dyn_cast(transposedInput.getType()).getShape()); - // TOSA uses NHWC, so we will slice dim 2 for Width value - sizeWSlice[2] = inputWDim - (remainderWDim - padding[3]); - transposedInput = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy), - transposedInput, - tosa::getTosaConstShape(rewriter, op->getLoc(), startWSlice), - tosa::getTosaConstShape(rewriter, op->getLoc(), sizeWSlice)); - fullHDim = fullHDim - padding[3]; - padding[3] = 0; - } else { - fullWDim = fullWDim - padding[3]; - padding[3] = padding[3] - remainderWDim; - fullWDim = fullWDim + padding[3]; + outputHDim = fullHDim / stride[0] + 1; + + int64_t fullWDim = inputWDim + padding[2] + padding[3] - + dilation[1] * (weightWDim - 1) - 1; + int64_t remainderWDim = fullWDim % stride[1]; + if (remainderWDim != 0) { + if (remainderWDim > padding[3]) { + SmallVector startWSlice(inputTy.getRank(), 0); + SmallVector sizeWSlice( + dyn_cast(transposedInput.getType()).getShape()); + // TOSA uses NHWC, so slice dim 2 for Width. + sizeWSlice[2] = inputWDim - (remainderWDim - padding[3]); + transposedInput = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy), + transposedInput, + tosa::getTosaConstShape(rewriter, op->getLoc(), startWSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), sizeWSlice)); + fullWDim = fullWDim - padding[3]; + padding[3] = 0; + } else { + fullWDim = fullWDim - padding[3]; + padding[3] = padding[3] - remainderWDim; + fullWDim = fullWDim + padding[3]; + } } + outputWDim = fullWDim / stride[1] + 1; + } else { + outputHDim = kUnknownSize; + outputWDim = kUnknownSize; } - outputWDim = fullWDim / stride[1] + 1; - } else { - outputHDim = kUnknownSize; - outputWDim = kUnknownSize; - } - - // Output shape is NHWC, to be transposed back to NCHW. Output elemTy for - // quantized input is i32, which gets rescaled down to quantized output range. - SmallVector outputShape = {transposedInputShape[0], outputHDim, - outputWDim, outputCDim}; - auto convOpTy = - RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); - // create zero-point tensors for input and weight - auto zps = tosa::createZPsAsConst(rewriter, input, weight); - // for i8 input/weight, zero-points are returned as un-initialized - Value inputZp = - zps.first - ? zps.first - : tosa::createZeroPointTensor(rewriter, op->getLoc(), inputElemTy, 0) - .value(); + outputShape = {transposedInputShape[0], outputHDim, outputWDim, outputCDim}; + } else { + int64_t outputDDim, outputHDim, outputWDim; + int64_t inputDDim = inputShape[2]; + int64_t inputHDim = inputShape[3]; + int64_t inputWDim = inputShape[4]; + + bool isStaticSpatialDims = !ShapedType::isDynamic(inputDDim) && + !ShapedType::isDynamic(inputHDim) && + !ShapedType::isDynamic(inputWDim); + SmallVector currentShape = transposedInputShape; + + if (isStaticSpatialDims) { + int64_t weightDDim = weightShape[2]; + int64_t weightHDim = weightShape[3]; + int64_t weightWDim = weightShape[4]; + + if (failed(adjustSpatialDim(currentShape, transposedInput, + /*axis=*/0, /*padBeforeIdx=*/0, + /*padAfterIdx=*/1, weightDDim, stride[0], + dilation[0], outputDDim))) + return failure(); + + if (failed(adjustSpatialDim(currentShape, transposedInput, + /*axis=*/1, /*padBeforeIdx=*/2, + /*padAfterIdx=*/3, weightHDim, stride[1], + dilation[1], outputHDim))) + return failure(); + + if (failed(adjustSpatialDim(currentShape, transposedInput, + /*axis=*/2, /*padBeforeIdx=*/4, + /*padAfterIdx=*/5, weightWDim, stride[2], + dilation[2], outputWDim))) + return failure(); + } else { + outputDDim = kUnknownSize; + outputHDim = kUnknownSize; + outputWDim = kUnknownSize; + } - Value weightZp = - zps.second - ? zps.second - : tosa::createZeroPointTensor(rewriter, op->getLoc(), weightElemTy, 0) - .value(); + outputShape = {currentShape[0], outputDDim, outputHDim, outputWDim, + outputCDim}; + } + auto [inputZp, weightZp] = getOrCreateConvZeroPoints( + rewriter, op->getLoc(), input, inputElemTy, weight, weightElemTy); + Value bias; + Type biasElemTy; Value convOpResult; if (groups == 1) { + auto biasValueOr = getOrCreateBias(weightShape[0]); + if (failed(biasValueOr)) + return failure(); + bias = *biasValueOr; + biasElemTy = cast(bias.getType()).getElementType(); + + auto convOpTy = + RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); // full convolution - convOpResult = - tosa::Conv2DOp::create( - rewriter, op->getLoc(), getTypeConverter()->convertType(convOpTy), - transposedInput, transformedWeight, bias, inputZp, weightZp, - rewriter.getDenseI64ArrayAttr(padding), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr(dilation), accType) - .getResult(); - } else if (weightShape[1] == 1) { + if (is3D) { + convOpResult = + tosa::Conv3DOp::create( + rewriter, op->getLoc(), getTypeConverter()->convertType(convOpTy), + transposedInput, transformedWeight, bias, inputZp, weightZp, + rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation), accType) + .getResult(); + } else { + convOpResult = + tosa::Conv2DOp::create( + rewriter, op->getLoc(), getTypeConverter()->convertType(convOpTy), + transposedInput, transformedWeight, bias, inputZp, weightZp, + rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation), accType) + .getResult(); + } + } else if (!is3D && weightShape[1] == 1) { + auto biasValueOr = getOrCreateBias(weightShape[0]); + if (failed(biasValueOr)) + return failure(); + bias = *biasValueOr; + biasElemTy = cast(bias.getType()).getElementType(); + + auto convOpTy = + RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); // depthwise convolution convOpResult = tosa::DepthwiseConv2DOp::create( @@ -2751,18 +3171,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getDenseI64ArrayAttr(dilation), accType) .getResult(); } else { - llvm_unreachable("Unhandled convolution type"); + return rewriter.notifyMatchFailure( + op, is3D ? "Unimplemented: grouped or depthwise 3D convolution " + "not supported by TOSA" + : "Unhandled convolution type"); } - SmallVector transposedOutputShape( - {outputShape[0], outputShape[3], outputShape[1], outputShape[2]}); + SmallVector transposedOutputShape = + permuteShape(outputShape, tosaToTorchDims); auto transposedOutputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedOutputShape), biasElemTy); auto transposedOutput = tosa::TransposeOp::create( rewriter, op->getLoc(), getTypeConverter()->convertType(transposedOutputType), convOpResult, - rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) + rewriter.getDenseI32ArrayAttr(tosaToTorchDims)) .getResult(); Value rescaledResult = transposedOutput; diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 7bf88663cd37..31f8a6e7d267 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -584,7 +584,7 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter, FailureOr getConvBiasForNoneType(Operation *op, PatternRewriter &rewriter, Type inputElemTy, Type outputElemTy, - ArrayRef weightShape) { + int64_t numOutputChannels) { Type biasElemTy; @@ -610,16 +610,18 @@ FailureOr getConvBiasForNoneType(Operation *op, biasElemTy = outputElemTy; } + if (ShapedType::isDynamic(numOutputChannels)) + return rewriter.notifyMatchFailure( + op, "cannot synthesize conv bias with dynamic output channels"); + + int32_t oc = static_cast(numOutputChannels); + if (biasElemTy.isInteger()) { - SmallVector zeroVec(weightShape[0], 0); - return tosa::getConstTensor(rewriter, op, zeroVec, - {static_cast(weightShape[0])}) - .value(); + SmallVector zeroVec(oc, 0); + return tosa::getConstTensor(rewriter, op, zeroVec, {oc}).value(); } else { - SmallVector zeroVec(weightShape[0], 0); - return tosa::getConstTensor(rewriter, op, zeroVec, - {static_cast(weightShape[0])}, - biasElemTy) + SmallVector zeroVec(oc, 0); + return tosa::getConstTensor(rewriter, op, zeroVec, {oc}, biasElemTy) .value(); } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2fcfbb1ef772..54cef69aba4f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -70,9 +70,6 @@ "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", "Conv_Transpose2dModule_basic", - "Conv_Transpose2dStaticModule_basic", - "Conv_Transpose3dModule_basic", - "Conv_Transpose3dStaticModule_basic", "ConvolutionModule2DTransposeStridedStatic_basic", "ConvolutionModule2DTransposeStrided_basic", "GridSamplerBasic1_basic", @@ -1136,8 +1133,6 @@ "ConvolutionBackwardModule2DStatic_basic", "ConvolutionModule2DTransposeStridedStatic_basic", "Conv_Transpose1dStaticModule_basic", - "Conv_Transpose2dStaticModule_basic", - "Conv_Transpose3dStaticModule_basic", "ConstantPad2dStaticModule_basic", "ConstantPadNdModule_basic", "ConstantPadNdPartialStaticModule_basic", @@ -2168,6 +2163,9 @@ "Conv2dWithValidPaddingModule_basic", "Conv2dWithSamePaddingModule_basic", "Convolution2DStaticModule_basic", + "Conv3dModule_basic", + "Conv3dWithSamePaddingModule_basic", + "Conv3dWithValidPaddingModule_basic", "CosineSimilarityStaticModule_basic", "DetachModule_basic", "DropoutEvalFloatModule_basic", @@ -2758,6 +2756,13 @@ "TupleModule_basic", "ThresholdStaticModule_basic", "VarCorrectionLargeInputModule_basic", + "Conv3dModule_basic", + "Conv3dWithSamePaddingModule_basic", + "Conv3dWithValidPaddingModule_basic", + "Conv_Transpose3dModule_basic", + "ConvolutionModule3DGroups_basic", + "ConvolutionModule3DGroupsStrided_basic", + "ConvolutionModule3DGroupsDilated_basic", # Failure - incorrect shape "ArangeStartOutDtypeModule_basic", "ArangeStartOutViewModule_basic", @@ -2907,12 +2912,6 @@ "Conv2dWithPaddingModule_basic", "Conv2dWithSamePaddingModule_basic", "Conv2dWithValidPaddingModule_basic", - "Conv3dModule_basic", - "Conv3dWithSamePaddingModule_basic", - "Conv3dWithValidPaddingModule_basic", - "ConvolutionModule3DGroups_basic", - "ConvolutionModule3DGroupsStrided_basic", - "ConvolutionModule3DGroupsDilated_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", @@ -3373,7 +3372,6 @@ # Failure - unknown "BernoulliModule_basic", "Conv_Transpose1dModule_basic", - "Conv_Transpose3dModule_basic", "CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", @@ -3587,8 +3585,6 @@ "AvgPool3dSingleIntTupleStrideModule_basic", "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", - "Conv_Transpose3dModule_basic", - "Conv_Transpose3dStaticModule_basic", "IndexPutWithNoneAndBroadcastModule_basic", "MaskedScatterStaticBasic_basic", "MaxUnpool3dModulePad0_basic", @@ -3690,9 +3686,6 @@ "Conv2dQInt8PerChannelModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", - "Conv3dModule_basic", - "Conv3dWithSamePaddingModule_basic", - "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", @@ -4109,9 +4102,6 @@ "AvgPool3dStaticModule_basic", "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", - "Conv_Transpose2dStaticModule_basic", - "Conv_Transpose3dModule_basic", - "Conv_Transpose3dStaticModule_basic", "ElementwiseFmaxModule_basic", "ElementwiseFminModule_basic", "ElementwiseGeluApproximateTanhModule_basic", @@ -4326,9 +4316,6 @@ "Conv2dWithPaddingModule_basic", "Conv2dWithSamePaddingModule_basic", "Conv2dWithValidPaddingModule_basic", - "Conv3dModule_basic", - "Conv3dWithSamePaddingModule_basic", - "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index ee691146e419..9d28d3f8592b 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -3700,26 +3700,26 @@ func.func @torch.aten.constant_pad_nd$basic(%arg0: !torch.vtensor<[1,1,20,20,4,4 // ----- // CHECK-LABEL: func.func @torch.aten.convolution$basic( -// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,10,20],f32> -> tensor<5x2x10x20xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.bool false -// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource : tensor<10x2x3x3xf32>}> : () -> tensor<10x2x3x3xf32> -// CHECK: %[[VAL_5:.*]] = torch.constant.none -// CHECK: %[[VAL_6:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<10xf32>}> : () -> tensor<10xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<10x2x3x3xf32>) -> tensor<10x3x3x2xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<5x2x10x20xf32>) -> tensor<5x10x20x2xf32> -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_12]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x14x24x10xf32> -// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array} : (tensor<5x14x24x10xf32>) -> tensor<5x10x14x24xf32> -// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> -// CHECK: return %[[VAL_18]] : !torch.vtensor<[5,10,14,24],f32> +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { +// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[5,2,10,20],f32> -> tensor<5x2x10x20xf32> +// CHECK: %[[USE_BIAS:.*]] = torch.constant.bool false +// CHECK: %[[KERNEL:.*]] = torch.constant.int 3 +// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense_resource : tensor<10x2x3x3xf32>}> : () -> tensor<10x2x3x3xf32> +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[ONE:.*]] = torch.constant.int 1 +// CHECK: %[[STRIDE_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[KERNEL]], %[[KERNEL]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array} : (tensor<5x2x10x20xf32>) -> tensor<5x10x20x2xf32> +// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array} : (tensor<10x2x3x3xf32>) -> tensor<10x3x3x2xf32> +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<10xf32>}> : () -> tensor<10xf32> +// CHECK: %[[CONV:.*]] = tosa.conv2d %[[NHWC_INPUT]], %[[NHWC_WEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x14x24x10xf32> +// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array} : (tensor<5x14x24x10xf32>) -> tensor<5x10x14x24xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_NCHW]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[5,10,14,24],f32> // CHECK: } func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { %false = torch.constant.bool false @@ -3738,29 +3738,29 @@ func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) // ----- // CHECK-LABEL: func.func @torch.aten.convolution$depthwise( -// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,4,10,20],f32> -> tensor<5x4x10x20xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.bool false -// CHECK: %[[VAL_3:.*]] = torch.constant.int 4 -// CHECK: %[[VAL_4:.*]] = torch.constant.int 3 -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense_resource : tensor<4x1x3x3xf32>}> : () -> tensor<4x1x3x3xf32> -// CHECK: %[[VAL_6:.*]] = torch.constant.none -// CHECK: %[[VAL_7:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_5]] {perms = array} : (tensor<4x1x3x3xf32>) -> tensor<3x3x4x1xf32> -// CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<[3, 3, 4, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_13]], %[[VAL_14]] : (tensor<3x3x4x1xf32>, !tosa.shape<4>) -> tensor<3x3x4x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<5x4x10x20xf32>) -> tensor<5x10x20x4xf32> -// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_18:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_19:.*]] = tosa.depthwise_conv2d %[[VAL_16]], %[[VAL_15]], %[[VAL_12]], %[[VAL_17]], %[[VAL_18]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x5x10x4xf32> -// CHECK: %[[VAL_20:.*]] = tosa.transpose %[[VAL_19]] {perms = array} : (tensor<5x5x10x4xf32>) -> tensor<5x4x5x10xf32> -// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> -// CHECK: return %[[VAL_21]] : !torch.vtensor<[5,4,5,10],f32> +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { +// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[5,4,10,20],f32> -> tensor<5x4x10x20xf32> +// CHECK: %[[USE_BIAS:.*]] = torch.constant.bool false +// CHECK: %[[OUT_CHANS:.*]] = torch.constant.int 4 +// CHECK: %[[KERNEL:.*]] = torch.constant.int 3 +// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense_resource : tensor<4x1x3x3xf32>}> : () -> tensor<4x1x3x3xf32> +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[STRIDE:.*]] = torch.constant.int 2 +// CHECK: %[[STRIDE_LIST:.*]] = torch.prim.ListConstruct %[[STRIDE]], %[[STRIDE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[KERNEL]], %[[KERNEL]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[KERNEL]], %[[KERNEL]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array} : (tensor<5x4x10x20xf32>) -> tensor<5x10x20x4xf32> +// CHECK: %[[HW_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array} : (tensor<4x1x3x3xf32>) -> tensor<3x3x4x1xf32> +// CHECK: %[[RESHAPE_SHAPE:.*]] = tosa.const_shape {values = dense<[3, 3, 4, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[FILTER:.*]] = tosa.reshape %[[HW_WEIGHT]], %[[RESHAPE_SHAPE]] : (tensor<3x3x4x1xf32>, !tosa.shape<4>) -> tensor<3x3x4x1xf32> +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32> +// CHECK: %[[DEPTHWISE:.*]] = tosa.depthwise_conv2d %[[NHWC_INPUT]], %[[FILTER]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x5x10x4xf32> +// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[DEPTHWISE]] {perms = array} : (tensor<5x5x10x4xf32>) -> tensor<5x4x5x10xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_NCHW]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[5,4,5,10],f32> // CHECK: } func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { %false = torch.constant.bool false @@ -3777,36 +3777,34 @@ func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f3 return %5 : !torch.vtensor<[5,4,5,10],f32> } -// ----- - // CHECK-LABEL: func.func @torch.aten.convolution$zero_pad_with_sliced_input( -// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,64,56,56],f32>) -> !torch.vtensor<[1,128,28,28],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,56,56],f32> -> tensor<1x64x56x56xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.bool false -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense_resource : tensor<128x64x1x1xf32>}> : () -> tensor<128x64x1x1xf32> -// CHECK: %[[VAL_6:.*]] = torch.constant.none -// CHECK: %[[VAL_7:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<128xf32>}> : () -> tensor<128xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_5]] {perms = array} : (tensor<128x64x1x1xf32>) -> tensor<128x1x1x64xf32> -// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x64x56x56xf32>) -> tensor<1x56x56x64xf32> -// CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_16:.*]] = tosa.const_shape {values = dense<[1, 55, 56, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_17:.*]] = tosa.slice %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] : (tensor<1x56x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x56x64xf32> -// CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 55, 55, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_20:.*]] = tosa.slice %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : (tensor<1x55x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x55x64xf32> -// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_23:.*]] = tosa.conv2d %[[VAL_20]], %[[VAL_13]], %[[VAL_12]], %[[VAL_21]], %[[VAL_22]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x55x55x64xf32>, tensor<128x1x1x64xf32>, tensor<128xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x28x28x128xf32> -// CHECK: %[[VAL_24:.*]] = tosa.transpose %[[VAL_23]] {perms = array} : (tensor<1x28x28x128xf32>) -> tensor<1x128x28x28xf32> -// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<1x128x28x28xf32> -> !torch.vtensor<[1,128,28,28],f32> -// CHECK: return %[[VAL_25]] : !torch.vtensor<[1,128,28,28],f32> +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,64,56,56],f32>) -> !torch.vtensor<[1,128,28,28],f32> { +// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[1,64,56,56],f32> -> tensor<1x64x56x56xf32> +// CHECK: %[[USE_BIAS:.*]] = torch.constant.bool false +// CHECK: %[[ONE:.*]] = torch.constant.int 1 +// CHECK: %[[ZERO:.*]] = torch.constant.int 0 +// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense_resource : tensor<128x64x1x1xf32>}> : () -> tensor<128x64x1x1xf32> +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[STRIDE:.*]] = torch.constant.int 2 +// CHECK: %[[STRIDE_LIST:.*]] = torch.prim.ListConstruct %[[STRIDE]], %[[STRIDE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array} : (tensor<1x64x56x56xf32>) -> tensor<1x56x56x64xf32> +// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array} : (tensor<128x64x1x1xf32>) -> tensor<128x1x1x64xf32> +// CHECK-DAG: %[[SLICE0_START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[SLICE0_SIZE:.*]] = tosa.const_shape {values = dense<[1, 55, 56, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[TRIMMED_H:.*]] = tosa.slice %[[NHWC_INPUT]], %[[SLICE0_START]], %[[SLICE0_SIZE]] : (tensor<1x56x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x56x64xf32> +// CHECK-DAG: %[[SLICE1_START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[SLICE1_SIZE:.*]] = tosa.const_shape {values = dense<[1, 55, 55, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[TRIMMED_HW:.*]] = tosa.slice %[[TRIMMED_H]], %[[SLICE1_START]], %[[SLICE1_SIZE]] : (tensor<1x55x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x55x64xf32> +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<128xf32>}> : () -> tensor<128xf32> +// CHECK: %[[CONV:.*]] = tosa.conv2d %[[TRIMMED_HW]], %[[NHWC_WEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x55x55x64xf32>, tensor<128x1x1x64xf32>, tensor<128xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x28x28x128xf32> +// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array} : (tensor<1x28x28x128xf32>) -> tensor<1x128x28x28xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_NCHW]] : tensor<1x128x28x28xf32> -> !torch.vtensor<[1,128,28,28],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[1,128,28,28],f32> // CHECK: } func.func @torch.aten.convolution$zero_pad_with_sliced_input(%arg0: !torch.vtensor<[1,64,56,56],f32>) -> !torch.vtensor<[1,128,28,28],f32> { %false = torch.constant.bool false @@ -3826,26 +3824,26 @@ func.func @torch.aten.convolution$zero_pad_with_sliced_input(%arg0: !torch.vtens // ----- // CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input( -// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,32,112,112],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,3,224,224],f32> -> tensor<1x3x224x224xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.bool false -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_5:.*]] = torch.constant.none -// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x3x224x224xf32>) -> tensor<1x224x224x3xf32> -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_12]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x112x112x32xf32> -// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array} : (tensor<1x112x112x32xf32>) -> tensor<1x32x112x112xf32> -// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x32x112x112xf32> -> !torch.vtensor<[1,32,112,112],f32> -// CHECK: return %[[VAL_18]] : !torch.vtensor<[1,32,112,112],f32> +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,32,112,112],f32> { +// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[1,3,224,224],f32> -> tensor<1x3x224x224xf32> +// CHECK: %[[USE_BIAS:.*]] = torch.constant.bool false +// CHECK: %[[ONE:.*]] = torch.constant.int 1 +// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense_resource : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32> +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[STRIDE:.*]] = torch.constant.int 2 +// CHECK: %[[STRIDE_LIST:.*]] = torch.prim.ListConstruct %[[STRIDE]], %[[STRIDE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array} : (tensor<1x3x224x224xf32>) -> tensor<1x224x224x3xf32> +// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> +// CHECK: %[[CONV:.*]] = tosa.conv2d %[[NHWC_INPUT]], %[[NHWC_WEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x112x112x32xf32> +// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array} : (tensor<1x112x112x32xf32>) -> tensor<1x32x112x112xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_NCHW]] : tensor<1x32x112x112xf32> -> !torch.vtensor<[1,32,112,112],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[1,32,112,112],f32> // CHECK: } func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input(%arg0: !torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,32,112,112],f32> { %false = torch.constant.bool false @@ -3864,32 +3862,32 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_ // ----- // CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input( -// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,3,225,225],f32>) -> !torch.vtensor<[1,32,75,75],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,3,225,225],f32> -> tensor<1x3x225x225xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.bool false -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_5:.*]] = torch.constant.none -// CHECK: %[[VAL_6:.*]] = torch.constant.int 3 -// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x3x225x225xf32>) -> tensor<1x225x225x3xf32> -// CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_13]], %[[VAL_14]], %[[VAL_15]] : (tensor<1x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x225x3xf32> -// CHECK-DAG: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<1x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x224x3xf32> -// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_12]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x75x75x32xf32> -// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array} : (tensor<1x75x75x32xf32>) -> tensor<1x32x75x75xf32> -// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<1x32x75x75xf32> -> !torch.vtensor<[1,32,75,75],f32> -// CHECK: return %[[VAL_24]] : !torch.vtensor<[1,32,75,75],f32> +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,3,225,225],f32>) -> !torch.vtensor<[1,32,75,75],f32> { +// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[1,3,225,225],f32> -> tensor<1x3x225x225xf32> +// CHECK: %[[USE_BIAS:.*]] = torch.constant.bool false +// CHECK: %[[ONE:.*]] = torch.constant.int 1 +// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense_resource : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32> +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[STRIDE:.*]] = torch.constant.int 3 +// CHECK: %[[STRIDE_LIST:.*]] = torch.prim.ListConstruct %[[STRIDE]], %[[STRIDE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array} : (tensor<1x3x225x225xf32>) -> tensor<1x225x225x3xf32> +// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK-DAG: %[[SLICE0_START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[SLICE0_SIZE:.*]] = tosa.const_shape {values = dense<[1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[TRIMMED_H:.*]] = tosa.slice %[[NHWC_INPUT]], %[[SLICE0_START]], %[[SLICE0_SIZE]] : (tensor<1x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x225x3xf32> +// CHECK-DAG: %[[SLICE1_START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[SLICE1_SIZE:.*]] = tosa.const_shape {values = dense<[1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[TRIMMED_HW:.*]] = tosa.slice %[[TRIMMED_H]], %[[SLICE1_START]], %[[SLICE1_SIZE]] : (tensor<1x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x224x3xf32> +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> +// CHECK: %[[CONV:.*]] = tosa.conv2d %[[TRIMMED_HW]], %[[NHWC_WEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x75x75x32xf32> +// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array} : (tensor<1x75x75x32xf32>) -> tensor<1x32x75x75xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_NCHW]] : tensor<1x32x75x75xf32> -> !torch.vtensor<[1,32,75,75],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[1,32,75,75],f32> // CHECK: } func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input(%arg0: !torch.vtensor<[1,3,225,225],f32>) -> !torch.vtensor<[1,32,75,75],f32> { %false = torch.constant.bool false @@ -3909,26 +3907,27 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp // ----- // CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch( -// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,224,224],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = torch.constant.bool false -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_5:.*]] = torch.constant.none -// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor) -> tensor -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_12]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor -// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array} : (tensor) -> tensor -// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor -> !torch.vtensor<[?,32,112,112],f32> -// CHECK: return %[[VAL_18]] +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> { +// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,3,224,224],f32> -> tensor +// CHECK: %[[USE_BIAS:.*]] = torch.constant.bool false +// CHECK: %[[ONE:.*]] = torch.constant.int 1 +// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense_resource : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32> +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[STRIDE:.*]] = torch.constant.int 2 +// CHECK: %[[STRIDE_LIST:.*]] = torch.prim.ListConstruct %[[STRIDE]], %[[STRIDE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array} : (tensor) -> tensor +// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> +// CHECK: %[[CONV:.*]] = tosa.conv2d %[[NHWC_INPUT]], %[[NHWC_WEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor +// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array} : (tensor) -> tensor +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_NCHW]] : tensor -> !torch.vtensor<[?,32,112,112],f32> +// CHECK: return %[[RESULT]] +// CHECK: } func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch(%arg0: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> { %false = torch.constant.bool false @@ -3948,32 +3947,33 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_ // ----- // CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch( -// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,225,225],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = torch.constant.bool false -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_5:.*]] = torch.constant.none -// CHECK: %[[VAL_6:.*]] = torch.constant.int 3 -// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor) -> tensor -// CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_13]], %[[VAL_14]], %[[VAL_15]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor -// CHECK-DAG: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor -// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_12]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor -// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array} : (tensor) -> tensor -// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor -> !torch.vtensor<[?,32,75,75],f32> -// CHECK: return %[[VAL_24]] +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> { +// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,3,225,225],f32> -> tensor +// CHECK: %[[USE_BIAS:.*]] = torch.constant.bool false +// CHECK: %[[ONE:.*]] = torch.constant.int 1 +// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense_resource : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32> +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[STRIDE:.*]] = torch.constant.int 3 +// CHECK: %[[STRIDE_LIST:.*]] = torch.prim.ListConstruct %[[STRIDE]], %[[STRIDE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array} : (tensor) -> tensor +// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK-DAG: %[[SLICE0_START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[SLICE0_SIZE:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[TRIMMED_H:.*]] = tosa.slice %[[NHWC_INPUT]], %[[SLICE0_START]], %[[SLICE0_SIZE]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor +// CHECK-DAG: %[[SLICE1_START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[SLICE1_SIZE:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[TRIMMED_HW:.*]] = tosa.slice %[[TRIMMED_H]], %[[SLICE1_START]], %[[SLICE1_SIZE]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> +// CHECK: %[[CONV:.*]] = tosa.conv2d %[[TRIMMED_HW]], %[[NHWC_WEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor +// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array} : (tensor) -> tensor +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_NCHW]] : tensor -> !torch.vtensor<[?,32,75,75],f32> +// CHECK: return %[[RESULT]] +// CHECK: } func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch(%arg0: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> { %false = torch.constant.bool false %int1 = torch.constant.int 1 diff --git a/test/Conversion/TorchToTosa/conv3d_transpose.mlir b/test/Conversion/TorchToTosa/conv3d_transpose.mlir new file mode 100644 index 000000000000..8b65e52087c0 --- /dev/null +++ b/test/Conversion/TorchToTosa/conv3d_transpose.mlir @@ -0,0 +1,68 @@ +// RUN: torch-mlir-opt %s -convert-torch-to-tosa -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.convolution$3d_basic( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,5,6,7],f32>) -> !torch.vtensor<[2,4,5,6,7],f32> { +// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,5,6,7],f32> -> tensor<2x3x5x6x7xf32> +// CHECK: %[[USE_BIAS:.*]] = torch.constant.bool false +// CHECK: %[[ONE:.*]] = torch.constant.int 1 +// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense_resource : tensor<4x3x3x3x3xf32>}> : () -> tensor<4x3x3x3x3xf32> +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]], %[[ONE]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]], %[[ONE]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]], %[[ONE]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array} : (tensor<2x3x5x6x7xf32>) -> tensor<2x5x6x7x3xf32> +// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array} : (tensor<4x3x3x3x3xf32>) -> tensor<4x3x3x3x3xf32> +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[BIAS_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32> +// CHECK: %[[CONV:.*]] = tosa.conv3d %[[NHWC_INPUT]], %[[NHWC_WEIGHT]], %[[BIAS_CONST]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x5x6x7x3xf32>, tensor<4x3x3x3x3xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x5x6x7x4xf32> +// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array} : (tensor<2x5x6x7x4xf32>) -> tensor<2x4x5x6x7xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_NCHW]] : tensor<2x4x5x6x7xf32> -> !torch.vtensor<[2,4,5,6,7],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[2,4,5,6,7],f32> +func.func @torch.aten.convolution$3d_basic(%arg0: !torch.vtensor<[2,3,5,6,7],f32>) -> !torch.vtensor<[2,4,5,6,7],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %weight = torch.vtensor.literal(dense_resource : tensor<4x3x3x3x3xf32>) : !torch.vtensor<[4,3,3,3,3],f32> + %none = torch.constant.none + %stride = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %output_padding = torch.prim.ListConstruct : () -> !torch.list + %result = torch.aten.convolution %arg0, %weight, %none, %stride, %padding, %dilation, %false, %output_padding, %int1 : !torch.vtensor<[2,3,5,6,7],f32>, !torch.vtensor<[4,3,3,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[2,4,5,6,7],f32> + return %result : !torch.vtensor<[2,4,5,6,7],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$3d_transpose( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,1,2,2,2],f32>) -> !torch.vtensor<[1,1,4,4,4],f32> { +// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[1,1,2,2,2],f32> -> tensor<1x1x2x2x2xf32> +// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1x3x3x3xf32>}> : () -> tensor<1x1x3x3x3xf32> +// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array} : (tensor<1x1x2x2x2xf32>) -> tensor<1x2x2x2x1xf32> +// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array} : (tensor<1x1x3x3x3xf32>) -> tensor<1x3x3x3x1xf32> +// CHECK: %[[REV_D:.*]] = tosa.reverse %[[NHWC_WEIGHT]] {axis = 1 : i32} +// CHECK: %[[REV_H:.*]] = tosa.reverse %[[REV_D]] {axis = 2 : i32} +// CHECK: %[[REV_W:.*]] = tosa.reverse %[[REV_H]] {axis = 3 : i32} +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[CONV:.*]] = tosa.conv3d {{.*}}, %[[REV_W]], {{.*}}, {{.*}}, {{.*}} {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x6x6x6x1xf32>, tensor<1x3x3x3x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x4x1xf32> +// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array} : (tensor<1x4x4x4x1xf32>) -> tensor<1x1x4x4x4xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_NCHW]] : tensor<1x1x4x4x4xf32> -> !torch.vtensor<[1,1,4,4,4],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[1,1,4,4,4],f32> +// CHECK: } +func.func @torch.aten.convolution$3d_transpose(%arg0: !torch.vtensor<[1,1,2,2,2],f32>) -> !torch.vtensor<[1,1,4,4,4],f32> { + %true = torch.constant.bool true + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %weight = torch.vtensor.literal(dense<1.000000e+00> : tensor<1x1x3x3x3xf32>) : !torch.vtensor<[1,1,3,3,3],f32> + %bias = torch.vtensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.vtensor<[1],f32> + %stride = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %out_padding = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %groups = torch.constant.int 1 + %result = torch.aten.convolution %arg0, %weight, %bias, %stride, %padding, %dilation, %true, %out_padding, %groups : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,3,3,3],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,4,4],f32> + return %result : !torch.vtensor<[1,1,4,4,4],f32> +}