Skip to content

Commit bd11877

Browse files
authored
[onnx] Support lowering quantize linear to torch (#2751)
We can map the per_tensor case to the `torch.aten.quantize_per_linear` operation. In this case we extract the `scale` and `zeropoint` values and directly invoke the quantization, then return the integer representation value.
1 parent 77a03f2 commit bd11877

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
11+
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
1112
#include "llvm/ADT/ArrayRef.h"
1213
#include "llvm/ADT/SmallVector.h"
1314

@@ -41,6 +42,56 @@ Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
4142

4243
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
4344
OnnxCustomOpConversionPattern &patterns) {
45+
patterns.onOp("QuantizeLinear", 1,
46+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
47+
Torch::ValueTensorType resultType;
48+
llvm::SmallVector<Value> operands;
49+
if (binder.tensorOperands(operands, 3) ||
50+
binder.tensorResultType(resultType))
51+
return failure();
52+
53+
Value operand = operands[0];
54+
Value scale = operands[1];
55+
Value zeropoint = operands[2];
56+
57+
auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>();
58+
if (!scaleTy || !scaleTy.hasSizes()) return rewriter.notifyMatchFailure(binder.op, "requires known rank");
59+
if (!resultType.hasDtype())
60+
return rewriter.notifyMatchFailure(
61+
binder.op, "requires known result dtype");
62+
63+
if (scaleTy.getSizes().size() == 0) {
64+
Type qTy = resultType.getDtype();
65+
66+
if (qTy.isUnsignedInteger(8)) {
67+
qTy = rewriter.getType<Torch::QUInt8Type>();
68+
} else if (qTy.isSignedInteger(8)) {
69+
qTy = rewriter.getType<Torch::QInt8Type>();
70+
} else if (qTy.isSignedInteger(32)) {
71+
qTy = rewriter.getType<Torch::QInt32Type>();
72+
} else {
73+
return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype");
74+
}
75+
76+
auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(resultType.getOptionalSizes(), qTy);
77+
auto torchqTy = Torch::getScalarTypeForType(qTy);
78+
79+
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
80+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
81+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast<int64_t>(torchqTy)));
82+
83+
scale = rewriter.create<Torch::AtenItemOp>(binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
84+
zeropoint = rewriter.create<Torch::AtenItemOp>(binder.getLoc(), rewriter.getType<Torch::IntType>(), zeropoint);
85+
86+
auto quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst);
87+
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType, quantize);
88+
return success();
89+
}
90+
91+
return failure();
92+
93+
}
94+
);
4495
patterns.onOp("Reciprocal", 1,
4596
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
4697
Torch::ValueTensorType resultType;

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,50 @@
44
// level constants. This is a pragmatic choice which lets us have a lot
55
// of tests in this file, whereas the others tend to be more bespoke.
66

7+
8+
// CHECK-LABEL: @test_quantizelinear_si8
9+
func.func @test_quantizelinear_si8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
10+
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8>
11+
12+
// CHECK: %[[C12:.+]] = torch.constant.int 12
13+
// CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
14+
// CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si8> -> !torch.int
15+
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[C12]]
16+
// CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]]
17+
// CHECK: return %[[REPR]]
18+
return %0 : !torch.vtensor<[6],si8>
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: @test_quantizelinear_ui8
24+
func.func @test_quantizelinear_ui8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],ui8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
25+
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],ui8>
26+
// CHECK: %[[C13:.+]] = torch.constant.int 13
27+
// CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
28+
// CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int
29+
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[C13]]
30+
// CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]]
31+
// CHECK: return %[[REPR]]
32+
return %0 : !torch.vtensor<[6],ui8>
33+
}
34+
35+
// -----
36+
37+
// CHECK-LABEL: @test_quantizelinear_i32
38+
func.func @test_quantizelinear_i32(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si32>) -> !torch.vtensor<[6],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
39+
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[6],si32>
40+
// CHECK: %[[C14:.+]] = torch.constant.int 14
41+
// CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
42+
// CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int
43+
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[C14]]
44+
// CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]]
45+
// CHECK: return %[[REPR]]
46+
return %0 : !torch.vtensor<[6],si32>
47+
}
48+
49+
// -----
50+
751
// CHECK-LABEL: func.func @test_reciprocal
852
func.func @test_reciprocal(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
953
// CHECK: torch.aten.reciprocal %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>

0 commit comments

Comments
 (0)