Skip to content

Commit

Permalink
[torch][quant] Quantized torch.mm for linalg with end-to-end test (#…
Browse files Browse the repository at this point in the history
…2750)

This includes custom op matching for decomposed operations and fusing
dequantization into dense operations. As a validation we compare
to the dequant+mm torch implementation.
  • Loading branch information
rsuderman committed Jan 24, 2024
1 parent 60bf6c2 commit f6f8905
Show file tree
Hide file tree
Showing 13 changed files with 577 additions and 8 deletions.
2 changes: 1 addition & 1 deletion externals/llvm-project
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);

std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();

std::unique_ptr<OperationPass<func::FuncOp>> createFuseQuantizedOpsPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createMatchQuantizedCustomOpsPass();

std::unique_ptr<OperationPass<ModuleOp>>
createReifyShapeCalculationsPass(StringRef extraLibrary);

Expand Down
28 changes: 28 additions & 0 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,34 @@ def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> {
}];
}

def FuseQuantizedOps : Pass<"torch-fuse-quantized-ops", "func::FuncOp"> {
let summary = "QDQ: Fuse recognized QDQ op sequences.";
let constructor = "mlir::torch::Torch::createFuseQuantizedOpsPass()";
let description = [{
Torch models often represents quantized operations as the sequence:
Dequantize
DenseOp
Quantize
This allows the existing dense operations to be used without specifically
representing quantized types. It is more computationally efficient to
perform the dense operation in the quantized domain, so we fuse the
quantization / dequantization behavior together and represent as purely
quantized operations.
}];
}

def MatchQuantizedCustomOps : Pass<"torch-match-quantized-custom-ops", "func::FuncOp"> {
let summary = "Match quantized operations that occur in different namespace.";
let constructor = "mlir::torch::Torch::createMatchQuantizedCustomOpsPass()";
let description = [{
Torch quantization utilities generated custom op versions of known aten
quantziation operations. We can match these specially named operations and
rewrite to the corresponding aten quantized operations.

We handle this post import to maintain a simplified import process.
}];
}

def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
let summary = "Reify shape calculations.";
let constructor = [{
Expand Down
45 changes: 43 additions & 2 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {

static void getZeroPoint(Value value, Value &zeropoint) {
if (auto make = value.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
zeropoint = make.getZeroPoint();
}
}

class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -64,11 +71,27 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
op.getSelf().getType().cast<ValueTensorType>();
ValueTensorType rhsTorchType =
op.getMat2().getType().cast<ValueTensorType>();

Value lhsZeroPoint, rhsZeroPoint;
getZeroPoint(op.getSelf(), lhsZeroPoint);
getZeroPoint(op.getMat2(), rhsZeroPoint);

if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(lhsZeroPoint)) {
return rewriter.notifyMatchFailure(
op, "unsupported: aten.mm with mixed quantization");
}

if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
return rewriter.notifyMatchFailure(
op, "unsupported: aten.mm with different input element types");
}

bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
if (lhsZeroPoint && isUnsigned) {
return rewriter.notifyMatchFailure(
op, "unsupported: unsigned quantized matmul not supported");
}

Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);

Expand All @@ -89,8 +112,26 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);

Value matmul;
auto intType = dyn_cast<mlir::IntegerType>(lhsTorchType.getDtype());
if (intType && intType.isUnsigned()) {
if (lhsZeroPoint && !isUnsigned) {
lhsZeroPoint = typeConverter->materializeTargetConversion(
rewriter, loc,
getTypeConverter()->convertType(lhsZeroPoint.getType()),
lhsZeroPoint);
rhsZeroPoint = typeConverter->materializeTargetConversion(
rewriter, loc,
getTypeConverter()->convertType(rhsZeroPoint.getType()),
rhsZeroPoint);
lhsZeroPoint = rewriter.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), lhsZeroPoint);
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), rhsZeroPoint);
matmul =
rewriter
.create<linalg::QuantizedMatmulOp>(
loc, zeroFill.getType(),
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroFill)
.getResult(0);
} else if (isUnsigned) {
matmul = rewriter
.create<linalg::MatmulUnsignedOp>(
loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill)
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Torch/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ add_mlir_library(TorchMLIRTorchPasses
DecomposeComplexOps.cpp
DropAbstractInterpCalculations.cpp
EraseModuleInitializer.cpp
FuseQuantizedOps.cpp
Passes.cpp
GlobalizeObjectGraph.cpp
InlineGlobalSlots.cpp
LowerToBackendContract.cpp
MatchQuantizedOps.cpp
MaximizeValueSemantics.cpp
PrepareForGlobalizeObjectGraph.cpp
RecomposeComplexOps.cpp
Expand Down
214 changes: 214 additions & 0 deletions lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {

template <typename SrcOp>
class QuantizeOperands : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<Value> operands(op->getOperands());

bool dequanted = false;
for (auto &operand : operands) {
if (auto dequant = operand.getDefiningOp<AtenDequantizeTensorOp>()) {
operand = dequant.getOperand();
dequanted = true;
}
if (auto dequant = operand.getDefiningOp<AtenDequantizeSelfOp>()) {
operand = dequant.getOperand();
dequanted = true;
}
}

if (!dequanted) {
return rewriter.notifyMatchFailure(op, "no dequantizations found");
}

rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
return success();
}
};

template <typename SrcOp> class QuantizeBias : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<Value> operands(op->getOperands());
if (operands.size() < 3)
return failure();

Value bias = operands[2];
if (bias.getDefiningOp<AtenDequantizeTensorOp>())
return failure();

Value lhsScale;
if (auto qLhs =
operands[0].getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>())
lhsScale = qLhs.getScale();

Value rhsScale;
if (auto qRhs =
operands[1].getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>())
rhsScale = qRhs.getScale();

if (!rhsScale || !lhsScale)
return failure();

auto biasTy = bias.getType().cast<ValueTensorType>();
auto biasETy = biasTy.getOptionalDtype();
if (!biasETy || !isa<mlir::FloatType>(biasETy))
return failure();

Value biasScale = rewriter.create<AtenMulFloatOp>(
op.getLoc(), lhsScale.getType(), lhsScale, rhsScale);

Value zero = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

auto qi32Ty = rewriter.getType<QInt32Type>();
auto newBiasTy =
rewriter.getType<ValueTensorType>(biasTy.getOptionalSizes(), qi32Ty);
Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty);
bias = rewriter.create<AtenQuantizePerTensorOp>(
op.getLoc(), newBiasTy, bias, biasScale, zero, dtype);

operands[2] = bias;
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
return success();
}
};

template <typename SrcOp>
class QuantizeAccumulator : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.getOperand(0);
auto rhs = op.getOperand(1);

auto resultTy = dyn_cast_or_null<ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasDtype())
return failure();

Type resultETy = resultTy.getDtype();
if (!resultETy.isa<mlir::FloatType>())
return failure();

Value lhsScale;
if (auto defining =
lhs.template getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
lhsScale = defining.getScale();
}

Value rhsScale;
if (auto defining =
rhs.template getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
rhsScale = defining.getScale();
}

if (!lhsScale || !rhsScale)
return failure();

// Quantize the bias input to the expected result:
Value zero = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

auto qi32Ty = rewriter.getType<QInt32Type>();
Value biasScale = rewriter.create<AtenMulFloatOp>(
op.getLoc(), lhsScale.getType(), lhsScale, rhsScale);

// Update the quantied type:
llvm::SmallVector<Value> operands(op.getOperands());

auto newResultTy =
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qi32Ty);
auto conv = rewriter.create<SrcOp>(op.getLoc(), newResultTy, operands);

// Attach the quantize information to the resulting quint32:
auto intReprTy = rewriter.getType<ValueTensorType>(
resultTy.getOptionalSizes(),
rewriter.getIntegerType(32, IntegerType::Signed));
auto intRepr = rewriter.create<AtenIntReprOp>(op.getLoc(), intReprTy, conv);

auto quantTy =
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qi32Ty);
auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
op.getLoc(), quantTy, intRepr, biasScale, zero);
auto dequant =
rewriter.create<AtenDequantizeTensorOp>(op.getLoc(), resultTy, quant);
rewriter.replaceOp(op, dequant);

return success();
}
};

template <typename SrcOp> class RemoveUnused : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
auto result = op.getResult();
if (result.use_empty()) {
op.erase();
return success();
}
return failure();
}
};

class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns
.insert<RemoveUnused<AtenDequantizeSelfOp>,
RemoveUnused<AtenDequantizeTensorOp>,
RemoveUnused<AtenQuantizePerTensorOp>,
QuantizeOperands<AtenConvolutionOp>, QuantizeOperands<AtenMmOp>,
QuantizeAccumulator<AtenConvolutionOp>,
QuantizeAccumulator<AtenMmOp>, QuantizeBias<AtenConvolutionOp>>(
context);

GreedyRewriteConfig config;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
return signalPassFailure();
}
}
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createFuseQuantizedOpsPass() {
return std::make_unique<FuseQuantizedOpsPass>();
}
Loading

0 comments on commit f6f8905

Please sign in to comment.