Skip to content

Commit

Permalink
[CPU EP] Int4 support for QuantizeLinear, DequantizeLinear, and Trans…
Browse files Browse the repository at this point in the history
…pose (#20362)

### Description
- 4-bit QuantizeLinear(21). **Blocked quantization still missing (i.e.,
do not support the new `block_size` attribute)**
- 4-bit DequantizeLinear(21). **Blocked dequantization still missing
(i.e., do not support the new `block_size` attribute)**
- 4-bit Transpose(21).
- Update quantization tool with int4 types.
- Disable QDQ fusions for 4-bit types. See:
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
- MLAS 4-bit quantization kernels for intel, neon, powerpc.

##### Notes
To calculate a tensor's storage size, we normally get the number of
elements from the shape (i.e., `tensor_shape.Size()`) and multiply by
the size of a single element. This does not directly work for sub-byte
elements like int4 as each element in a `Tensor<Int4x2>` stores **two**
packed int4 elements in a byte. The `Tensor::
CalculateTensorStorageSize` should be called to perform the correct
calculation for any tensor element type.

### Motivation and Context
ONNX 1.16 added the int4 and uint4 types. This initial PR adds the int4
type to ORT and adds int4 implementations for the Quant, Dequant, and
Transpose ops on CPU EP. We still need to add int4 support for many ops
and execution providers. See the ONNX 1.16 release notes:
https://github.com/onnx/onnx/releases.
  • Loading branch information
adrianlizarraga authored May 31, 2024
1 parent a508130 commit b02d5e6
Show file tree
Hide file tree
Showing 58 changed files with 2,472 additions and 281 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ jobs:
github_token: ${{ secrets.github_token }}
reporter: github-pr-check
level: warning
flags: --linelength=120 --exclude=java/src/main/native/*.c
flags: --linelength=120
--exclude=java/src/main/native/*.c
--exclude=onnxruntime/core/mlas/inc/*
--exclude=onnxruntime/core/mlas/lib/*
filter: "-runtime/references"

lint-js:
Expand Down
4 changes: 2 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1373,7 +1373,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32)</dt>
<dt><tt>T1</tt> : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32), tensor(int4), tensor(uint4)</dt>
<dd>Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, 16-bit integer tensors, or 32-bit signed integer tensors.</dd>
<dt><tt>T2</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain 'y', 'x_scale' to float tensors.</dd>
Expand Down Expand Up @@ -4832,7 +4832,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>T1</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain 'x', 'y_scale' to float tensors.</dd>
<dt><tt>T2</tt> : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16)</dt>
<dt><tt>T2</tt> : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int4), tensor(uint4)</dt>
<dd>Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.</dd>
</dl>

Expand Down
10 changes: 5 additions & 5 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Do not modify directly.*
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[1, 10]|**T** = tensor(double), tensor(float)|
|DequantizeLinear|*in* x:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *out* y:**tensor(float)**<br><br>or<br><br>*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *out* y:**tensor(float)**<br><br>or<br><br>*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||[19, 20]|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int32), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||[13, 18]|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|||[10, 12]|**T** = tensor(int32), tensor(int8), tensor(uint8)|
Expand Down Expand Up @@ -259,7 +259,7 @@ Do not modify directly.*
|||[7, 11]|**T** = tensor(double), tensor(float)|
|QLinearConv|*in* x:**T1**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T1**<br> *in* w:**T2**<br> *in* w_scale:**tensor(float)**<br> *in* w_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *in* B:**T4**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)<br/> **T4** = tensor(int32)|
|QLinearMatMul|*in* a:**T1**<br> *in* a_scale:**TS**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**TS**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**TS**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**<br><br>or<br><br>*in* a:**T1**<br> *in* a_scale:**tensor(float)**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**tensor(float)**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**<br><br>or<br><br>*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**<br><br>or<br><br>*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|||[19, 20]|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)|
|||[13, 18]|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|||[10, 12]|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
Expand Down Expand Up @@ -418,7 +418,7 @@ Do not modify directly.*
|TopK|*in* X:**T**<br> *in* K:**tensor(int64)**<br> *out* Values:**T**<br> *out* Indices:**I**<br><br>or<br><br>*in* X:**T**<br> *out* Values:**T**<br> *out* Indices:**I**|11+|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||10|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float)|
|||[1, 9]|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float)|
|Transpose|*in* data:**T**<br> *out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Transpose|*in* data:**T**<br> *out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Trilu|*in* input:**T**<br> *in* k:**tensor(int64)**<br> *out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)|
Expand Down Expand Up @@ -468,7 +468,7 @@ Do not modify directly.*
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|CropAndResize|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *in* crop_size:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int32)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)<br/> **T2** = tensor(float)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)<br/> **T2** = tensor(float)|
|DynamicQuantizeLSTM|*in* X:**T**<br> *in* W:**T2**<br> *in* R:**T2**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* W_scale:**T**<br> *in* W_zero_point:**T2**<br> *in* R_scale:**T**<br> *in* R_zero_point:**T2**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(float)<br/> **T1** = tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|DynamicQuantizeMatMul|*in* A:**T1**<br> *in* B:**T2**<br> *in* b_scale:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float)|
Expand Down Expand Up @@ -504,7 +504,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearWhere|*in* condition:**B**<br> *in* X:**T**<br> *in* x_scale:**TF**<br> *in* x_zero_point:**T**<br> *in* Y:**T**<br> *in* y_scale:**TF**<br> *in* y_zero_point:**T**<br> *in* z_scale:**TF**<br> *in* z_zero_point:**T**<br> *out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float)|
Expand Down
57 changes: 42 additions & 15 deletions include/onnxruntime/core/framework/data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "core/framework/endian.h"
#include "core/framework/float8.h"
#include "core/framework/float16.h"
#include "core/framework/int4.h"
#include "core/graph/onnx_protobuf.h"
#include "core/framework/to_tensor_proto_element_type.h"

Expand Down Expand Up @@ -280,7 +281,8 @@ struct IsAnyOf<T, H, Tail...> {
template <typename T>
struct IsTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
int32_t, int64_t, std::string, bool, MLFloat16,
double, uint32_t, uint64_t, BFloat16
double, uint32_t, uint64_t, BFloat16,
Int4x2, UInt4x2
#if !defined(DISABLE_FLOAT8_TYPES)
,
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
Expand Down Expand Up @@ -917,7 +919,8 @@ class OpaqueType : public NonTensorType<T> {
* Base class for primitive Tensor contained types
*
* \details This class contains an integer constant that can be
* used for input data type dispatching
* used for input data type dispatching. This class also stores the number of subelements per size units.
* Example: For int4, the size unit is 1 byte and the number of subelements is 2.
*
*/
class PrimitiveDataTypeBase : public DataTypeImpl {
Expand All @@ -934,12 +937,21 @@ class PrimitiveDataTypeBase : public DataTypeImpl {
return data_type_;
}

int32_t GetNumSubElems() const {
return num_sub_elems_;
}

bool HasSubElems() const {
return num_sub_elems_ > 1;
}

protected:
PrimitiveDataTypeBase(size_t size, int32_t data_type)
: DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type} {}
PrimitiveDataTypeBase(size_t size, int32_t data_type, int32_t num_sub_elems)
: DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type}, num_sub_elems_{num_sub_elems} {}

private:
const int32_t data_type_;
const int32_t num_sub_elems_; // > 1 for subbyte primitives, 1 for normal primitives.
};

/**
Expand All @@ -965,9 +977,9 @@ class PrimitiveDataType : public PrimitiveDataTypeBase {
}

private:
PrimitiveDataType()
explicit PrimitiveDataType(int32_t num_sub_elems)
: PrimitiveDataTypeBase{sizeof(T),
utils::ToTensorProtoElementType<T>()} {
utils::ToTensorProtoElementType<T>(), num_sub_elems} {
}
};

Expand Down Expand Up @@ -1074,15 +1086,30 @@ inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const {
return SequenceTensorType<ELEM_TYPE>::Type(); \
}

#define ORT_REGISTER_PRIM_TYPE(TYPE) \
template <> \
MLDataType PrimitiveDataType<TYPE>::Type() { \
static PrimitiveDataType<TYPE> prim_data_type; \
return &prim_data_type; \
} \
template <> \
MLDataType DataTypeImpl::GetType<TYPE>() { \
return PrimitiveDataType<TYPE>::Type(); \
#define ORT_REGISTER_PRIM_TYPE(TYPE) \
template <> \
MLDataType PrimitiveDataType<TYPE>::Type() { \
static PrimitiveDataType<TYPE> prim_data_type(1); \
return &prim_data_type; \
} \
template <> \
MLDataType DataTypeImpl::GetType<TYPE>() { \
return PrimitiveDataType<TYPE>::Type(); \
}

// Registers a subbyte primitive.
// Examples:
// - Int4x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Int4x2, 2)
// - [not supported] Int3x8 could store 8 packed 3-bit elements in 3 bytes: ORT_*_SUBBYTE_TYPE(Int3x8, 8)
#define ORT_REGISTER_PRIM_SUBBYTE_TYPE(TYPE, NUM_SUB_ELEMS) \
template <> \
MLDataType PrimitiveDataType<TYPE>::Type() { \
static PrimitiveDataType<TYPE> prim_data_type(NUM_SUB_ELEMS); \
return &prim_data_type; \
} \
template <> \
MLDataType DataTypeImpl::GetType<TYPE>() { \
return PrimitiveDataType<TYPE>::Type(); \
}

#define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \
Expand Down
24 changes: 24 additions & 0 deletions include/onnxruntime/core/framework/data_types_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
function<Float8E5M2FNUZ>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
function<Int4x2>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
function<UInt4x2>(__VA_ARGS__); \
break; \
default: \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
}
Expand Down Expand Up @@ -153,6 +159,12 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
retval = function<Float8E5M2FNUZ>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
retval = function<Int4x2>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
retval = function<UInt4x2>(__VA_ARGS__); \
break; \
default: \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
}
Expand Down Expand Up @@ -203,6 +215,12 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
function<BFloat16>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
function<Int4x2>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
function<UInt4x2>(__VA_ARGS__); \
break; \
default: \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
}
Expand Down Expand Up @@ -251,6 +269,12 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
retval = function<BFloat16>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
retval = function<Int4x2>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
retval = function<UInt4x2>(__VA_ARGS__); \
break; \
default: \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
}
Expand Down
Loading

0 comments on commit b02d5e6

Please sign in to comment.