Skip to content

Commit

Permalink
[onnx] Lower onnx.QLinearMatMul lowering to torch operators (#2776)
Browse files Browse the repository at this point in the history
We can plumb the linear matmul into pytorch using its quantized types
with side channel information. To handle the final int8 operation we
dequantize and requantize.
  • Loading branch information
rsuderman authored Jan 24, 2024
1 parent 894805d commit 60bf6c2
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 4 deletions.
134 changes: 131 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value zeropoint = operands[2];

auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>();
if (!scaleTy || !scaleTy.hasSizes()) return rewriter.notifyMatchFailure(binder.op, "requires known rank");
if (!scaleTy || !scaleTy.hasSizes())
return rewriter.notifyMatchFailure(binder.op,
"requires known rank");
if (!resultType.hasDtype())
return rewriter.notifyMatchFailure(
binder.op, "requires known result dtype");
Expand Down Expand Up @@ -89,9 +91,135 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
}

return failure();
});
patterns.onOp(
"QLinearMatMul", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
if (binder.tensorOperands(operands, 8) ||
binder.tensorResultType(resultType))
return failure();
Value a = operands[0];
Value aScale = operands[1];
Value aZp = operands[2];
Value b = operands[3];
Value bScale = operands[4];
Value bZp = operands[5];
Value cScale = operands[6];
Value cZp = operands[7];

auto check = [](Value v) {
auto vTy = v.getType().cast<Torch::ValueTensorType>();
for (auto dim : vTy.getSizes())
if (dim != 1)
return false;
return true;
};
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
!check(cScale) || !check(cScale))
return rewriter.notifyMatchFailure(
binder.op, "not supported for non per-tensor quantization");

Value emptyList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
ValueRange{});
auto extract = [&rewriter, &binder, &emptyList](Value v) {
auto vTy = v.getType().cast<Torch::ValueTensorType>();
if (!vTy.getSizes().empty()) {
vTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({}), vTy.getOptionalDtype());
v = rewriter.create<Torch::AtenReshapeOp>(binder.getLoc(), vTy, v,
emptyList);
}

Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(vTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();

return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
v);
};

aZp = extract(aZp);
bZp = extract(bZp);
cZp = extract(cZp);
aScale = extract(aScale);
bScale = extract(bScale);
cScale = extract(cScale);

auto getQTy =
[&rewriter](Torch::ValueTensorType ty) -> Torch::ValueTensorType {
auto dt = ty.getDtype();
Type newDt;
if (dt.isUnsignedInteger(8)) {
newDt = rewriter.getType<Torch::QUInt8Type>();
} else if (dt.isSignedInteger(8)) {
newDt = rewriter.getType<Torch::QInt8Type>();
} else if (dt.isSignedInteger(32)) {
newDt = rewriter.getType<Torch::QInt32Type>();
} else {
return nullptr;
}

}
);
return rewriter.getType<Torch::ValueTensorType>(ty.getOptionalSizes(),
newDt);
};

auto make = [&rewriter, &binder, &getQTy](Value v, Value scale,
Value zp) -> Value {
auto ty = v.getType().cast<Torch::ValueTensorType>();
auto newTy = getQTy(ty);
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), newTy, v, scale, zp);
};

a = make(a, aScale, aZp);
b = make(b, bScale, bZp);

auto cTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(),
rewriter.getIntegerType(32, /*issigned=*/true));

Value c;
if (cTy.getSizes().size() == 2) {
c = rewriter.create<Torch::AtenMmOp>(binder.getLoc(), cTy, a, b);
} else {
c = rewriter.create<Torch::AtenBmmOp>(binder.getLoc(), cTy, a, b);
}

cTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(),
rewriter.getType<Torch::QInt32Type>());

Value mmScale = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
bScale);
Value mmZp = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), cTy, c, mmScale, mmZp);
cTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), rewriter.getF32Type());

c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
c);
cTy = getQTy(resultType);
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(64),
static_cast<int64_t>(
Torch::getScalarTypeForType(cTy.getDtype()))));
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
c);
return success();
});
patterns.onOp("Reciprocal", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
65 changes: 64 additions & 1 deletion test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// level constants. This is a pragmatic choice which lets us have a lot
// of tests in this file, whereas the others tend to be more bespoke.


// CHECK-LABEL: @test_quantizelinear_si8
func.func @test_quantizelinear_si8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8>
Expand Down Expand Up @@ -48,6 +47,70 @@ func.func @test_quantizelinear_i32(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch

// -----

// CHECK-LABEL: @test_qlinearmatmul_2D
func.func @test_qlinearmatmul_2D(%arg0: !torch.vtensor<[2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8>
// CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK-DAG: %[[RESH0:.+]] = torch.aten.reshape %arg2, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list<int> -> !torch.vtensor<[],ui8>
// CHECK-DAG: %[[RESH1:.+]] = torch.aten.reshape %arg5, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list<int> -> !torch.vtensor<[],ui8>
// CHECK-DAG: %[[RESH2:.+]] = torch.aten.reshape %arg7, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list<int> -> !torch.vtensor<[],ui8>
// CHECK-DAG: %[[RESH3:.+]] = torch.aten.reshape %arg1, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RESH4:.+]] = torch.aten.reshape %arg4, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RESH5:.+]] = torch.aten.reshape %arg6, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[AZP:.+]] = torch.aten.item %[[RESH0]] : !torch.vtensor<[],ui8> -> !torch.int
// CHECK-DAG: %[[BZP:.+]] = torch.aten.item %[[RESH1]] : !torch.vtensor<[],ui8> -> !torch.int
// CHECK-DAG: %[[CZP:.+]] = torch.aten.item %[[RESH2]] : !torch.vtensor<[],ui8> -> !torch.int
// CHECK-DAG: %[[ASCALE:.+]] = torch.aten.item %[[RESH3]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK-DAG: %[[BSCALE:.+]] = torch.aten.item %[[RESH4]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK-DAG: %[[CCSCALE:.+]] = torch.aten.item %[[RESH5]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK-DAG: %[[LHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[ASCALE]], %[[AZP]] : !torch.vtensor<[2,4],ui8>, !torch.float, !torch.int -> !torch.vtensor<[2,4],!torch.quint8>
// CHECK-DAG: %[[RHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[BSCALE]], %[[BZP]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8>
// CHECK: %[[MM:.+]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,4],!torch.quint8>, !torch.vtensor<[4,3],!torch.quint8> -> !torch.vtensor<[2,3],si32>
// CHECK: %[[CSCALE:.+]] = torch.aten.mul.float %[[ASCALE]], %[[BSCALE]] : !torch.float, !torch.float -> !torch.float
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
// CHECK: %[[QC:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[MM]], %[[CSCALE]], %[[ZERO]] : !torch.vtensor<[2,3],si32>, !torch.float, !torch.int -> !torch.vtensor<[2,3],!torch.qint32>
// CHECK: %[[FC:.+]] = torch.aten.dequantize.self %[[QC]] : !torch.vtensor<[2,3],!torch.qint32> -> !torch.vtensor<[2,3],f32>
// CHECK: %[[DTY:.+]] = torch.constant.int 13
// CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[FC]], %[[CCSCALE]], %[[CZP]], %[[DTY]] : !torch.vtensor<[2,3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[2,3],!torch.quint8>
// CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[2,3],!torch.quint8> -> !torch.vtensor<[2,3],ui8>
// CHECK: return %[[OUT]]
return %0 : !torch.vtensor<[2,3],ui8>
}

// -----

// CHECK-LABEL: @test_qlinearmatmul_3D
func.func @test_qlinearmatmul_3D(%arg0: !torch.vtensor<[2,2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[2,4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[2,4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,2,3],ui8>
// CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK-DAG: %[[RESH0:.+]] = torch.aten.reshape %arg2, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list<int> -> !torch.vtensor<[],ui8>
// CHECK-DAG: %[[RESH1:.+]] = torch.aten.reshape %arg5, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list<int> -> !torch.vtensor<[],ui8>
// CHECK-DAG: %[[RESH2:.+]] = torch.aten.reshape %arg7, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list<int> -> !torch.vtensor<[],ui8>
// CHECK-DAG: %[[RESH3:.+]] = torch.aten.reshape %arg1, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RESH4:.+]] = torch.aten.reshape %arg4, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RESH5:.+]] = torch.aten.reshape %arg6, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[AZP:.+]] = torch.aten.item %[[RESH0]] : !torch.vtensor<[],ui8> -> !torch.int
// CHECK-DAG: %[[BZP:.+]] = torch.aten.item %[[RESH1]] : !torch.vtensor<[],ui8> -> !torch.int
// CHECK-DAG: %[[CZP:.+]] = torch.aten.item %[[RESH2]] : !torch.vtensor<[],ui8> -> !torch.int
// CHECK-DAG: %[[ASCALE:.+]] = torch.aten.item %[[RESH3]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK-DAG: %[[BSCALE:.+]] = torch.aten.item %[[RESH4]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK-DAG: %[[CCSCALE:.+]] = torch.aten.item %[[RESH5]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK-DAG: %[[LHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[ASCALE]], %[[AZP]] : !torch.vtensor<[2,2,4],ui8>, !torch.float, !torch.int -> !torch.vtensor<[2,2,4],!torch.quint8>
// CHECK-DAG: %[[RHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[BSCALE]], %[[BZP]] : !torch.vtensor<[2,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[2,4,3],!torch.quint8>
// CHECK: %[[MM:.+]] = torch.aten.bmm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,2,4],!torch.quint8>, !torch.vtensor<[2,4,3],!torch.quint8> -> !torch.vtensor<[2,2,3],si32>
// CHECK: %[[CSCALE:.+]] = torch.aten.mul.float %[[ASCALE]], %[[BSCALE]] : !torch.float, !torch.float -> !torch.float
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
// CHECK: %[[QC:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[MM]], %[[CSCALE]], %[[ZERO]] : !torch.vtensor<[2,2,3],si32>, !torch.float, !torch.int -> !torch.vtensor<[2,2,3],!torch.qint32>
// CHECK: %[[FC:.+]] = torch.aten.dequantize.self %[[QC]] : !torch.vtensor<[2,2,3],!torch.qint32> -> !torch.vtensor<[2,2,3],f32>
// CHECK: %[[DTY:.+]] = torch.constant.int 13
// CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[FC]], %[[CCSCALE]], %[[CZP]], %[[DTY]] : !torch.vtensor<[2,2,3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[2,2,3],!torch.quint8>
// CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[2,2,3],!torch.quint8> -> !torch.vtensor<[2,2,3],ui8>
// CHECK: return %[[OUT]]
return %0 : !torch.vtensor<[2,2,3],ui8>
}

// -----

// CHECK-LABEL: func.func @test_reciprocal
func.func @test_reciprocal(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.reciprocal %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
Expand Down

0 comments on commit 60bf6c2

Please sign in to comment.