diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 12b772ceff282..34911cfc7972e 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -93,7 +93,10 @@ jobs: github_token: ${{ secrets.github_token }} reporter: github-pr-check level: warning - flags: --linelength=120 --exclude=java/src/main/native/*.c + flags: --linelength=120 + --exclude=java/src/main/native/*.c + --exclude=onnxruntime/core/mlas/inc/* + --exclude=onnxruntime/core/mlas/lib/* filter: "-runtime/references" lint-js: diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index bf38dd56247d9..4d7493bd69650 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1373,7 +1373,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32)
+
T1 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32), tensor(int4), tensor(uint4)
Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, 16-bit integer tensors, or 32-bit signed integer tensors.
T2 : tensor(float16), tensor(float)
Constrain 'y', 'x_scale' to float tensors.
@@ -4832,7 +4832,7 @@ This version of the operator has been available since version 1 of the 'com.micr
T1 : tensor(float16), tensor(float)
Constrain 'x', 'y_scale' to float tensors.
-
T2 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16)
+
T2 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int4), tensor(uint4)
Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 4533884a51773..8092c26da651a 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -87,7 +87,7 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||[19, 20]|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||[13, 18]|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||[10, 12]|**T** = tensor(int32), tensor(int8), tensor(uint8)| @@ -259,7 +259,7 @@ Do not modify directly.* |||[7, 11]|**T** = tensor(double), tensor(float)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| |QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)| |||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |||[10, 12]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| @@ -418,7 +418,7 @@ Do not modify directly.* |TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float)| |||[1, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float)| -|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)| @@ -468,7 +468,7 @@ Do not modify directly.* |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)| -|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)
**T2** = tensor(float)| +|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float)| |DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float)| @@ -504,7 +504,7 @@ Do not modify directly.* |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearWhere|*in* condition:**B**
*in* X:**T**
*in* x_scale:**TF**
*in* x_zero_point:**T**
*in* Y:**T**
*in* y_scale:**TF**
*in* y_zero_point:**T**
*in* z_scale:**TF**
*in* z_zero_point:**T**
*out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)| diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index f3942128077de..b197d88090432 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -15,6 +15,7 @@ #include "core/framework/endian.h" #include "core/framework/float8.h" #include "core/framework/float16.h" +#include "core/framework/int4.h" #include "core/graph/onnx_protobuf.h" #include "core/framework/to_tensor_proto_element_type.h" @@ -280,7 +281,8 @@ struct IsAnyOf { template struct IsTensorContainedType : public IsAnyOf { * Base class for primitive Tensor contained types * * \details This class contains an integer constant that can be - * used for input data type dispatching + * used for input data type dispatching. This class also stores the number of subelements per size units. + * Example: For int4, the size unit is 1 byte and the number of subelements is 2. * */ class PrimitiveDataTypeBase : public DataTypeImpl { @@ -934,12 +937,21 @@ class PrimitiveDataTypeBase : public DataTypeImpl { return data_type_; } + int32_t GetNumSubElems() const { + return num_sub_elems_; + } + + bool HasSubElems() const { + return num_sub_elems_ > 1; + } + protected: - PrimitiveDataTypeBase(size_t size, int32_t data_type) - : DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type} {} + PrimitiveDataTypeBase(size_t size, int32_t data_type, int32_t num_sub_elems) + : DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type}, num_sub_elems_{num_sub_elems} {} private: const int32_t data_type_; + const int32_t num_sub_elems_; // > 1 for subbyte primitives, 1 for normal primitives. }; /** @@ -965,9 +977,9 @@ class PrimitiveDataType : public PrimitiveDataTypeBase { } private: - PrimitiveDataType() + explicit PrimitiveDataType(int32_t num_sub_elems) : PrimitiveDataTypeBase{sizeof(T), - utils::ToTensorProtoElementType()} { + utils::ToTensorProtoElementType(), num_sub_elems} { } }; @@ -1074,15 +1086,30 @@ inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const { return SequenceTensorType::Type(); \ } -#define ORT_REGISTER_PRIM_TYPE(TYPE) \ - template <> \ - MLDataType PrimitiveDataType::Type() { \ - static PrimitiveDataType prim_data_type; \ - return &prim_data_type; \ - } \ - template <> \ - MLDataType DataTypeImpl::GetType() { \ - return PrimitiveDataType::Type(); \ +#define ORT_REGISTER_PRIM_TYPE(TYPE) \ + template <> \ + MLDataType PrimitiveDataType::Type() { \ + static PrimitiveDataType prim_data_type(1); \ + return &prim_data_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return PrimitiveDataType::Type(); \ + } + +// Registers a subbyte primitive. +// Examples: +// - Int4x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Int4x2, 2) +// - [not supported] Int3x8 could store 8 packed 3-bit elements in 3 bytes: ORT_*_SUBBYTE_TYPE(Int3x8, 8) +#define ORT_REGISTER_PRIM_SUBBYTE_TYPE(TYPE, NUM_SUB_ELEMS) \ + template <> \ + MLDataType PrimitiveDataType::Type() { \ + static PrimitiveDataType prim_data_type(NUM_SUB_ELEMS); \ + return &prim_data_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return PrimitiveDataType::Type(); \ } #define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \ diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index 3a3b5cb6888f2..05f4c10995ef2 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -93,6 +93,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \ function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -153,6 +159,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \ retval = function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + retval = function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -203,6 +215,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -251,6 +269,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ retval = function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + retval = function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h new file mode 100644 index 0000000000000..228c1e4e872de --- /dev/null +++ b/include/onnxruntime/core/framework/int4.h @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/common/common.h" +#include "core/common/gsl.h" + +namespace onnxruntime { + +template +struct Int4Traits; + +template <> +struct Int4Traits { + using UnpackedType = int8_t; + static constexpr int8_t min_val = -8; + static constexpr int8_t max_val = 7; +}; + +template <> +struct Int4Traits { + using UnpackedType = uint8_t; + static constexpr uint8_t min_val = 0; + static constexpr uint8_t max_val = 15; +}; + +/// +/// Stores 2 packed 4-bit elements in 1 byte. +/// +/// Set to true if signed int4, or false if unsigned uint4. +template +struct Int4x2Base { + using UnpackedType = typename Int4Traits::UnpackedType; + static constexpr UnpackedType min_val = Int4Traits::min_val; + static constexpr UnpackedType max_val = Int4Traits::max_val; + + std::byte bits_{}; + + Int4x2Base() = default; + + explicit Int4x2Base(std::byte bits) { + bits_ = bits; + } + + Int4x2Base(UnpackedType val0, UnpackedType val1) { + bits_ = static_cast(((val1 & 0xF) << 4) | (val0 & 0xF)); + } + + static inline int8_t SignExtendLower4Bits(std::byte bits) { + // Sign-extend lower 4-bits by left shifting and then doing an arithmetic right shift. + constexpr uint8_t shift = (sizeof(int32_t) * 8) - 4; + return static_cast((static_cast(bits) << shift) >> shift); + } + + inline UnpackedType GetElem(size_t index) const { + assert(index <= 1); + const uint8_t shift = 4 * static_cast(index); + const std::byte val = (bits_ >> shift) & std::byte{0xF}; + + if constexpr (Signed) { + return SignExtendLower4Bits(val); + } else { + return static_cast(val); + } + } + + inline void SetElem(size_t index, UnpackedType val) { + assert(index <= 1); + const uint8_t shift = 4 * static_cast(index); + const std::byte mask = std::byte{0xF0} >> shift; + + bits_ &= mask; // Clear 4-bit element to 0 + bits_ |= static_cast((val & 0xF) << shift); // Set 4-bit element to val + } + + inline std::byte ToBits() const { + return bits_; + } + + static size_t CalcNumInt4Pairs(size_t num_int4_elems) { + return (num_int4_elems + 1) / 2; + } + + static bool Unpack(gsl::span dst, gsl::span> src) { + if (CalcNumInt4Pairs(dst.size()) != src.size()) { + return false; + } + + for (size_t i = 0; i < dst.size(); i++) { + size_t r = i >> 1; // i / 2; + size_t c = i & 0x1; // i % 2; + dst[i] = src[r].GetElem(c); + } + + return true; + } + + static bool Pack(gsl::span> dst, gsl::span src) { + if (src.empty() || (CalcNumInt4Pairs(src.size()) != dst.size())) { + return false; + } + + size_t src_i = 0; + size_t dst_i = 0; + + for (; src_i < src.size() - 1; src_i += 2) { + dst[dst_i++] = Int4x2Base(src[src_i], src[src_i + 1]); + } + + if (src_i < src.size()) { + dst[dst_i] = Int4x2Base(src[src_i], 0); + } + + return true; + } +}; + +using Int4x2 = Int4x2Base; +using UInt4x2 = Int4x2Base; +static_assert(sizeof(Int4x2) == sizeof(std::byte)); +static_assert(sizeof(UInt4x2) == sizeof(std::byte)); +} // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index a867ab6066485..96725aa103064 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -145,6 +145,17 @@ class Tensor final { /// Bytes required. static size_t CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape); + /// + /// Calculate the required storage for the tensor. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Power of 2 alignment to include in calculation. + /// Bumps up result to the nearest multiple of alignment. Set to 0 to ignore. + /// The resulting storage size. + /// Status indicating success or failure. + static Status CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape, size_t alignment, + size_t& storage_size); /** Returns the data type. */ @@ -200,7 +211,7 @@ class Tensor final { ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", "T ", "!=", dtype_); T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - return gsl::make_span(data, static_cast(shape_.Size())); + return gsl::make_span(data, static_cast(NumStorageElements())); } template @@ -217,7 +228,7 @@ class Tensor final { ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", "T ", "!=", dtype_); const T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - return gsl::make_span(data, static_cast::size_type>(shape_.Size())); + return gsl::make_span(data, static_cast::size_type>(NumStorageElements())); } void* MutableDataRaw(MLDataType type) { @@ -271,6 +282,19 @@ class Tensor final { byte_offset_ = byte_offset; } + /// + /// The number of Tensor "storage" elements. A single storage element may contain multiple sub-elements for + /// sub-byte data types (e.g., int4). + /// + /// For element types smaller than 1 byte (e.g., int4), a single storage element stores multiple sub-byte elements. + /// Example: Tensor of shape (4,) has 2 storage elements. + /// + /// For element types >= 1 byte, this function returns the product of the shape. + /// Example: Tensor of shape (4,) has 4 storage elements. + /// + /// Number of tensor storage elements + int64_t NumStorageElements() const; + /** The number of bytes of data. */ diff --git a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h index 21253eb4a6e83..e9e28e4864a67 100644 --- a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h +++ b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h @@ -12,6 +12,7 @@ #include "core/framework/float8.h" #include "core/framework/float16.h" +#include "core/framework/int4.h" namespace onnxruntime { namespace utils { @@ -97,6 +98,14 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_INT4; +} +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_UINT4; +} } // namespace utils } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 67cc493f04ab9..16701f2e0d923 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -196,7 +196,10 @@ typedef enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, // Non-IEEE floating-point format based on IEEE754 single-precision ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision + // Int4 types were introduced in ONNX 1.16. See https://onnx.ai/onnx/technical/int4.html + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of packed uint4 values (size == 1 byte) + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of packed int4 values (size == 1 byte) } ONNXTensorElementDataType; // Synced with onnx TypeProto oneof diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index b37ce2f72e721..8aa885cf1ebd6 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -64,10 +64,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid); @@ -200,15 +204,32 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index 6c4aec417a033..72ab5a9e898c7 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -639,6 +639,8 @@ ORT_REGISTER_TENSOR_TYPE(Float8E4M3FNUZ); ORT_REGISTER_TENSOR_TYPE(Float8E5M2); ORT_REGISTER_TENSOR_TYPE(Float8E5M2FNUZ); #endif +ORT_REGISTER_TENSOR_TYPE(Int4x2); +ORT_REGISTER_TENSOR_TYPE(UInt4x2); #if !defined(DISABLE_SPARSE_TENSORS) ORT_REGISTER_SPARSE_TENSOR_TYPE(int32_t); @@ -700,6 +702,9 @@ ORT_REGISTER_SEQ_TENSOR_TYPE(Float8E5M2FNUZ); #endif +ORT_REGISTER_SEQ_TENSOR_TYPE(Int4x2); +ORT_REGISTER_SEQ_TENSOR_TYPE(UInt4x2); + #if !defined(DISABLE_ML_OPS) ORT_REGISTER_SEQ(VectorMapStringToFloat); ORT_REGISTER_SEQ(VectorMapInt64ToFloat); @@ -725,7 +730,9 @@ ORT_REGISTER_SEQ(VectorMapInt64ToFloat); ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E4M3FN); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E4M3FNUZ); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2); \ - ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2FNUZ); + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2FNUZ); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int4x2); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2); #else @@ -743,7 +750,9 @@ ORT_REGISTER_SEQ(VectorMapInt64ToFloat); ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, uint32_t); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, uint64_t); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, MLFloat16); \ - ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, BFloat16); + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, BFloat16); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int4x2); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2); #endif @@ -808,6 +817,8 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_TENSOR_PROTO(Float8E5M2, reg_fn); REGISTER_TENSOR_PROTO(Float8E5M2FNUZ, reg_fn); #endif + REGISTER_TENSOR_PROTO(Int4x2, reg_fn); + REGISTER_TENSOR_PROTO(UInt4x2, reg_fn); #if !defined(DISABLE_SPARSE_TENSORS) REGISTER_SPARSE_TENSOR_PROTO(int32_t, reg_fn); @@ -867,6 +878,9 @@ void RegisterAllProtos(const std::function& reg_fn) { #endif + REGISTER_SEQ_TENSOR_PROTO(Int4x2, reg_fn); + REGISTER_SEQ_TENSOR_PROTO(UInt4x2, reg_fn); + #if !defined(DISABLE_ML_OPS) REGISTER_ONNX_PROTO(VectorMapStringToFloat, reg_fn); REGISTER_ONNX_PROTO(VectorMapInt64ToFloat, reg_fn); @@ -894,7 +908,9 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E4M3FN, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E4M3FNUZ, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2, reg_fn); \ - REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2FNUZ, reg_fn); + REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2FNUZ, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int4x2, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn); #else @@ -912,7 +928,9 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_OPTIONAL_PROTO(ORT_TYPE, uint32_t, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, uint64_t, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, MLFloat16, reg_fn); \ - REGISTER_OPTIONAL_PROTO(ORT_TYPE, BFloat16, reg_fn); + REGISTER_OPTIONAL_PROTO(ORT_TYPE, BFloat16, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int4x2, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn); #endif @@ -973,6 +991,10 @@ const char* DataTypeImpl::ToString(MLDataType type) { return "Float8E5M2"; case TensorProto_DataType_FLOAT8E5M2FNUZ: return "Float8E5M2FNUZ"; + case TensorProto_DataType_INT4: + return "Int4x2"; + case TensorProto_DataType_UINT4: + return "UInt4x2"; default: break; } @@ -1041,6 +1063,10 @@ const TensorTypeBase* DataTypeImpl::TensorTypeFromONNXEnum(int type) { return DataTypeImpl::GetTensorType()->AsTensorType(); #endif + case TensorProto_DataType_INT4: + return DataTypeImpl::GetTensorType()->AsTensorType(); + case TensorProto_DataType_UINT4: + return DataTypeImpl::GetTensorType()->AsTensorType(); default: ORT_NOT_IMPLEMENTED("tensor type ", type, " is not supported"); @@ -1090,6 +1116,10 @@ const SequenceTensorTypeBase* DataTypeImpl::SequenceTensorTypeFromONNXEnum(int t return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); #endif + case TensorProto_DataType_INT4: + return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); + case TensorProto_DataType_UINT4: + return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); default: ORT_NOT_IMPLEMENTED("sequence tensor type ", type, " is not supported"); @@ -1183,6 +1213,8 @@ ORT_REGISTER_PRIM_TYPE(Float8E5M2); ORT_REGISTER_PRIM_TYPE(Float8E5M2FNUZ); #endif +ORT_REGISTER_PRIM_SUBBYTE_TYPE(Int4x2, 2); +ORT_REGISTER_PRIM_SUBBYTE_TYPE(UInt4x2, 2); namespace { template diff --git a/onnxruntime/core/framework/element_type_lists.h b/onnxruntime/core/framework/element_type_lists.h index 3d956ec26d22e..2478dc27162ac 100644 --- a/onnxruntime/core/framework/element_type_lists.h +++ b/onnxruntime/core/framework/element_type_lists.h @@ -11,6 +11,7 @@ #include "core/common/type_list.h" #include "core/framework/float8.h" #include "core/framework/float16.h" +#include "core/framework/int4.h" namespace onnxruntime { @@ -82,6 +83,12 @@ using AllIRv9 = using AllIRv9 = AllIRv4; #endif +using AllIRv10 = + boost::mp11::mp_push_back< + AllIRv9, + UInt4x2, + Int4x2>; + using All = AllIRv4; #if !defined(DISABLE_FLOAT8_TYPES) diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 32a5f749af084..894e0daae94b6 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -529,17 +529,8 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va return Status(ONNXRUNTIME, FAIL, "Trying to allocate memory for unused optional inputs/outputs"); } - size_t size; - int64_t len = shape.Size(); - if (len < 0) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape cannot contain any negative value"); - } - if (static_cast(len) > std::numeric_limits::max()) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape is too large"); - } - if (!IAllocator::CalcMemSizeForArrayWithAlignment(static_cast(len), element_type->Size(), &size)) { - return Status(ONNXRUNTIME, FAIL, "size overflow"); - } + size_t size = 0; + ORT_RETURN_IF_ERROR(Tensor::CalculateTensorStorageSize(element_type, shape, kAllocAlignment, size)); // Lazily get the allocator only if needed. AllocatorPtr alloc = nullptr; diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc index 3963504273599..1370580bad4f6 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc @@ -78,6 +78,12 @@ ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) { case TensorType::TensorProto_DataType_FLOAT8E5M2FNUZ: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ; } // Non-IEEE floating-point format based on IEEE754 single-precision + case TensorType::TensorProto_DataType_INT4: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; + } // maps to a pair of int4 (size == 1 byte) + case TensorType::TensorProto_DataType_UINT4: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; + } // maps to a pair of uint4 (size == 1 byte) default: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } @@ -126,4 +132,4 @@ ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, ORT_API(void, OrtApis::ReleaseMapTypeInfo, _Frees_ptr_opt_ OrtMapTypeInfo* ptr) { std::unique_ptr p(ptr); -} \ No newline at end of file +} diff --git a/onnxruntime/core/framework/ort_value_tensor_slicer.cc b/onnxruntime/core/framework/ort_value_tensor_slicer.cc index cf4020d830bff..2acc1e301d6e0 100644 --- a/onnxruntime/core/framework/ort_value_tensor_slicer.cc +++ b/onnxruntime/core/framework/ort_value_tensor_slicer.cc @@ -14,7 +14,13 @@ OrtValueTensorSlicer OrtValueTensorSlicer::Create(T& ort_value, int64_t sl ORT_ENFORCE(ort_value.IsTensor(), "Can't slice a non-tensor OrtValue. Type was ", ort_value.Type()); ORT_ENFORCE(ort_value.IsAllocated(), "OrtValue has not been allocated so can't be sliced."); - auto& tensor_shape = ort_value.template Get().Shape(); + const Tensor& tensor = ort_value.template Get(); + auto* prim_type = tensor.DataType()->AsPrimitiveDataType(); + if (prim_type != nullptr) { + // TODO(adrianlizarraga): Support slicing Tensors of subbyte element types (e.g., int4). + ORT_ENFORCE(!prim_type->HasSubElems(), "Can't slice a tensor with a subbyte element type"); + } + auto& tensor_shape = tensor.Shape(); ORT_ENFORCE(gsl::narrow_cast(tensor_shape.NumDimensions()) >= slice_dimension, "Insufficient dimensions to slice on ", slice_dimension, ". Shape:", tensor_shape); diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 692ca08772535..059de8e3c8c4a 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -37,20 +37,10 @@ namespace session_state_utils { // It can handle arena-based allocators and non-arena based allocators. static common::Status AllocateBufferUsingDeviceAllocatorFromShapeAndType(const TensorShape& tensor_shape, const DataTypeImpl* type, const AllocatorPtr& alloc, /*out*/ void*& p_data) { - int64_t shape_size = tensor_shape.Size(); - if (shape_size < 0) - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "shape.Size() must >=0"); + size_t mem_size = 0; + ORT_RETURN_IF_ERROR(Tensor::CalculateTensorStorageSize(type, tensor_shape, /*alignment*/ 0, mem_size)); - p_data = nullptr; - if (shape_size > 0) { - SafeInt mem_size = 0; - - if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), type->Size(), &mem_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed memory size calculation"); - } - - p_data = alloc->Reserve(mem_size); - } + p_data = alloc->Reserve(mem_size); return Status::OK(); } diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index 36f03a9b1046a..60d768cc59a5d 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -27,20 +27,54 @@ int64_t GetSizeFromStrides(const TensorShape& shape, gsl::span st } // namespace #endif -size_t Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape) { - int64_t shape_size = shape.Size(); - if (shape_size < 0) - ORT_THROW("shape.Size() must >=0"); +/// +/// Get the number of elements for a Tensor of the given element type and shape size. +/// +/// For element types smaller than 1 byte (e.g., int4), a single storage element stores multiple sub-byte elements. +/// Example: Tensor of shape_size 4 has 2 storage elements. +/// +/// For element types >= 1 byte, this function returns the product of the shape. +/// Example: Tensor of shape_size 4 has 4 storage elements. +/// +/// Data type of the tensor elements. +/// The number of elements indicated by the shape (i.e., shape.Size()). +/// Number of Tensor elements. Returns -1 if shape_size is negative. +static int64_t GetNumTensorStorageElems(MLDataType elt_type, int64_t shape_size) { + int64_t num_elems = shape_size; + auto prim_type = elt_type->AsPrimitiveDataType(); + + if (prim_type != nullptr && num_elems > 0 && prim_type->HasSubElems()) { + const int64_t num_sub_elems = prim_type->GetNumSubElems(); + num_elems = (num_elems + (num_sub_elems - 1)) / num_sub_elems; + } - if (shape_size > 0) { - SafeInt len = 0; - if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), elt_type->Size(), &len)) - ORT_THROW("tensor failed memory size calculation"); + return num_elems; +} + +Status Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape, size_t alignment, + /*out*/ size_t& storage_size) { + int64_t num_elems = GetNumTensorStorageElems(elt_type, shape.Size()); + ORT_RETURN_IF(num_elems < 0, "Tensor shape.Size() must be >= 0"); - return len; + if (num_elems > 0) { + if (static_cast(num_elems) > std::numeric_limits::max()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape is too large"); + } + if (!IAllocator::CalcMemSizeForArrayWithAlignment(static_cast(num_elems), elt_type->Size(), alignment, + &storage_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Calculation for Tensor storage size overflowed"); + } + } else { + storage_size = 0; } - return 0; + return Status::OK(); +} + +size_t Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape) { + size_t storage_size = 0; + ORT_THROW_IF_ERROR(CalculateTensorStorageSize(elt_type, shape, 0, storage_size)); + return storage_size; } Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, @@ -98,14 +132,19 @@ void Tensor::InitOrtValue(Tensor&& tensor, OrtValue& ort_value) { ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } -size_t Tensor::SizeInBytes() const { +int64_t Tensor::NumStorageElements() const { #ifdef ENABLE_STRIDED_TENSORS int64_t size = IsContiguous() ? shape_.Size() : GetSizeFromStrides(shape_, strides_); #else int64_t size = shape_.Size(); #endif - size_t ret; - if (!IAllocator::CalcMemSizeForArray(SafeInt(size), dtype_->Size(), &ret)) { + + return GetNumTensorStorageElems(dtype_, size); +} + +size_t Tensor::SizeInBytes() const { + size_t ret = 0; + if (!IAllocator::CalcMemSizeForArray(SafeInt(NumStorageElements()), dtype_->Size(), &ret)) { ORT_THROW("tensor size overflow"); } return ret; @@ -138,6 +177,8 @@ void Tensor::Init(MLDataType elt_type, const TensorShape& shape, void* p_raw_dat ORT_ENFORCE(shape.NumDimensions() == strides.size(), "Length of strides doesn't match tensor dimension size."); strides_.assign(strides.begin(), strides.end()); is_contiguous_ = CheckIsContiguous(); + ORT_ENFORCE(is_contiguous_ || !dtype_->HasSubElems(), + "Do not support subbyte element types with non-contiguous strided tensors."); } #else ORT_UNUSED_PARAMETER(strides); @@ -254,6 +295,8 @@ void Tensor::SetShapeAndStrides(const TensorShape& new_shape, gsl::spanHasSubElems(), + "Do not support subbyte element types with non-contiguous strided tensors."); } #endif diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 6e11bfe1ac8ea..418e46924fb9f 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -164,6 +164,12 @@ constexpr ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementData case o::TensorProto_DataType_BOOL: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; break; + case o::TensorProto_DataType_INT4: + type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; + break; + case o::TensorProto_DataType_UINT4: + type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; + break; default: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; break; @@ -365,4 +371,4 @@ ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, *out = ptr.release(); return nullptr; API_IMPL_END -} \ No newline at end of file +} diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 8a2db6d5728af..6af78f18fb82f 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -59,6 +59,20 @@ TensorProto ToScalarTensor(TensorProto_DataType datatype, int32_t value) { return t; \ } +#define TO_TENSOR_ORT_TYPE_INT4(TYPE) \ + template <> \ + TensorProto ToTensor(const onnxruntime::TYPE& value) { \ + return ToScalarTensor(ToTensorProtoElementType(), static_cast(value.ToBits())); \ + } \ + template <> \ + TensorProto ToTensor(const std::vector& values) { \ + TensorProto t = ToTensorInitialize(ToTensorProtoElementType()); \ + for (const onnxruntime::TYPE& val : values) { \ + t.add_int32_data(static_cast(val.ToBits())); \ + } \ + return t; \ + } + namespace ONNX_NAMESPACE { // Provide template specializations for onnxruntime-specific types. @@ -70,6 +84,8 @@ TO_TENSOR_ORT_TYPE(Float8E4M3FNUZ) TO_TENSOR_ORT_TYPE(Float8E5M2) TO_TENSOR_ORT_TYPE(Float8E5M2FNUZ) #endif +TO_TENSOR_ORT_TYPE_INT4(Int4x2) +TO_TENSOR_ORT_TYPE_INT4(UInt4x2) bool operator==(const ONNX_NAMESPACE::TensorShapeProto_Dimension& l, const ONNX_NAMESPACE::TensorShapeProto_Dimension& r) { @@ -125,6 +141,29 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t reinterpret_cast(p_data)); } +#define DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(INT4_TYPE) \ + template <> \ + Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, \ + /*out*/ INT4_TYPE* p_data) { \ + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \ + \ + ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \ + \ + size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \ + ORT_RETURN_IF_NOT(num_packed_pairs == raw_data_len, "Unexpected number of packed int4 pairs"); \ + \ + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), \ + num_packed_pairs); \ + gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); \ + \ + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \ + \ + return Status::OK(); \ + } + +DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2) +DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2) + static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const ORTCHAR_T* tensor_proto_dir, std::basic_string& external_file_path, @@ -261,6 +300,32 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, reinterpret_cast(p_data)); } +#define DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(INT4_TYPE) \ + template <> \ + Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, \ + const ORTCHAR_T* tensor_proto_dir, size_t expected_num_elements, \ + /*out*/ INT4_TYPE* p_data) { \ + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \ + \ + ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \ + std::vector unpacked_tensor; \ + ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); \ + \ + size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \ + ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); \ + \ + gsl::span src_span = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), \ + num_packed_pairs); \ + gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); \ + \ + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \ + \ + return Status::OK(); \ + } + +DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Int4x2) +DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(UInt4x2) + #define INSTANTIATE_UNPACK_EXTERNAL_TENSOR(type) \ template Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto&, const ORTCHAR_T*, size_t, type*); @@ -602,6 +667,40 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d #endif +#define DEFINE_INT4_UNPACK_TENSOR_IMPL(INT4_TYPE, ONNX_INT4_TYPE) \ + template <> \ + Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \ + /*out*/ INT4_TYPE* p_data, size_t expected_num_elems) { \ + if (nullptr == p_data) { \ + const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); \ + return size == 0 ? Status::OK() : Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ + } \ + if (ONNX_NAMESPACE::ONNX_INT4_TYPE != tensor.data_type()) { \ + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ + } \ + \ + size_t expected_int4_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elems); \ + \ + if (raw_data != nullptr) { \ + return UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); \ + } \ + \ + ORT_RETURN_IF_NOT(static_cast(tensor.int32_data_size()) == expected_int4_pairs, \ + "UnpackTensor: the pre-allocated size does not match the size in proto"); \ + \ + for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { \ + p_data[i] = INT4_TYPE(static_cast(tensor.int32_data()[i])); \ + } \ + \ + return Status::OK(); \ + } + +// UnpackTensor +DEFINE_INT4_UNPACK_TENSOR_IMPL(Int4x2, TensorProto_DataType_INT4) + +// UnpackTensor +DEFINE_INT4_UNPACK_TENSOR_IMPL(UInt4x2, TensorProto_DataType_UINT4) + // UnpackTensor from raw data, external data or the type specific data field. // Uses the model path to construct the full path for loading external data. In case when model_path is empty // it uses current directory. @@ -651,6 +750,8 @@ INSTANTIATE_UNPACK_TENSOR(Float8E4M3FNUZ) INSTANTIATE_UNPACK_TENSOR(Float8E5M2) INSTANTIATE_UNPACK_TENSOR(Float8E5M2FNUZ) #endif +INSTANTIATE_UNPACK_TENSOR(Int4x2) +INSTANTIATE_UNPACK_TENSOR(UInt4x2) #define CASE_PROTO_TRACE(X, Y) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ @@ -659,6 +760,13 @@ INSTANTIATE_UNPACK_TENSOR(Float8E5M2FNUZ) } \ break; +#define CASE_PROTO_TRACE_INT4(X, Y) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + if (!IAllocator::CalcMemSizeForArrayWithAlignment(Y::CalcNumInt4Pairs(size), sizeof(Y), out)) { \ + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid TensorProto"); \ + } \ + break; + template common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out) { const auto& dims = tensor_proto.dims(); @@ -692,6 +800,8 @@ common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& CASE_PROTO_TRACE(FLOAT8E5M2, Float8E5M2); CASE_PROTO_TRACE(FLOAT8E5M2FNUZ, Float8E5M2FNUZ); #endif + CASE_PROTO_TRACE_INT4(UINT4, UInt4x2); + CASE_PROTO_TRACE_INT4(INT4, Int4x2); default: return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); } @@ -998,6 +1108,8 @@ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, CASE_PROTO(FLOAT8E5M2, Float8E5M2); CASE_PROTO(FLOAT8E5M2FNUZ, Float8E5M2FNUZ); #endif + CASE_PROTO(INT4, Int4x2); + CASE_PROTO(UINT4, UInt4x2); case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_STRING: ORT_RETURN_IF_ERROR(UnpackTensor(tensor_proto, raw_data, raw_data_len, static_cast(preallocated), @@ -1053,6 +1165,8 @@ ONNXTensorElementDataType CApiElementTypeFromProtoType(int type) { CASE_TYPE(FLOAT8E5M2) CASE_TYPE(FLOAT8E5M2FNUZ) #endif + CASE_TYPE(UINT4) + CASE_TYPE(INT4) default: return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } @@ -1570,6 +1684,20 @@ template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::T break; \ } +#define CASE_UNPACK_INT4(TYPE, ELEMENT_TYPE, DATA_SIZE) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \ + TensorShape tensor_shape = GetTensorShapeFromTensorProto(initializer); \ + size_t element_count = static_cast(tensor_shape.Size()); \ + size_t packed_element_count = ELEMENT_TYPE::CalcNumInt4Pairs(element_count); \ + unpacked_tensor.resize(packed_element_count * sizeof(ELEMENT_TYPE)); \ + return onnxruntime::utils::UnpackTensor( \ + initializer, \ + initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \ + initializer.has_raw_data() ? initializer.raw_data().size() : 0, \ + reinterpret_cast(unpacked_tensor.data()), element_count); \ + break; \ + } + Status UnpackInitializerData(const onnx::TensorProto& initializer, const Path& model_path, std::vector& unpacked_tensor) { @@ -1604,6 +1732,8 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer, CASE_UNPACK(FLOAT8E5M2, onnxruntime::Float8E5M2, int32_data_size); CASE_UNPACK(FLOAT8E5M2FNUZ, onnxruntime::Float8E5M2FNUZ, int32_data_size); #endif + CASE_UNPACK_INT4(INT4, Int4x2, int32_data_size); + CASE_UNPACK_INT4(UINT4, UInt4x2, int32_data_size); default: break; } diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index f0b1b9109d405..fd14eeeb33d27 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -222,6 +222,16 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType #endif +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; +} + +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; +} + int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); #ifdef ENABLE_TRAINING diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 47f61a43458ed..762d892c45ce8 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -164,7 +164,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T2", OpSchema::Optional) .Output(0, "y", "N-D quantized output tensor. It has same shape as input 'x'.", "T2") .TypeConstraint("T1", {"tensor(float16)", "tensor(float)"}, "Constrain 'x', 'y_scale' to float tensors.") - .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", "tensor(uint16)"}, + .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", "tensor(uint16)", "tensor(int4)", + "tensor(uint4)"}, "Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.") .SetDoc(QuantizeLinear_ver1_doc) .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { @@ -206,7 +207,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(DequantizeLinear, 1, .Output(0, "y", "N-D full precision output tensor. It has same shape as input 'x'.", "T2") .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", - "tensor(uint16)", "tensor(int32)"}, + "tensor(uint16)", "tensor(int32)", "tensor(int4)", + "tensor(uint4)"}, "Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, " "16-bit integer tensors, or 32-bit signed integer tensors.") .TypeConstraint("T2", {"tensor(float16)", "tensor(float)"}, diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index ce7838556fbf0..cdfd283899c8c 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1222,6 +1222,26 @@ MlasQuantizeLinear( OutputType ZeroPoint ); +void +MLASCALL +MlasQuantizeLinearU4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ); + +void +MLASCALL +MlasQuantizeLinearS4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ); + /** * @brief Requantize a block of the intermediate buffer to the output buffer, * optionally adding the supplied bias diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 04da9ab4fd749..83200187963e1 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -680,6 +680,24 @@ void float Scale, int16_t ZeroPoint); +typedef +void +(MLASCALL MLAS_QUANTIZE_LINEAR_U4_KERNEL)( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_QUANTIZE_LINEAR_S4_KERNEL)( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint); + template struct MLAS_QUANT_KERNEL { @@ -826,6 +844,8 @@ extern "C" { MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8Kernel; MLAS_QUANTIZE_LINEAR_S16_KERNEL MlasQuantizeLinearS16Kernel; MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel; + MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; + MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; #if defined(MLAS_TARGET_AMD64) MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; @@ -1083,6 +1103,8 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel; MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; + MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; + MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; #endif #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; @@ -1112,6 +1134,8 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel; MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; + MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; + MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; @@ -2497,3 +2521,51 @@ MlasThreadedBufAlloc(size_t size) ThreadedBufSize = size; } } + +// +// Utilities for INT4 quantization. +// + +template +struct Int4Traits; + +template<> +struct Int4Traits { + using UnpackedType = int8_t; + static constexpr int8_t Min = -8; + static constexpr int8_t Max = 7; +}; + +template<> +struct Int4Traits { + using UnpackedType = uint8_t; + static constexpr int8_t Min = 0; + static constexpr int8_t Max = 15; +}; + +template +MLAS_FORCEINLINE +void +MlasSetInt4Element(uint8_t* Output, size_t ElemIndex, UnpackedType Value) +{ + static_assert(std::is_same_v || std::is_same_v); + + const size_t OutputIndex = ElemIndex >> 1; // which byte + const size_t NibbleIndex = ElemIndex & 0x1; // which 4-bit elem in the byte + const uint8_t Shift = static_cast(NibbleIndex << 2); // Either 0 or 4 + const uint8_t Mask = static_cast(0xF0 >> Shift); + uint8_t* Dst = &Output[OutputIndex]; + + *Dst &= Mask; // Clear 4-bit lane + *Dst |= static_cast((Value & 0xF) << Shift); // Set 4-bit lane +} + +template +MLAS_FORCEINLINE +void +MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueHigh) +{ + static_assert(std::is_same_v || std::is_same_v); + *Output = static_cast(((ValueHigh & 0xF) << 4) | (ValueLow & 0xF)); +} + diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 3f86b3f7c5062..72eb35c894094 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -274,6 +274,8 @@ Return Value: this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; + this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; + this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; this->NchwcBlockSize = 8; this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; @@ -545,6 +547,8 @@ Return Value: this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; + this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; + this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; #if defined(__linux__) unsigned long hwcap2 = getauxval(AT_HWCAP2); diff --git a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp index 1fed8af21b31c..0cfa56740edfb 100644 --- a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp +++ b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp @@ -107,6 +107,119 @@ Return Value: } } +template +void +MLASCALL +MlasQuantizeLinearInt4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer as int4 using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. Contains packed 4-bit elements. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + constexpr int32_t MinimumValue = Int4Traits::Min; + constexpr int32_t MaximumValue = Int4Traits::Max; + using UnpackedType = typename Int4Traits::UnpackedType; + + auto ScaleVector = vec_splats(Scale); + auto MinimumValueVector = vec_splats(float(MinimumValue)); + auto MaximumValueVector = vec_splats(float(MaximumValue)); + auto ZeroPointVector = vec_splats(float(ZeroPoint)); + + // Holds 16 quantized 8-bit values that will be packed into the output as packed 4-bit values. + UnpackedType TmpOutput[16] = {}; + + while (N >= 16) { + auto FloatVector0 = vec_xl(0, Input); + auto FloatVector1 = vec_xl(0, Input + 4); + auto FloatVector2 = vec_xl(0, Input + 8); + auto FloatVector3 = vec_xl(0, Input + 12); + + FloatVector0 = vec_div(FloatVector0, ScaleVector); + FloatVector1 = vec_div(FloatVector1, ScaleVector); + FloatVector2 = vec_div(FloatVector2, ScaleVector); + FloatVector3 = vec_div(FloatVector3, ScaleVector); + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + FloatVector0 = vec_add(FloatVector0, ZeroPointVector); + FloatVector1 = vec_add(FloatVector1, ZeroPointVector); + FloatVector2 = vec_add(FloatVector2, ZeroPointVector); + FloatVector3 = vec_add(FloatVector3, ZeroPointVector); + + FloatVector0 = vec_max(FloatVector0, MinimumValueVector); + FloatVector1 = vec_max(FloatVector1, MinimumValueVector); + FloatVector2 = vec_max(FloatVector2, MinimumValueVector); + FloatVector3 = vec_max(FloatVector3, MinimumValueVector); + + FloatVector0 = vec_min(FloatVector0, MaximumValueVector); + FloatVector1 = vec_min(FloatVector1, MaximumValueVector); + FloatVector2 = vec_min(FloatVector2, MaximumValueVector); + FloatVector3 = vec_min(FloatVector3, MaximumValueVector); + + auto IntegerVector0 = vec_signed(FloatVector0); + auto IntegerVector1 = vec_signed(FloatVector1); + auto IntegerVector2 = vec_signed(FloatVector2); + auto IntegerVector3 = vec_signed(FloatVector3); + + auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); + auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); + + auto CharVector = vec_pack(ShortVector0, ShortVector1); + vec_xst(CharVector, 0, static_cast(&TmpOutput[0])); + + MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]); + MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]); + MlasPackInt4Elements(Output++, TmpOutput[4], TmpOutput[5]); + MlasPackInt4Elements(Output++, TmpOutput[6], TmpOutput[7]); + MlasPackInt4Elements(Output++, TmpOutput[8], TmpOutput[9]); + MlasPackInt4Elements(Output++, TmpOutput[10], TmpOutput[11]); + MlasPackInt4Elements(Output++, TmpOutput[12], TmpOutput[13]); + MlasPackInt4Elements(Output++, TmpOutput[14], TmpOutput[15]); + + Input += 16; + N -= 16; + } + + for (size_t n = 0; n < N; n++) { + + float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); + FloatValue = std::max(FloatValue, static_cast(MinimumValue)); + FloatValue = std::min(FloatValue, static_cast(MaximumValue)); + UnpackedType IntValue = static_cast(FloatValue); + + MlasSetInt4Element(Output, n, IntValue); + } +} + void MLASCALL MlasQuantizeLinearU8Kernel( @@ -159,3 +272,29 @@ MlasQuantizeLinearS16Kernel( MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); } +void +MLASCALL +MlasQuantizeLinearU4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index ffecc2dbeff9e..ae638fafee18f 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -519,6 +519,87 @@ Return Value: } } +template +void +MLASCALL +MlasQuantizeLinearInt4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + constexpr int32_t MinimumValue = Int4Traits::Min; + constexpr int32_t MaximumValue = Int4Traits::Max; + using UnpackedType = typename Int4Traits::UnpackedType; + + auto ScaleVector = MlasBroadcastFloat32x4(Scale); + auto MinimumValueVector = MlasBroadcastFloat32x4(static_cast(MinimumValue - ZeroPoint)); + auto MaximumValueVector = MlasBroadcastFloat32x4(static_cast(MaximumValue - ZeroPoint)); + auto ZeroPointVector = MlasBroadcastInt32x4(ZeroPoint); + + // Holds 4 quantized 8bit values that will be packed into the output as packed 4bit values. + UnpackedType TmpOutput[4] = {}; + + while (N >= 4) { + + auto FloatVector = MlasLoadFloat32x4(Input); + auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, + MinimumValueVector, MaximumValueVector, ZeroPointVector); + + IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); + MlasQuantizeLinearStore4PackedValues(IntegerVector, &TmpOutput[0]); + MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]); + MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]); + + Input += 4; + N -= 4; + } + + for (size_t n = 0; n < N; n++) { + +#if defined(MLAS_NEON64_INTRINSICS) + auto FloatVector = vld1q_dup_f32(Input + n); +#elif defined(MLAS_LSX_INTRINSICS) + MLAS_FLOAT32X4 FloatVector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input+n, 0); +#else + auto FloatVector = _mm_load_ss(Input + n); +#endif + auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, + MinimumValueVector, MaximumValueVector, ZeroPointVector); + + MlasQuantizeLinearStoreSingleValue(IntegerVector, &TmpOutput[0]); + MlasSetInt4Element(Output, n, TmpOutput[0]); + } +} + +void +MLASCALL +MlasQuantizeLinearS4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearU4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + void MLASCALL MlasQuantizeLinearS8Kernel( @@ -571,6 +652,42 @@ MlasQuantizeLinearS16Kernel( MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); } +void +MLASCALL +MlasQuantizeLinearS4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearS4Kernel( +#else + MlasQuantizeLinearS4Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearU4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearU4Kernel( +#else + MlasQuantizeLinearU4Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + template<> void MLASCALL @@ -707,6 +824,31 @@ MlasQuantizeLinear( GetMlasPlatform().QuantizeLinearU16Kernel(Input, Output, N, Scale, ZeroPoint); } +void +MLASCALL +MlasQuantizeLinearS4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + GetMlasPlatform().QuantizeLinearS4Kernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearU4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + GetMlasPlatform().QuantizeLinearU4Kernel(Input, Output, N, Scale, ZeroPoint); +} #endif // @@ -805,6 +947,58 @@ MlasQuantizeLinear( uint16_t ZeroPoint ); +template +void +MLASCALL +MlasQuantizeLinearInt4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + constexpr int32_t MinimumValue = Int4Traits::Min; + constexpr int32_t MaximumValue = Int4Traits::Max; + using UnpackedType = typename Int4Traits::UnpackedType; + + for (size_t n = 0; n < N; n++) { + float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); + FloatValue = std::max(FloatValue, static_cast(MinimumValue)); + FloatValue = std::min(FloatValue, static_cast(MaximumValue)); + UnpackedType IntValue = static_cast(FloatValue); + + MlasSetInt4Element(Output, n, IntValue); + } +} + +// QuantizeLinear INT4 implementation using the C++ runtime. +void +MLASCALL +MlasQuantizeLinearS4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint); +} + +// QuantizeLinear UINT4 implementation using the C++ runtime. +void +MLASCALL +MlasQuantizeLinearU4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint); +} #endif #endif diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 6b4f62ae1343d..09705f61c82ce 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -20,6 +20,11 @@ constexpr bool Is16BitIntType(int32_t data_type) { (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16); } +constexpr bool Is4BitIntType(int32_t data_type) { + return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4) || + (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4); +} + // adjust for an optional input/output that has an entry but does not exist int NumActualValues(const Node& node, bool input) { const auto& defs = input ? node.InputDefs() : node.OutputDefs(); @@ -134,6 +139,10 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + const Node& dq_node = *dq_nodes.front(); const Node& q_node = *q_nodes.front(); @@ -167,6 +176,10 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { return graph_viewer.GetConstantInitializer(initializer_name, true); }; @@ -193,6 +206,10 @@ bool UnaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& return false; } + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + return true; } @@ -218,6 +235,10 @@ bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + if (!allow_4bit_ && Is4BitIntType(dt_input_1)) { + return false; + } + return true; } @@ -253,6 +274,10 @@ bool VariadicNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + return true; } @@ -275,6 +300,10 @@ bool SplitNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& dq_node = *dq_nodes.front(); int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + // All Q outputs should have same data type and (optionally) equal quantization parameters as the input. for (size_t q_idx = 0; q_idx < q_nodes.size(); q_idx++) { const Node& q_node = *q_nodes[q_idx]; @@ -312,6 +341,10 @@ bool ConvNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + if (!allow_4bit_weight_ && Is4BitIntType(dt_weight)) { + return false; + } + if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { if (!int8_allowed_ || dt_weight != dt_input) { return false; @@ -359,6 +392,11 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + // 4-bit int types must be explicitly allowed. + if (!allow_4bit_ && (Is4BitIntType(dt_input) || Is4BitIntType(dt_weight))) { + return false; + } + // potential match for QLinearMatMul or MatMulIntegerToFloat bool qlinear = !q_nodes.empty(); @@ -407,6 +445,10 @@ bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + if (!allow_4bit_ && (Is4BitIntType(dt_A) || Is4BitIntType(dt_B))) { + return false; + } + if (dq_nodes.size() < 3) { // no bias return true; } @@ -445,6 +487,10 @@ bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& return false; } + if (!allow_4bit_ && Is4BitIntType(dt_input_1)) { + return false; + } + return true; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 5d550669e2e86..1a2a620acb480 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -5,6 +5,7 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +#include #include "core/framework/node_unit.h" #include "core/optimizer/selectors_actions/selector_action_transformer.h" @@ -47,7 +48,8 @@ class NodeGroupSelector { // Zero point and scale are constant scalars and must match class DropQDQNodeGroupSelector : public NodeGroupSelector { public: - explicit DropQDQNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -55,12 +57,14 @@ class DropQDQNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; // Single DQ -> node. class DropDQNodeGroupSelector : public NodeGroupSelector { public: - explicit DropDQNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + explicit DropDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -68,12 +72,14 @@ class DropDQNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; // single input. default is to only support uint8. class UnaryNodeGroupSelector : public NodeGroupSelector { public: - explicit UnaryNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + explicit UnaryNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -81,12 +87,14 @@ class UnaryNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; // 2 DQ nodes providing input -> node -> Q class BinaryNodeGroupSelector : public NodeGroupSelector { public: - explicit BinaryNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + explicit BinaryNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -94,12 +102,14 @@ class BinaryNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; // Variadic DQ nodes -> node -> Q class VariadicNodeGroupSelector : public NodeGroupSelector { public: - explicit VariadicNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + explicit VariadicNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -107,6 +117,7 @@ class VariadicNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; // DQ node -> Split -> multiple Q nodes with equal quantization types. @@ -114,8 +125,8 @@ class VariadicNodeGroupSelector : public NodeGroupSelector { // equal and constant. class SplitNodeGroupSelector : public NodeGroupSelector { public: - explicit SplitNodeGroupSelector(bool req_equal_quant_params = false) - : req_equal_quant_params_(req_equal_quant_params) {} + explicit SplitNodeGroupSelector(bool req_equal_quant_params = false, bool allow_4bit = true) + : req_equal_quant_params_(req_equal_quant_params), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -125,14 +136,15 @@ class SplitNodeGroupSelector : public NodeGroupSelector { bool req_equal_quant_params_; // If true, only selects a node group if the input and output // quantization parameters are all equal/constant, which enables the // optimizer to drop the Q/DQ ops if the group is assigned to the CPU EP. + bool allow_4bit_; }; // DQ nodes for X, W and optionally B -> node -> Q class ConvNodeGroupSelector : public NodeGroupSelector { public: // default to 'true' - ConvNodeGroupSelector(bool int8_allowed = true, bool allow_16bit = true) - : int8_allowed_(int8_allowed), allow_16bit_(allow_16bit) {} + ConvNodeGroupSelector(bool int8_allowed = true, bool allow_16bit = true, bool allow_4bit_weight = true) + : int8_allowed_(int8_allowed), allow_16bit_(allow_16bit), allow_4bit_weight_(allow_4bit_weight) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -141,12 +153,13 @@ class ConvNodeGroupSelector : public NodeGroupSelector { bool int8_allowed_; bool allow_16bit_; + bool allow_4bit_weight_; }; class WhereNodeGroupSelector : public NodeGroupSelector { public: - explicit WhereNodeGroupSelector(bool allow_16bit = true) - : allow_16bit_(allow_16bit) {} + explicit WhereNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -154,6 +167,7 @@ class WhereNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; class PadNodeGroupSelector : public NodeGroupSelector { @@ -172,10 +186,12 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { public: MatMulNodeGroupSelector(bool int8_allowed = true, bool matmulintegertofloat_allowed = false, - bool allow_16bit = true) + bool allow_16bit = true, + bool allow_4bit = true) : int8_allowed_(int8_allowed), matmulintegertofloat_allowed_(matmulintegertofloat_allowed), - allow_16bit_(allow_16bit) { + allow_16bit_(allow_16bit), + allow_4bit_(allow_4bit) { } private: @@ -185,13 +201,15 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { bool int8_allowed_; bool matmulintegertofloat_allowed_; bool allow_16bit_; + bool allow_4bit_; }; // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmNodeGroupSelector : public NodeGroupSelector { public: - explicit GemmNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + explicit GemmNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -199,6 +217,7 @@ class GemmNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; // Input: DQ nodes for input, scale, and B @@ -273,33 +292,35 @@ class BaseSelector : public NodeSelector { class DropQDQNodesSelector : public BaseSelector { public: - explicit DropQDQNodesSelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {} }; class DropDQNodesSelector : public BaseSelector { public: - explicit DropDQNodesSelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit DropDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {} }; class UnarySelector : public BaseSelector { public: - explicit UnarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} + explicit UnarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false, + bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit), compatible_providers) {} }; class BinarySelector : public BaseSelector { public: - explicit BinarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} + explicit BinarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false, + bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit), compatible_providers) {} }; // Variadic DQ nodes -> node -> Q class InputVariadicSelector : public BaseSelector { public: - explicit InputVariadicSelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit InputVariadicSelector(bool allow_16bit = false, bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; @@ -307,8 +328,8 @@ class InputVariadicSelector : public BaseSelector { // DQ -> Split -> variadic Q nodes class SplitSelector : public BaseSelector { public: - SplitSelector(bool req_equal_quant_params = false) - : BaseSelector(std::make_unique(req_equal_quant_params)) {} + SplitSelector(bool req_equal_quant_params = false, bool allow_4bit = false) + : BaseSelector(std::make_unique(req_equal_quant_params, allow_4bit)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; @@ -316,32 +337,34 @@ class SplitSelector : public BaseSelector { // DQ nodes for X, W and optionally B -> node -> Q class ConvSelector : public BaseSelector { public: - ConvSelector(bool int8_allowed = false, bool allow_16bit = false) - : BaseSelector(std::make_unique(int8_allowed, allow_16bit)) {} + ConvSelector(bool int8_allowed = false, bool allow_16bit = false, bool allow_4bit_weight = false) + : BaseSelector(std::make_unique(int8_allowed, allow_16bit, allow_4bit_weight)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; class WhereSelector : public BaseSelector { public: - explicit WhereSelector(gsl::span compatible_providers = {}, bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} + explicit WhereSelector(gsl::span compatible_providers = {}, bool allow_16bit = false, + bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit), compatible_providers) {} }; // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not class MatMulSelector : public BaseSelector { public: - MatMulSelector(bool int8_allowed, bool allow_16bit = false) + MatMulSelector(bool int8_allowed, bool allow_16bit = false, bool allow_4bit = false) : BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true, - allow_16bit)) {} + allow_16bit, allow_4bit)) {} }; // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmSelector : public BaseSelector { public: - explicit GemmSelector(gsl::span compatible_providers = {}, bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} + explicit GemmSelector(gsl::span compatible_providers = {}, bool allow_16bit = false, + bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit), compatible_providers) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 8a270a05d7287..b8d5a7852a968 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -5,6 +5,7 @@ #include #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" +#include "core/framework/int4.h" #include "core/mlas/inc/mlas.h" #ifndef DISABLE_CONTRIB_OPS @@ -1070,6 +1071,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint16_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int16_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Int4x2, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2, DequantizeLinear); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FNUZ, DequantizeLinear); @@ -1080,6 +1083,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint16_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int16_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Int4x2, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2, QuantizeLinear); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FNUZ, QuantizeLinear); @@ -2655,6 +2660,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { DequantizeLinear)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, @@ -2673,6 +2682,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { QuantizeLinear)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index d8924551e5292..05dea2a05c97b 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -5,6 +5,7 @@ #include "core/framework/element_type_lists.h" #include "core/framework/float8.h" #include "core/framework/float16.h" +#include "core/framework/int4.h" #include "core/framework/op_kernel.h" #include "core/providers/common.h" #include "core/mlas/inc/mlas.h" @@ -19,12 +20,22 @@ class DequantizeLinear final : public OpKernel { if (!info.GetAttr("axis", &axis_).IsOK()) { axis_ = 1; } + + if (!info.GetAttr("block_size", &block_size_).IsOK()) { + block_size_ = 0; + } + + // TODO(adrianlizarraga): Support the block_size attribute added in opset 21. + if (block_size_ != 0) { + ORT_THROW("DequantizeLinear does not yet support the 'block_size' attribute."); + } } Status Compute(OpKernelContext* context) const override; private: int64_t axis_; + int64_t block_size_; }; template @@ -37,6 +48,15 @@ class QuantizeLinear final : public OpKernel { if (!info.GetAttr("saturate", &saturate_).IsOK()) { saturate_ = 1; } + + if (!info.GetAttr("block_size", &block_size_).IsOK()) { + block_size_ = 0; + } + + // TODO(adrianlizarraga): Support the block_size attribute added in opset 21. + if (block_size_ != 0) { + ORT_THROW("QuantizeLinear does not yet support the 'block_size' attribute."); + } } Status Compute(OpKernelContext* context) const override; @@ -44,36 +64,38 @@ class QuantizeLinear final : public OpKernel { private: int64_t axis_; int64_t saturate_; + int64_t block_size_; }; static void PrepareForQDQ(const TensorShape& input_shape, const Tensor& scale, const Tensor* zero_point_ptr, int64_t axis, - int64_t& block_count, - int64_t& broadcast_dim, - int64_t& block_size) { + int64_t& quant_block_count, // A "quant block" is a block of elems with the same scale/zp + int64_t& axis_dim_val, + int64_t& quant_block_size) { if (IsScalarOr1ElementVector(&scale)) { // per-tensor QuantizeLinear/DequantizeLinear - block_count = 1; - broadcast_dim = 1; - block_size = static_cast(input_shape.Size()); + quant_block_count = 1; + axis_dim_val = 1; + quant_block_size = static_cast(input_shape.Size()); // enforce that zero point are scalars ORT_ENFORCE(zero_point_ptr == nullptr || IsScalarOr1ElementVector(zero_point_ptr), "x_zero_point must be null or a scalar or 1D tensor or size 1."); } else { // per-channel QuantizeLinear/DequantizeLinear const int64_t axis_no_neg = HandleNegativeAxis(axis, input_shape.NumDimensions()); - block_count = input_shape.SizeToDimension(onnxruntime::narrow(axis_no_neg)); - broadcast_dim = input_shape[onnxruntime::narrow(axis_no_neg)]; - block_size = input_shape.SizeFromDimension(SafeInt(axis_no_neg) + 1); + quant_block_count = input_shape.SizeToDimension(onnxruntime::narrow(axis_no_neg)); + axis_dim_val = input_shape[onnxruntime::narrow(axis_no_neg)]; + quant_block_size = input_shape.SizeFromDimension(SafeInt(axis_no_neg) + 1); // if an axis was specified, ensure the scale and zero point are compatible - ORT_ENFORCE(scale.Shape().NumDimensions() == 1 && scale.Shape()[0] == broadcast_dim, + ORT_ENFORCE(scale.Shape().NumDimensions() == 1 && scale.Shape()[0] == axis_dim_val, "scale must be 1D tensor with size ", - broadcast_dim); - ORT_ENFORCE(zero_point_ptr == nullptr || (zero_point_ptr->Shape().NumDimensions() == 1 && zero_point_ptr->Shape()[0] == broadcast_dim), + axis_dim_val); + ORT_ENFORCE(zero_point_ptr == nullptr || + (zero_point_ptr->Shape().NumDimensions() == 1 && zero_point_ptr->Shape()[0] == axis_dim_val), "x_zero_point must be null or 1D tensor with size ", - broadcast_dim); + axis_dim_val); } } @@ -126,6 +148,8 @@ REGISTER_DEQUANTIZELINEAR(uint8_t) REGISTER_DEQUANTIZELINEAR(int16_t) REGISTER_DEQUANTIZELINEAR(uint16_t) REGISTER_DEQUANTIZELINEAR(int32_t) +REGISTER_DEQUANTIZELINEAR(Int4x2) +REGISTER_DEQUANTIZELINEAR(UInt4x2) #if !defined(DISABLE_FLOAT8_TYPES) REGISTER_DEQUANTIZELINEAR(Float8E4M3FN) REGISTER_DEQUANTIZELINEAR(Float8E4M3FNUZ) @@ -199,17 +223,36 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( .TypeConstraint("T2", DataTypeImpl::GetTensorType()), DequantizeLinear); +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + Int4x2, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + UInt4x2, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + } // namespace contrib #endif // !defined(DISABLE_CONTRIB_OPS) template struct DequantizeLinearApply { - void op(int64_t N, int64_t broadcast_dim, int64_t block_size, const T* input, const OutT* scale, OutT* output, const T* zero_point) { + void op(int64_t N, int64_t axis_dim_val, int64_t quant_block_size, const T* input, const OutT* scale, OutT* output, + const T* zero_point) { for (size_t n = 0; n < static_cast(N); n++) { - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { + for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { auto zp = zero_point ? static_cast(zero_point[bd]) : 0; auto sc = static_cast(scale[bd]); - for (size_t bs = 0; bs < static_cast(block_size); bs++) { + for (size_t bs = 0; bs < static_cast(quant_block_size); bs++) { *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); } } @@ -217,21 +260,50 @@ struct DequantizeLinearApply { } }; +#define DEQUANTIZE_LINEAR_APPLY_INT4(T) \ + template \ + struct DequantizeLinearApply { \ + void op(int64_t N, int64_t axis_dim_val, int64_t quant_block_size, const T* input, const OutT* scale, \ + OutT* output, const T* zero_point) { \ + size_t input_index = 0; \ + for (size_t n = 0; n < static_cast(N); n++) { \ + for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { \ + size_t bd_i = bd >> 1; /*bd / 2*/ \ + size_t bd_j = bd & 0x1; /*bd % 2*/ \ + auto zp = zero_point ? static_cast(zero_point[bd_i].GetElem(bd_j)) : 0; \ + auto sc = static_cast(scale[bd]); \ + for (size_t bs = 0; bs < static_cast(quant_block_size); bs++) { \ + size_t input_i = input_index >> 1; \ + size_t input_j = input_index & 0x1; \ + int32_t val = static_cast(input[input_i].GetElem(input_j)); \ + *output++ = static_cast(static_cast(val - zp) * sc); \ + input_index += 1; \ + } \ + } \ + } \ + assert(input_index == static_cast(N * axis_dim_val * quant_block_size)); \ + } \ + }; + +DEQUANTIZE_LINEAR_APPLY_INT4(Int4x2); +DEQUANTIZE_LINEAR_APPLY_INT4(UInt4x2); + #if !defined(DISABLE_FLOAT8_TYPES) -#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ - template \ - struct DequantizeLinearApply { \ - void op(int64_t N, int64_t broadcast_dim, int64_t block_size, const T* input, const OutT* scale, OutT* output, const T*) { \ - for (size_t n = 0; n < static_cast(N); n++) { \ - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ - auto sc = scale[bd]; \ - for (size_t bs = 0; bs < static_cast(block_size); bs++, input++) { \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - } \ - } \ +#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ + template \ + struct DequantizeLinearApply { \ + void op(int64_t N, int64_t axis_dim_val, int64_t quant_block_size, const T* input, const OutT* scale, \ + OutT* output, const T*) { \ + for (size_t n = 0; n < static_cast(N); n++) { \ + for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { \ + auto sc = scale[bd]; \ + for (size_t bs = 0; bs < static_cast(quant_block_size); bs++, input++) { \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + } \ + } \ }; DEQUANTIZE_LINEAR_APPLY_FLOAT8(Float8E4M3FN) @@ -252,10 +324,10 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { auto& y = *ctx->Output(0, x_shape); int64_t N; - int64_t broadcast_dim; - int64_t block_size; + int64_t axis_dim_val; + int64_t quant_block_size; - PrepareForQDQ(x.Shape(), x_scale, x_zero_point, axis_, N, broadcast_dim, block_size); + PrepareForQDQ(x.Shape(), x_scale, x_zero_point, axis_, N, axis_dim_val, quant_block_size); const T* zero_point = x_zero_point ? x_zero_point->Data() : nullptr; @@ -277,11 +349,11 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { if (to == ONNX_NAMESPACE::TensorProto::FLOAT) { const float* scale = x_scale.Data(); float* output = y.MutableData(); - DequantizeLinearApply().op(N, broadcast_dim, block_size, input, scale, output, zero_point); + DequantizeLinearApply().op(N, axis_dim_val, quant_block_size, input, scale, output, zero_point); } else if (to == ONNX_NAMESPACE::TensorProto::FLOAT16) { const MLFloat16* scale = x_scale.Data(); MLFloat16* output = y.MutableData(); - DequantizeLinearApply().op(N, broadcast_dim, block_size, input, scale, output, zero_point); + DequantizeLinearApply().op(N, axis_dim_val, quant_block_size, input, scale, output, zero_point); } else if (to == ONNX_NAMESPACE::TensorProto::BFLOAT16) { ORT_THROW("DequantizeLinear into BFLOAT16 is not implemented yet."); } else { @@ -341,6 +413,8 @@ REGISTER_QUANTIZELINEAR(int8_t) REGISTER_QUANTIZELINEAR(uint8_t) REGISTER_QUANTIZELINEAR(int16_t) REGISTER_QUANTIZELINEAR(uint16_t) +REGISTER_QUANTIZELINEAR(Int4x2) +REGISTER_QUANTIZELINEAR(UInt4x2) #if !defined(DISABLE_FLOAT8_TYPES) REGISTER_QUANTIZELINEAR(Float8E4M3FN) @@ -404,6 +478,24 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( .TypeConstraint("T1", DataTypeImpl::GetTensorType()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()), QuantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + Int4x2, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + UInt4x2, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); } // namespace contrib #endif // !defined(DISABLE_CONTRIB_OPS) @@ -423,22 +515,93 @@ void ParQuantizeLinear(const InputType* Input, ParQuantizeLinearStd(Input, Output, N, Scale, ZeroPoint != nullptr ? ZeroPoint[bd] : (OutputType)0, thread_pool); #if !defined(DISABLE_FLOAT8_TYPES) } else { - ParQuantizeLinearSat(Input, Output, N, Scale, ZeroPoint != nullptr ? ZeroPoint[bd] : OutputType(static_cast(static_cast(0)), true), saturate, thread_pool); + ParQuantizeLinearSat(Input, Output, N, Scale, + ZeroPoint != nullptr ? ZeroPoint[bd] + : OutputType(static_cast(static_cast(0)), true), + saturate, thread_pool); } #endif } template -void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const T* zero_point, T* output, int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) { +void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const T* zero_point, T* output, int64_t N, + int64_t axis_dim_val, int64_t quant_block_size, bool saturate) { for (size_t n = 0; n < static_cast(N); n++) { - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { - ParQuantizeLinear(input, output, static_cast(block_size), scale[bd], bd, zero_point, saturate, ctx->GetOperatorThreadPool()); - input += block_size; - output += block_size; + for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { + ParQuantizeLinear(input, output, static_cast(quant_block_size), scale[bd], bd, zero_point, saturate, + ctx->GetOperatorThreadPool()); + input += quant_block_size; + output += quant_block_size; } } } +// Quantizes float32 to INT4 (in-place) using MLAS kernel. +#define DEFINE_COMPUTE_LOOP_FP32_TO_INT4(INT4_TYPE, QUANT_FUNC) \ + template <> \ + void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const INT4_TYPE* zero_point, \ + INT4_TYPE* output, int64_t N, int64_t axis_dim_val, int64_t quant_block_size, bool saturate) { \ + ORT_UNUSED_PARAMETER(saturate); \ + size_t output_index = 0; \ + for (size_t n = 0; n < static_cast(N); n++) { \ + for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { \ + size_t bd_i = bd >> 1; /*bd / 2*/ \ + size_t bd_j = bd & 0x1; /*bd % 2*/ \ + INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ + QUANT_FUNC(input, output, output_index, output_index + static_cast(quant_block_size), \ + scale[bd], INT4_TYPE(zp, 0), ctx->GetOperatorThreadPool()); \ + input += quant_block_size; \ + output_index += static_cast(quant_block_size); \ + } \ + } \ + assert(output_index == static_cast(N * axis_dim_val * quant_block_size)); \ + } + +DEFINE_COMPUTE_LOOP_FP32_TO_INT4(Int4x2, ParQuantizeLinearStdS4) +DEFINE_COMPUTE_LOOP_FP32_TO_INT4(UInt4x2, ParQuantizeLinearStdU4) + +// Defines functions to quantize MLFloat16 to INT4. +// This is not an efficient implementation: we allocate a buffer, quantize to INT8, and then copy/clamp/pack +// into output INT4 buffer. +#define DEFINE_COMPUTE_LOOP_FP16_TO_INT4(INT4_TYPE) \ + template <> \ + void ComputeLoop(OpKernelContext * ctx, const MLFloat16* input, const MLFloat16* scale, \ + const INT4_TYPE* zero_point, INT4_TYPE* output, int64_t N, \ + int64_t axis_dim_val, int64_t quant_block_size, bool saturate) { \ + ORT_UNUSED_PARAMETER(saturate); \ + \ + size_t total_size = static_cast(N * axis_dim_val * quant_block_size); \ + auto tmp_buf = std::make_unique(total_size); \ + size_t tmp_buf_index = 0; \ + \ + for (size_t n = 0; n < static_cast(N); n++) { \ + for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { \ + size_t bd_i = bd >> 1; /*bd / 2*/ \ + size_t bd_j = bd & 0x1; /*bd % 2*/ \ + INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ + ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, \ + static_cast(quant_block_size), scale[bd], \ + zp, ctx->GetOperatorThreadPool()); \ + input += quant_block_size; \ + tmp_buf_index += static_cast(quant_block_size); \ + } \ + } \ + \ + for (size_t i = 0; i < total_size; i++) { \ + tmp_buf[i] = std::min(INT4_TYPE::max_val, \ + std::max(INT4_TYPE::min_val, \ + tmp_buf[i])); \ + } \ + \ + size_t num_int4_pairs = (total_size + 1) / 2; \ + auto dst = gsl::make_span(output, num_int4_pairs); \ + auto src = gsl::make_span(tmp_buf.get(), total_size); \ + INT4_TYPE::Pack(dst, src); \ + } + +DEFINE_COMPUTE_LOOP_FP16_TO_INT4(Int4x2) +DEFINE_COMPUTE_LOOP_FP16_TO_INT4(UInt4x2) + // formula is Y = X / Scale + ZeroPoint template Status QuantizeLinear::Compute(OpKernelContext* ctx) const { @@ -449,17 +612,19 @@ Status QuantizeLinear::Compute(OpKernelContext* ctx) const { auto& y = *ctx->Output(0, x_shape); int64_t N; - int64_t broadcast_dim; - int64_t block_size; - PrepareForQDQ(x.Shape(), y_scale, y_zero_point, axis_, N, broadcast_dim, block_size); + int64_t axis_dim_val; + int64_t quant_block_size; + PrepareForQDQ(x.Shape(), y_scale, y_zero_point, axis_, N, axis_dim_val, quant_block_size); const T* zero_point = y_zero_point != nullptr ? y_zero_point->Data() : nullptr; T* output = y.MutableData(); if (x.IsDataType()) { - ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, N, broadcast_dim, block_size, saturate_); + ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, N, axis_dim_val, + quant_block_size, saturate_); } else if (x.IsDataType()) { - ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, N, broadcast_dim, block_size, saturate_); + ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, N, + axis_dim_val, quant_block_size, saturate_); } else { ORT_THROW("Unsupported input type."); } diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index ec4624cf59ae6..5b904e85848d0 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -3,6 +3,7 @@ #include "core/providers/cpu/tensor/transpose.h" +#include #include "core/framework/element_type_lists.h" #include "core/framework/utils.h" #include "core/framework/transpose_helper.h" @@ -29,12 +30,23 @@ ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Transpose, Input, 0, DefaultDataTypes); #endif + +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, Transpose, 21, Input, 0, + element_type_lists::AllIRv10); + +ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, Transpose, 21, Input, 0, + element_type_lists::AllIRv10); + } // namespace op_kernel_type_control namespace { // reduce the supported types with any global or op specific lists -using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, - Transpose, Input, 0); +using EnabledDataTypesAllOpsets = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, + Transpose, Input, 0); +using EnabledDataTypesOpset21 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + Transpose, 21, Input, 0); } // namespace /* A permutation [a,b,c,...] indicates that @@ -183,7 +195,7 @@ inline void CopyPrim(uint8_t* target, const uint8_t* source) { template static bool TypedDoTransposeEltWise(int64_t num_axes, gsl::span target_dims, size_t num_blocks, const gsl::span& stride, const uint8_t* source, uint8_t* target) { - constexpr bool enabled = utils::HasTypeWithSameSize(); + constexpr bool enabled = utils::HasTypeWithSameSize(); if (enabled) { MultiIndex mindex; @@ -287,7 +299,7 @@ static Status DoUntypedTranspose(const gsl::span& permutations, co Status status = Status::OK(); if (is_string_type) { - constexpr bool string_enabled = utils::HasType(); + constexpr bool string_enabled = utils::HasType(); if (string_enabled) { const auto* input_data = input.Data(); @@ -336,38 +348,84 @@ bool IsTransposeReshape(const gsl::span& perm, gsl::span& permutations, const Tensor& input, Tensor& output, + const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { + TensorShape shape = input_shape_override ? *input_shape_override : input.Shape(); + + if (IsTransposeReshape(permutations, shape.GetDims())) { + // As long as the dims with values > 1 stay in the same order, it's a reshape. + // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). + CopyCpuTensor(&input, &output); + return Status::OK(); + } + + size_t from = 0, to = 0; + bool moving_single_axis = IsTransposeMovingSingleAxis(permutations, from, to); + + if (moving_single_axis && !input.IsDataTypeString()) { + SingleAxisTranspose(permutations, input, output, from, to, input_shape_override, tp); + return Status::OK(); + } + + // fall back to default implementation + return DoUntypedTranspose(permutations, input, output, input_shape_override); +} + +template +static Status UnpackInt4Tensor(const Tensor& src, Tensor& dst, AllocatorPtr cpu_allocator) { + using UnpackedType = typename Int4Type::UnpackedType; + MLDataType int8_elem_type = DataTypeImpl::GetType(); + const TensorShape& shape = src.Shape(); + Tensor int8_tensor(int8_elem_type, shape, cpu_allocator); + + ORT_RETURN_IF_NOT(Int4Type::Unpack(int8_tensor.MutableDataAsSpan(), src.DataAsSpan()), + "Failed to unpack Int4x2 Tensor to an int8_t Tensor"); + + dst = std::move(int8_tensor); + + return Status::OK(); +} + +template +static Status DoTransposeInt4(const gsl::span& permutations, const Tensor& input, Tensor& output, + const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { + using Int8Type = typename Int4Type::UnpackedType; + + ORT_RETURN_IF_NOT(input.IsDataType() && output.IsDataType(), + "Expected to transpose int4 tensor"); + + // Convert to Tensor, transpose, and then repack back to Tensor. + AllocatorPtr cpu_allocator = std::make_shared(); + Tensor input_unpacked; + Tensor output_unpacked(DataTypeImpl::GetType(), output.Shape(), cpu_allocator); + + ORT_RETURN_IF_ERROR((UnpackInt4Tensor(input, input_unpacked, cpu_allocator))); + ORT_RETURN_IF_ERROR(TransposeImpl(permutations, input_unpacked, output_unpacked, input_shape_override, tp)); + ORT_RETURN_IF_NOT(Int4Type::Pack(output.MutableDataAsSpan(), output_unpacked.DataAsSpan()), + "Failed to pack 8-bit Tensor into 4-bit Tensor"); + + return Status::OK(); +} + //`input_shape_override` overrides the shape of `input` for compute purposes. Status TransposeBase::DoTranspose(const gsl::span& permutations, const Tensor& input, Tensor& output, - const TensorShape* input_shape_override) { - Status status = Status::OK(); - + const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { auto input_type = input.DataType(); auto output_type = output.DataType(); if (input_type != output_type) { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Mismatched data types between input and output Tensors. ", - input_type, " != ", output_type); - } else { - TensorShape shape = input_shape_override ? *input_shape_override : input.Shape(); - if (IsTransposeReshape(permutations, shape.GetDims())) { - // As long as the dims with values > 1 stay in the same order, it's a reshape. - // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). - CopyCpuTensor(&input, &output); - return Status::OK(); - } - - size_t from = 0, to = 0; - bool moving_single_axis = IsTransposeMovingSingleAxis(permutations, from, to); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Mismatched data types between input and output Tensors. ", + input_type, " != ", output_type); + } + if (input.IsDataType()) { + return DoTransposeInt4(permutations, input, output, input_shape_override, tp); + } - if (moving_single_axis && !input.IsDataTypeString()) { - SingleAxisTranspose(permutations, input, output, from, to, input_shape_override); - } else { - // fall back to default implementation - status = DoUntypedTranspose(permutations, input, output, input_shape_override); - } + if (input.IsDataType()) { + return DoTransposeInt4(permutations, input, output, input_shape_override, tp); } - return status; + return TransposeImpl(permutations, input, output, input_shape_override, tp); } Status Transpose::Compute(OpKernelContext* ctx) const { @@ -388,49 +446,33 @@ Status Transpose::Compute(OpKernelContext* ctx) const { TensorShape output_shape{output_dims}; Tensor& Y = *ctx->Output(0, output_shape); - if (output_shape.Size() == 0) - return Status::OK(); - - if (IsTransposeReshape(*p_perm, input_dims)) { - // As long as the dims with values > 1 stay in the same order, it's a reshape. - // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). - CopyCpuTensor(&X, &Y); + if (output_shape.Size() == 0) { return Status::OK(); } - size_t from = 0, to = 0; - bool moving_single_axis = IsTransposeMovingSingleAxis(*p_perm, from, to); - - if (moving_single_axis && !X.IsDataTypeString()) { - SingleAxisTranspose(*p_perm, X, Y, from, to, nullptr, ctx->GetOperatorThreadPool()); - } else { - // fall back to default implementation - status = DoUntypedTranspose(*p_perm, X, Y); - } - - return status; + return DoTranspose(*p_perm, X, Y, nullptr, ctx->GetOperatorThreadPool()); } ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Transpose, 1, 12, - KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), + KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), Transpose); ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Transpose, 13, 20, - KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), + KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), Transpose); // Opset 21 added support for float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4 and uint4. -// TODO(adrianlizarraga): Implement support for float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4 and uint4. +// TODO(adrianlizarraga): Implement support for float8e4m3fnuz, float8e5m2, and float8e5m2fnuz. ONNX_CPU_OPERATOR_KERNEL( Transpose, 21, - KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), + KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), Transpose); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.h b/onnxruntime/core/providers/cpu/tensor/transpose.h index 133b35ac80fe5..fda41c28a2567 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.h +++ b/onnxruntime/core/providers/cpu/tensor/transpose.h @@ -33,7 +33,8 @@ class TransposeBase { Both Tensors must have the same data type. `input_shape_override` overrides the shape of `input` for compute purposes. */ static Status DoTranspose(const gsl::span& permutations, const Tensor& input, Tensor& output, - const TensorShape* input_shape_override = nullptr); + const TensorShape* input_shape_override = nullptr, + concurrency::ThreadPool* tp = nullptr); protected: TransposeBase(const OpKernelInfo& info) { diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 7b36057e3bafe..7cdfb0ffc19f2 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -24,6 +24,7 @@ #include "core/framework/allocator.h" #include "core/framework/float8.h" #include "core/framework/float16.h" +#include "core/framework/int4.h" #include "core/framework/tensor_shape.h" #include "core/providers/providers.h" #include "core/common/path_string.h" @@ -68,7 +69,9 @@ enum TensorProto_DataType : int { TensorProto_DataType_FLOAT8E4M3FN = 17, TensorProto_DataType_FLOAT8E4M3FNUZ = 18, TensorProto_DataType_FLOAT8E5M2 = 19, - TensorProto_DataType_FLOAT8E5M2FNUZ = 20 + TensorProto_DataType_FLOAT8E5M2FNUZ = 20, + TensorProto_DataType_UINT4 = 21, + TensorProto_DataType_INT4 = 22, }; enum TensorProto_DataLocation : int { @@ -86,7 +89,8 @@ enum Version : int { IR_VERSION_2019_9_19 = 6, IR_VERSION_2020_5_8 = 7, IR_VERSION_2021_7_31 = 8, - IR_VERSION = 9 + IR_VERSION_2023_5_5 = 9, + IR_VERSION = 10 }; enum OperatorStatus : int { @@ -347,6 +351,14 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { template <> constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ; } #endif +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; +} +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; +} } // namespace utils namespace QDQ { diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 575434d19bf35..27d8a0f06f565 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -167,6 +167,10 @@ MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->Data template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_Float8E5M2FNUZ(); } #endif +template <> +MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_Int4x2(); } +template <> +MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_UInt4x2(); } template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_string(); } @@ -207,6 +211,10 @@ MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost() template <> MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_Float8E5M2FNUZ(); } #endif +template <> +MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_Int4x2(); } +template <> +MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_UInt4x2(); } #if !defined(DISABLE_SPARSE_TENSORS) template <> diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 597ebe98ba08c..cc3b13f696a96 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -582,6 +582,8 @@ struct ProviderHost { // PrimitiveDataTypeBase virtual int32_t PrimitiveDataTypeBase__GetDataType(const PrimitiveDataTypeBase* p) = 0; + virtual int32_t PrimitiveDataTypeBase__GetNumSubElems(const PrimitiveDataTypeBase* p) = 0; + virtual bool PrimitiveDataTypeBase__HasSubElems(const PrimitiveDataTypeBase* p) = 0; // DataTypeImpl virtual MLDataType DataTypeImpl__GetType_Tensor() = 0; @@ -610,6 +612,8 @@ struct ProviderHost { virtual MLDataType DataTypeImpl__GetType_Float8E5M2() = 0; virtual MLDataType DataTypeImpl__GetType_Float8E5M2FNUZ() = 0; #endif + virtual MLDataType DataTypeImpl__GetType_Int4x2() = 0; + virtual MLDataType DataTypeImpl__GetType_UInt4x2() = 0; virtual MLDataType DataTypeImpl__GetTensorType_bool() = 0; virtual MLDataType DataTypeImpl__GetTensorType_int8() = 0; @@ -630,6 +634,8 @@ struct ProviderHost { virtual MLDataType DataTypeImpl__GetTensorType_Float8E5M2() = 0; virtual MLDataType DataTypeImpl__GetTensorType_Float8E5M2FNUZ() = 0; #endif + virtual MLDataType DataTypeImpl__GetTensorType_Int4x2() = 0; + virtual MLDataType DataTypeImpl__GetTensorType_UInt4x2() = 0; #if !defined(DISABLE_SPARSE_TENSORS) virtual MLDataType DataTypeImpl__GetSparseTensorType_bool() = 0; @@ -986,6 +992,8 @@ struct ProviderHost { virtual Float8E5M2* Tensor__MutableData_Float8E5M2(Tensor* p) = 0; virtual Float8E5M2FNUZ* Tensor__MutableData_Float8E5M2FNUZ(Tensor* p) = 0; #endif + virtual Int4x2* Tensor__MutableData_Int4x2(Tensor* p) = 0; + virtual UInt4x2* Tensor__MutableData_UInt4x2(Tensor* p) = 0; virtual const bool* Tensor__Data_bool(const Tensor* p) = 0; virtual const int8_t* Tensor__Data_int8(const Tensor* p) = 0; @@ -1007,6 +1015,8 @@ struct ProviderHost { virtual const Float8E5M2* Tensor__Data_Float8E5M2(const Tensor* p) = 0; virtual const Float8E5M2FNUZ* Tensor__Data_Float8E5M2FNUZ(const Tensor* p) = 0; #endif + virtual const Int4x2* Tensor__Data_Int4x2(const Tensor* p) = 0; + virtual const UInt4x2* Tensor__Data_UInt4x2(const Tensor* p) = 0; virtual gsl::span Tensor__DataAsSpan_int64(const Tensor* p) = 0; @@ -1038,6 +1048,8 @@ struct ProviderHost { virtual bool Tensor__IsDataType_Float8E5M2(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataType_Float8E5M2FNUZ(const Tensor* p) noexcept = 0; #endif + virtual bool Tensor__IsDataType_Int4x2(const Tensor* p) noexcept = 0; + virtual bool Tensor__IsDataType_UInt4x2(const Tensor* p) noexcept = 0; virtual const TensorShape& Tensor__Shape(const Tensor* p) = 0; virtual void Tensor__Reshape(Tensor* p, const TensorShape& new_shape) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 125cd8fd92f6b..fd2540b42a3db 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -589,6 +589,14 @@ struct KernelRegistry final { struct PrimitiveDataTypeBase final { int32_t GetDataType() const { return g_host->PrimitiveDataTypeBase__GetDataType(this); } + int32_t GetNumSubElems() const { + return g_host->PrimitiveDataTypeBase__GetNumSubElems(this); + } + + bool HasSubElems() const { + return g_host->PrimitiveDataTypeBase__HasSubElems(this); + } + PROVIDER_DISALLOW_ALL(PrimitiveDataTypeBase) }; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 0239c1223130b..5cf5ff9b3bd0a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -215,10 +215,10 @@ ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); } - auto elem_count = narrow(tensor_shape.Size()); - size_t size_to_allocate; - if (!IAllocator::CalcMemSizeForArray(ml_type->Size(), elem_count, &size_to_allocate)) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "size overflow"); + size_t size_to_allocate = 0; + Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); + if (!status.IsOK()) { + return ToOrtStatus(status); } if (size_to_allocate > p_data_len) { std::ostringstream oss; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index a27eb880daf09..d18b3ac40d489 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -779,6 +779,8 @@ struct ProviderHostImpl : ProviderHost { // PrimitiveDataTypeBase (wrapped) int32_t PrimitiveDataTypeBase__GetDataType(const PrimitiveDataTypeBase* p) override { return p->GetDataType(); } + int32_t PrimitiveDataTypeBase__GetNumSubElems(const PrimitiveDataTypeBase* p) override { return p->GetNumSubElems(); } + bool PrimitiveDataTypeBase__HasSubElems(const PrimitiveDataTypeBase* p) override { return p->HasSubElems(); } // DataTypeImpl (wrapped) MLDataType DataTypeImpl__GetType_Tensor() override { return DataTypeImpl::GetType(); } @@ -808,6 +810,8 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetType_Float8E5M2() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetType_Float8E5M2FNUZ() override { return DataTypeImpl::GetType(); } #endif + MLDataType DataTypeImpl__GetType_Int4x2() override { return DataTypeImpl::GetType(); } + MLDataType DataTypeImpl__GetType_UInt4x2() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetTensorType_bool() override { return DataTypeImpl::GetTensorType(); } MLDataType DataTypeImpl__GetTensorType_int8() override { return DataTypeImpl::GetTensorType(); } @@ -829,6 +833,8 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetTensorType_Float8E5M2() override { return DataTypeImpl::GetTensorType(); } MLDataType DataTypeImpl__GetTensorType_Float8E5M2FNUZ() override { return DataTypeImpl::GetTensorType(); } #endif + MLDataType DataTypeImpl__GetTensorType_Int4x2() override { return DataTypeImpl::GetTensorType(); } + MLDataType DataTypeImpl__GetTensorType_UInt4x2() override { return DataTypeImpl::GetTensorType(); } #if !defined(DISABLE_SPARSE_TENSORS) MLDataType DataTypeImpl__GetSparseTensorType_bool() override { return DataTypeImpl::GetSparseTensorType(); } @@ -1281,6 +1287,8 @@ struct ProviderHostImpl : ProviderHost { Float8E5M2* Tensor__MutableData_Float8E5M2(Tensor* p) override { return p->MutableData(); } Float8E5M2FNUZ* Tensor__MutableData_Float8E5M2FNUZ(Tensor* p) override { return p->MutableData(); } #endif + Int4x2* Tensor__MutableData_Int4x2(Tensor* p) override { return p->MutableData(); } + UInt4x2* Tensor__MutableData_UInt4x2(Tensor* p) override { return p->MutableData(); } const bool* Tensor__Data_bool(const Tensor* p) override { return p->Data(); } const int8_t* Tensor__Data_int8(const Tensor* p) override { return p->Data(); } @@ -1302,6 +1310,8 @@ struct ProviderHostImpl : ProviderHost { const Float8E5M2* Tensor__Data_Float8E5M2(const Tensor* p) override { return p->Data(); } const Float8E5M2FNUZ* Tensor__Data_Float8E5M2FNUZ(const Tensor* p) override { return p->Data(); } #endif + const Int4x2* Tensor__Data_Int4x2(const Tensor* p) override { return p->Data(); } + const UInt4x2* Tensor__Data_UInt4x2(const Tensor* p) override { return p->Data(); } gsl::span Tensor__DataAsSpan_int64(const Tensor* p) override { return p->DataAsSpan(); } @@ -1331,6 +1341,8 @@ struct ProviderHostImpl : ProviderHost { bool Tensor__IsDataType_Float8E5M2(const Tensor* p) noexcept override { return p->IsDataType(); } bool Tensor__IsDataType_Float8E5M2FNUZ(const Tensor* p) noexcept override { return p->IsDataType(); } #endif + bool Tensor__IsDataType_Int4x2(const Tensor* p) noexcept override { return p->IsDataType(); } + bool Tensor__IsDataType_UInt4x2(const Tensor* p) noexcept override { return p->IsDataType(); } const TensorShape& Tensor__Shape(const Tensor* p) override { return p->Shape(); } void Tensor__Reshape(Tensor* p, const TensorShape& new_shape) override { return p->Reshape(new_shape); } diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index 173ab632d59cf..235ecfde0954a 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -8,6 +8,7 @@ #include "core/common/narrow.h" #include "core/framework/element_type_lists.h" #include "core/framework/float8.h" +#include "core/framework/int4.h" #include namespace onnxruntime { @@ -132,6 +133,100 @@ ParQuantizeLinearStd(const float* Input, }); } +/** + * Defines a function for int4 quantization. Calls MLAS kernel in parallel with the provided thread pool. + * + * \param FUNC_NAME The name of the generated function. + * \param INT4_TYPE The int4 type (i.e., either Int4x2 or UInt4x2) + * \param MLAS_FUNC The MLAS quantization kernel to call. + * \param Input The input float values to quantize. Must contain `out_end - out_start` elements. + * \param Output The output buffer that will contain the quantized values. + * \param out_start The int4 element index at which to start writing to the output buffer. + * Divide by 2 to get index into Output buffer. + * \param out_end The int4 element index at which to stop writing to the output buffer. + * Divide by 2 to get index into Output buffer. + * \param Scale The quantization scale value. + * \param ZeroPoint The quantization zero-point value. + * \param thread_pool The thread pool to use. + */ +#define DEFINE_PAR_QUANT_LINEAR_STD_4BIT(FUNC_NAME, INT4_TYPE, MLAS_FUNC) \ + inline void FUNC_NAME(const float* Input, \ + INT4_TYPE* Output, \ + size_t out_start, \ + size_t out_end, \ + float Scale, \ + INT4_TYPE ZeroPoint, \ + concurrency::ThreadPool* thread_pool) { \ + size_t inp_start = 0; \ + size_t inp_end = out_end - out_start; \ + \ + /* If starting at an int4 element in the middle of a byte, quantize it by itself. */ \ + if (out_start & 0x1) { \ + int32_t ival = static_cast(std::nearbyintf(Input[inp_start] / Scale)) + \ + static_cast(ZeroPoint.GetElem(0)); \ + size_t output_index = out_start >> 1; \ + \ + INT4_TYPE::UnpackedType quant_val = static_cast( \ + std::min(static_cast(INT4_TYPE::max_val), \ + std::max(static_cast(INT4_TYPE::min_val), ival))); \ + Output[output_index].SetElem(1, quant_val); \ + \ + out_start += 1; \ + inp_start += 1; \ + } \ + \ + /* If ending at element that ends in the middle of a byte, quantize it by itself. */ \ + if (out_end & 0x1) { \ + int32_t ival = static_cast(std::nearbyintf(Input[inp_end - 1] / Scale)) + \ + static_cast(ZeroPoint.GetElem(0)); \ + size_t output_index = (out_end - 1) >> 1; \ + \ + INT4_TYPE::UnpackedType quant_val = static_cast( \ + std::min(static_cast(INT4_TYPE::max_val), \ + std::max(static_cast(INT4_TYPE::min_val), ival))); \ + Output[output_index].SetElem(0, quant_val); \ + \ + out_end -= 1; \ + inp_end -= 1; \ + } \ + \ + if (out_start == out_end) { \ + return; \ + } \ + \ + /* At this point, should only need to quantize an *even* number of int4 elements that start and end at */ \ + /* a byte boundary. This is necessary to ensure that no two threads write to different int4 elements that */ \ + /* are stored in the same byte. */ \ + size_t N = out_end - out_start; \ + assert(N % 2 == 0); /* Should be guaranteed by previous code that quantizes boundary elements. */ \ + \ + constexpr std::ptrdiff_t block_size = 128; \ + static_assert(block_size % 2 == 0, \ + "Block size must also be even to ensure no two threads write to the same byte."); \ + \ + const std::ptrdiff_t num_blocks = (N + block_size - 1) / block_size; \ + const TensorOpCost unit_cost{static_cast(block_size * sizeof(float)), \ + static_cast(block_size * sizeof(INT4_TYPE::UnpackedType)) / 2.0, \ + static_cast(block_size) * 2.0}; \ + concurrency::ThreadPool::TryParallelFor( \ + thread_pool, num_blocks, unit_cost, \ + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { \ + auto begin_idx = begin * block_size; \ + auto end_idx = std::min(static_cast(N), end * block_size); \ + auto inp_idx = begin_idx + static_cast(inp_start); \ + auto out_idx = begin_idx + static_cast(out_start); \ + \ + MLAS_FUNC(&(Input[inp_idx]), \ + reinterpret_cast(&(Output[out_idx >> 1])), \ + end_idx - begin_idx, \ + Scale, \ + static_cast(ZeroPoint.GetElem(0))); \ + }); \ + } + +DEFINE_PAR_QUANT_LINEAR_STD_4BIT(ParQuantizeLinearStdS4, Int4x2, MlasQuantizeLinearS4) +DEFINE_PAR_QUANT_LINEAR_STD_4BIT(ParQuantizeLinearStdU4, UInt4x2, MlasQuantizeLinearU4) + // This implementation could be more efficient however the cast from float16 to other types // usually happens on GPU. template diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 78297df185d68..23744c24d1a21 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -586,9 +586,10 @@ static void CopyDataToTensor(PyArrayObject* darray, int npy_type, Tensor& tensor } } else { void* buffer = tensor.MutableDataRaw(); - size_t len; - if (!IAllocator::CalcMemSizeForArray(tensor.DataType()->Size(), tensor.Shape().Size(), &len)) { - throw std::runtime_error("length overflow"); + size_t len = 0; + Status status = Tensor::CalculateTensorStorageSize(tensor.DataType(), tensor.Shape(), /*alignment*/ 0, len); + if (!status.IsOK()) { + throw std::runtime_error(status.ErrorMessage()); } mem_cpy_to_device(buffer, PyArray_DATA(darray), len); } diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 625cab25b9c46..74e213fa61362 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -339,6 +339,14 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}" f"\nraw={str(q_weight_initializer)[:200]}." ) + elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): + # TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed + # within int32_data is fixed. + # q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, q_weight_data) + packed_data = onnx.helper.pack_float32_to_4bit(q_weight_data.flatten(), qType == onnx.TensorProto.INT4) + q_weight_initializer = onnx.helper.make_tensor( + q_weight_name, qType, weight.dims, packed_data.tobytes(), raw=True + ) else: q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape( weight.dims @@ -396,7 +404,10 @@ def quantize_weight_per_channel_impl( symmetric = quant_overrides_for_channels[0].get( "symmetric", - (self.is_weight_symmetric or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN)), + ( + self.is_weight_symmetric + or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.INT4) + ), ) reduce_range = quant_overrides_for_channels[0].get("reduce_range", self.reduce_range and reduce_range) zero_point_list = [] @@ -447,7 +458,8 @@ def quantize_weight_per_channel_impl( quantized_per_channel_data_list.append(quantized_per_channel_data) # combine per_channel_data into one - reshape_dims = list(weights.shape) # deep copy + weights_shape = list(weights.shape) + reshape_dims = list(weights_shape) # deep copy reshape_dims[channel_axis] = 1 # only one per channel for reshape quantized_weights = np.asarray(quantized_per_channel_data_list[0]).reshape(reshape_dims) for i in range(1, len(quantized_per_channel_data_list)): @@ -470,12 +482,26 @@ def quantize_weight_per_channel_impl( self.model.initializer_extend([scale_initializer, zero_initializer]) if not keep_float_weight: - quantized_weights = np.asarray( - quantized_weights, - dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[weight_qType], - ).reshape(initializer.dims) - q_weight_initializer = onnx.numpy_helper.from_array(quantized_weights, q_weight_name) - self.model.initializer_extend([q_weight_initializer]) + if weight_qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): + # TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed + # within int32_data is fixed. + # q_weight_initializer = onnx.helper.make_tensor( + # q_weight_name, weight_qType, weights_shape, quantized_weights + # ) + packed_data = onnx.helper.pack_float32_to_4bit( + quantized_weights.flatten(), weight_qType == onnx.TensorProto.INT4 + ) + q_weight_initializer = onnx.helper.make_tensor( + q_weight_name, weight_qType, weights_shape, packed_data.tobytes(), raw=True + ) + self.model.initializer_extend([q_weight_initializer]) + else: + quantized_weights = np.asarray( + quantized_weights, + dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[weight_qType], + ).reshape(initializer.dims) + q_weight_initializer = onnx.numpy_helper.from_array(quantized_weights, q_weight_name) + self.model.initializer_extend([q_weight_initializer]) return q_weight_name, zp_name, scale_name diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index 3b857c991951c..1ad56dc3ac455 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -21,6 +21,7 @@ Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16} Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8} +Q4_TYPES = {QuantType.QInt4, QuantType.QUInt4} OP_TYPES_TO_EXCLUDE = {"Cast"} MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB @@ -173,11 +174,12 @@ def get_qnn_qdq_config( } # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain - # on Q/DQ operators if using 16-bit quantization. + # on Q/DQ operators if using 16-bit or 4-bit quantization. onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx") if onnx_opset.version < 21: - overrides_have_int16 = any(t in Q16_TYPES for t in overrides_helper.get_quant_types()) - if activation_type in Q16_TYPES or weight_type in Q16_TYPES or overrides_have_int16: + opset21_types = Q16_TYPES.union(Q4_TYPES) + overrides_have_opset21_types = any(t in opset21_types for t in overrides_helper.get_quant_types()) + if activation_type in opset21_types or weight_type in opset21_types or overrides_have_opset21_types: extra_options["UseQDQContribOps"] = True return StaticQuantConfig( diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index c368d887fda22..ac61f4779d389 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -193,16 +193,20 @@ def __init__( # The ONNX spec did not support 16-bit Q/DQ ops before opset 21. # So, may have to override the Q/DQ op domain to 'com.microsoft' if the activation or weight types - # are 16-bit integers. + # are 16-bit or 4-bit integers. if self.opset_version < 21: - int16_types = (TensorProto.UINT16, TensorProto.INT16) - overrides_have_int16 = any(t.tensor_type in int16_types for t in self.tensor_quant_override_qtypes) + opset21_types = (TensorProto.UINT16, TensorProto.INT16, TensorProto.UINT4, TensorProto.INT4) + overrides_have_opset21_types = any( + t.tensor_type in opset21_types for t in self.tensor_quant_override_qtypes + ) if not self.qdq_op_domain and ( - self.activation_qType in int16_types or self.weight_qType in int16_types or overrides_have_int16 + self.activation_qType in opset21_types + or self.weight_qType in opset21_types + or overrides_have_opset21_types ): logging.warning( "ONNX QuantizeLinear and DequantizeLinear operators do not support " - "16-bit integer quantization types prior to opset 21. " + "16-bit/4-bit integer quantization types prior to opset 21. " f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to " "enable support." ) diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 35b5e1c8ba825..bdf6d5a355206 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -21,7 +21,7 @@ from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions try: - from onnx.reference.custom_element_types import float8e4m3fn + from onnx.reference.custom_element_types import float8e4m3fn, int4, uint4 except ImportError: float8e4m3fn = None @@ -81,6 +81,8 @@ class QuantType(Enum): QFLOAT8E4M3FN = 2 QInt16 = 3 QUInt16 = 4 + QInt4 = 5 + QUInt4 = 6 def __str__(self): return self.name @@ -104,6 +106,10 @@ def tensor_type(self): return TensorProto.INT16 if self == QuantType.QFLOAT8E4M3FN: return TensorProto.FLOAT8E4M3FN + if self == QuantType.QUInt4: + return TensorProto.UINT4 + if self == QuantType.QInt4: + return TensorProto.INT4 raise ValueError(f"Unexpected value qtype={self!r}.") @@ -128,6 +134,8 @@ def from_string(format): onnx_proto.TensorProto.INT16: numpy.dtype("int16"), onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"), onnx_proto.TensorProto.FLOAT8E4M3FN: float8e4m3fn, + onnx_proto.TensorProto.INT4: int4, + onnx_proto.TensorProto.UINT4: uint4, } ONNX_INT_TYPE_RANGE = { @@ -135,6 +143,8 @@ def from_string(format): onnx_proto.TensorProto.INT8: (numpy.array(-128, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)), onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(65535, dtype=numpy.uint16)), onnx_proto.TensorProto.INT16: (numpy.array(-32768, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)), + onnx_proto.TensorProto.UINT4: (numpy.array(0, dtype=uint4), numpy.array(15, dtype=uint4)), + onnx_proto.TensorProto.INT4: (numpy.array(-8, dtype=int4), numpy.array(7, dtype=int4)), } ONNX_INT_TYPE_SYMMETRIC_RANGE = { @@ -202,6 +212,35 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): ) ref = ReferenceEvaluator(onnx_model) return _check_type(ref.run(None, {"X": arr, "scale": scale})[0]) + elif qType in ( + onnx_proto.TensorProto.INT4, + onnx_proto.TensorProto.UINT4, + ): + if arr.dtype == numpy.float32: + onnx_type = TensorProto.FLOAT + elif arr.dtype == numpy.float16: + onnx_type = TensorProto.FLOAT16 + else: + raise ValueError(f"Unexpected dtype {arr.dtype}.") + onnx_model = make_model( + make_graph( + [ + make_node("QuantizeLinear", ["X", "scale", "zero_point"], ["Y"]), + ], + "qu", + [ + make_tensor_value_info("X", onnx_type, None), + make_tensor_value_info("scale", onnx_type, None), + make_tensor_value_info("zero_point", qType, None), + ], + [make_tensor_value_info("Y", qType, None)], + ) + ) + # The reference ONNX implementation of QuantizeLinear returns "unpacked" int8 numpy values + # because numpy cannot represent 4bit values (although ONNX TensorProto has no problem with this). + # These "unpacked" int8 values are correctly re-packed when passed to onnx.make_tensor(). + ref = ReferenceEvaluator(onnx_model) + return _check_type(ref.run(None, {"X": arr, "scale": scale, "zero_point": zero_point})[0]) else: dtype = ONNX_TYPE_TO_NP_TYPE[qType] (qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True) @@ -372,7 +411,14 @@ def quantize_data( ) return _check_type(rmin, rmax, zero_point, scale, quantized_data, zero_point_index=2) - if qType in (TensorProto.INT8, TensorProto.UINT8, TensorProto.INT16, TensorProto.UINT16): + if qType in ( + TensorProto.INT8, + TensorProto.UINT8, + TensorProto.INT16, + TensorProto.UINT16, + TensorProto.INT4, + TensorProto.UINT4, + ): if len(data): qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric) zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range) diff --git a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp index 986d158d2b1b9..7c160b6696265 100644 --- a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp +++ b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp @@ -71,6 +71,116 @@ class MlasQuantizeLinearTest : public MlasTestBase { } }; +template +class MlasQuantizeLinear4BitTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + + int32_t MinVal() const { + if constexpr (Signed) { + return -8; + } else { + return 0; + } + } + + int32_t MaxVal() const { + if constexpr (Signed) { + return 7; + } else { + return 15; + } + } + + void GenerateReference(const float* Input, uint8_t* OutputReference, size_t N, float Scale, + int8_t ZeroPoint) { + for (size_t n = 0; n < N; n++) { + float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); + FloatValue = std::max(FloatValue, static_cast(MinVal())); + FloatValue = std::min(FloatValue, static_cast(MaxVal())); + + int8_t IntValue = static_cast(FloatValue); + + size_t i = n >> 1; + size_t j = n & 0x1; + uint8_t Shift = 4 * static_cast(j); + uint8_t Mask = 0xF << Shift; + + OutputReference[i] &= ~Mask; // Clear 4-bit lane + OutputReference[i] |= static_cast((IntValue & 0xF) << Shift); // Set 4-bit lane + } + } + + void Test(size_t N) { + size_t OutBufLen = (N + 1) / 2; + float* Input = BufferInput.GetBuffer(N); + uint8_t* Output = BufferOutput.GetBuffer(OutBufLen); + uint8_t* OutputReference = BufferOutputReference.GetBuffer(OutBufLen); + + std::default_random_engine generator(static_cast(N)); + + std::uniform_real_distribution min_gen(-10.f, -10e-3f); + float MinimumValue = min_gen(generator); + + std::uniform_real_distribution max_gen(10e-3f, 10.f); + float MaximumValue = max_gen(generator); + + float Scale = (MaximumValue - MinimumValue) / 32.f; + + std::uniform_int_distribution zp_distribution(MinVal(), MaxVal()); + int8_t ZeroPoint = static_cast(zp_distribution(generator)); + + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + for (size_t n = 0; n < N; n++) { + Input[n] = distribution(generator); + } + + GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); + + if constexpr (Signed) { + MlasQuantizeLinearS4(Input, Output, N, Scale, ZeroPoint); + } else { + MlasQuantizeLinearU4(Input, Output, N, Scale, ZeroPoint); + } + + for (size_t n = 0; n < N; n++) { + size_t i = n >> 1; + size_t j = n & 0x1; + const uint8_t Shift = 4 * static_cast(j); + + int32_t actual_val = (Output[i] >> Shift) & 0xF; + int32_t expected_val = (OutputReference[i] >> Shift) & 0xF; + + if constexpr (Signed) { + constexpr uint8_t SignExtShift = (sizeof(int32_t) * 8) - 4; + actual_val = (actual_val << SignExtShift) >> SignExtShift; + expected_val = (expected_val << SignExtShift) >> SignExtShift; + } + + ASSERT_EQ(actual_val, expected_val) << ", size=" << N + << ", index=" << n + << ", nibble=" << j; + } + } + + public: + static const char* GetTestSuiteName() { + if constexpr (Signed) { + return "QuantizeLinearS4"; + } else { + return "QuantizeLinearU4"; + } + } + + void ExecuteShort(void) override { + for (size_t n = 1; n <= 512; n++) { + Test(n); + } + } +}; + static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { @@ -78,6 +188,8 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe count += MlasDirectShortExecuteTests>::RegisterShortExecute(); count += MlasDirectShortExecuteTests>::RegisterShortExecute(); count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } return count; }); diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index e12e9401413be..1d54a3cfae9bf 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1021,7 +1021,12 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {}}, {"dequantizelinear_blocked", "blocked quantization (onnx 1.16.0) not supported", {}}, {"quantizelinear_blocked_asymmetric", "blocked quantization (onnx 1.16.0) not supported", {}}, - {"quantizelinear_blocked_symmetric", "blocked quantization (onnx 1.16.0) not supported", {}}}); + {"quantizelinear_blocked_symmetric", "blocked quantization (onnx 1.16.0) not supported", {}}, + // See PR that fixes int4 q/dq tests: https://github.com/onnx/onnx/pull/6122 + {"dequantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}}, + {"dequantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}, + {"quantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}}, + {"quantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}}); // Some EPs may fail to pass some specific testcases. // For example TenosrRT EP may fail on FLOAT16 related testcases if GPU doesn't support float16. diff --git a/onnxruntime/test/onnx/tensorprotoutils.cc b/onnxruntime/test/onnx/tensorprotoutils.cc index b15b1769a69c4..5df055f862a86 100644 --- a/onnxruntime/test/onnx/tensorprotoutils.cc +++ b/onnxruntime/test/onnx/tensorprotoutils.cc @@ -75,6 +75,48 @@ static void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_length memcpy(p_data, raw_data, raw_data_length); } +template <> +void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, + /*out*/ Int4x2* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + if (p_data == nullptr) { + ORT_CXX_API_THROW("nullptr == p_data", OrtErrorCode::ORT_FAIL); + } + + size_t num_packed_pairs = (expected_num_elements + 1) / 2; + + if (num_packed_pairs != raw_data_len) { + ORT_CXX_API_THROW("Unexpected number of packed int4 pairs", OrtErrorCode::ORT_FAIL); + } + + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); + gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); +} + +template <> +void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, + /*out*/ UInt4x2* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + if (p_data == nullptr) { + ORT_CXX_API_THROW("nullptr == p_data", OrtErrorCode::ORT_FAIL); + } + + size_t num_packed_pairs = (expected_num_elements + 1) / 2; + + if (num_packed_pairs != raw_data_len) { + ORT_CXX_API_THROW("Unexpected number of packed int4 pairs", OrtErrorCode::ORT_FAIL); + } + + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); + gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); +} + // This macro doesn't work for Float16/bool/string tensors #define DEFINE_UNPACK_TENSOR(T, Type, field_name, field_size) \ template <> \ @@ -268,6 +310,41 @@ DEFINE_UNPACK_TENSOR_FLOAT8(Float8E5M2, TensorProto_DataType_FLOAT8E5M2) DEFINE_UNPACK_TENSOR_FLOAT8(Float8E5M2FNUZ, TensorProto_DataType_FLOAT8E5M2FNUZ) #endif +#define DEFINE_UNPACK_TENSOR_INT4(INT4_TYPE, ONNX_TYPE) \ + template <> \ + void UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \ + /*out*/ INT4_TYPE* p_data, size_t expected_num_elems) { \ + if (nullptr == p_data) { \ + const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); \ + if (size == 0) { \ + return; \ + } \ + ORT_CXX_API_THROW("p_data == nullptr, but size != 0", OrtErrorCode::ORT_INVALID_ARGUMENT); \ + } \ + if (ONNX_NAMESPACE::ONNX_TYPE != tensor.data_type()) { \ + ORT_CXX_API_THROW("TensorProto data type is not INT4", OrtErrorCode::ORT_INVALID_ARGUMENT); \ + } \ + \ + size_t expected_int4_pairs = (expected_num_elems + 1) / 2; \ + \ + if (raw_data != nullptr) { \ + UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); \ + return; \ + } \ + \ + if (static_cast(tensor.int32_data_size()) != expected_int4_pairs) { \ + ORT_CXX_API_THROW("UnpackTensor: the pre-allocated size does not match the size in proto", \ + OrtErrorCode::ORT_FAIL); \ + } \ + \ + for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { \ + p_data[i] = INT4_TYPE(static_cast(tensor.int32_data()[i])); \ + } \ + } + +DEFINE_UNPACK_TENSOR_INT4(Int4x2, TensorProto_DataType_INT4) +DEFINE_UNPACK_TENSOR_INT4(UInt4x2, TensorProto_DataType_UINT4) + #define CASE_PROTO_TRACE(X, Y) \ case onnx::TensorProto_DataType::TensorProto_DataType_##X: \ if (!CalcMemSizeForArrayWithAlignment(size, sizeof(Y), alignment, out)) { \ @@ -275,6 +352,13 @@ DEFINE_UNPACK_TENSOR_FLOAT8(Float8E5M2FNUZ, TensorProto_DataType_FLOAT8E5M2FNUZ) } \ break; +#define CASE_PROTO_TRACE_INT4(X) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + if (!CalcMemSizeForArrayWithAlignment((size + 1) / 2, 1, alignment, out)) { \ + ORT_CXX_API_THROW("Invalid TensorProto", OrtErrorCode::ORT_FAIL); \ + } \ + break; + template Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out) { const auto& dims = tensor_proto.dims(); @@ -308,6 +392,8 @@ Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_p CASE_PROTO_TRACE(FLOAT8E5M2FNUZ, Float8E5M2FNUZ); #endif CASE_PROTO_TRACE(STRING, std::string); + CASE_PROTO_TRACE_INT4(UINT4); + CASE_PROTO_TRACE_INT4(INT4); default: return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); } @@ -392,6 +478,8 @@ ONNXTensorElementDataType CApiElementTypeFromProtoType(int type) { CASE_TYPE(FLOAT8E4M3FNUZ) CASE_TYPE(FLOAT8E5M2) CASE_TYPE(FLOAT8E5M2FNUZ) + CASE_TYPE(UINT4) + CASE_TYPE(INT4) default: return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } @@ -456,6 +544,8 @@ Status TensorProtoToMLValue(const onnx::TensorProto& tensor_proto, const MemBuff CASE_PROTO(FLOAT8E5M2, Float8E5M2); CASE_PROTO(FLOAT8E5M2FNUZ, Float8E5M2FNUZ); #endif + CASE_PROTO(INT4, Int4x2); + CASE_PROTO(UINT4, UInt4x2); case onnx::TensorProto_DataType::TensorProto_DataType_STRING: if (preallocated != nullptr) { OrtStatus* status = OrtInitializeBufferForTensor(preallocated, preallocated_size, ele_type); diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index fd5770cb70022..1e2d34e5aefc4 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -10,6 +10,7 @@ #include "core/common/type_utils.h" #include "core/graph/graph.h" #include "core/framework/framework_common.h" +#include "core/framework/int4.h" #include "core/optimizer/graph_transformer_level.h" #include "core/graph/onnx_protobuf.h" #include "test/framework/test_utils.h" @@ -46,6 +47,12 @@ struct IsTypeQuantLinearCompatible : std::true_type {}; template <> struct IsTypeQuantLinearCompatible : std::true_type {}; +template <> +struct IsTypeQuantLinearCompatible : std::true_type {}; + +template <> +struct IsTypeQuantLinearCompatible : std::true_type {}; + template struct IsTypeDequantLinearCompatible : utils::IsByteType {}; @@ -58,6 +65,12 @@ struct IsTypeDequantLinearCompatible : std::true_type {}; template <> struct IsTypeDequantLinearCompatible : std::true_type {}; +template <> +struct IsTypeDequantLinearCompatible : std::true_type {}; + +template <> +struct IsTypeDequantLinearCompatible : std::true_type {}; + class ModelTestBuilder { public: ModelTestBuilder(Graph& graph) : graph_(graph) { @@ -103,6 +116,22 @@ class ModelTestBuilder { return MakeInput(shape, data); } + template + typename std::enable_if< + std::is_same_v || std::is_same_v, + NodeArg*>::type + MakeInputInt4(const std::vector& shape, typename TInt4::UnpackedType min, typename TInt4::UnpackedType max) { + using UnpackedType = typename TInt4::UnpackedType; + std::vector data_int8 = rand_gen_.Uniform(shape, min, max); + std::vector data(TInt4::CalcNumInt4Pairs(data_int8.size())); + for (size_t i = 0; i < data_int8.size(); i++) { + size_t r = i >> 1; + size_t c = i & 0x1; + data[r].SetElem(c, data_int8[i]); + } + return MakeInput(shape, data); + } + template NodeArg* MakeInput(const std::optional>& shape, std::optional input_name = std::nullopt) { diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 3dab0ec248f95..862408f31f004 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -5,9 +5,11 @@ #include #include +#include #include "graph_transform_test_builder.h" +#include "core/framework/int4.h" #include "core/common/span_utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/session/inference_session.h" @@ -510,12 +512,21 @@ GetQDQTestCaseFn BuildQDQSplitTestCase(const std::vector& input_shape, bool use_diff_output_scale, bool use_contrib_qdq = false) { return [input_shape, axis, use_diff_output_scale, use_contrib_qdq](ModelTestBuilder& builder) { - auto* input_arg = builder.MakeInput(input_shape, - std::numeric_limits::min(), - std::numeric_limits::max()); + InputType dq_zp{}; + OutputType q_zp{}; + NodeArg* input_arg = nullptr; + + if constexpr (std::is_same_v || std::is_same_v) { + input_arg = builder.MakeInputInt4(input_shape, InputType::min_val, InputType::max_val); + dq_zp = InputType(static_cast(InputType::max_val / 2)); + q_zp = OutputType(static_cast(OutputType::max_val / 2)); + } else { + input_arg = builder.MakeInput(input_shape, std::numeric_limits::min(), + std::numeric_limits::max()); + dq_zp = std::numeric_limits::max() / 2; + q_zp = std::numeric_limits::max() / 2; + } - InputType dq_zp = std::numeric_limits::max() / 2; - OutputType q_zp = std::numeric_limits::max() / 2; auto* dq_output = builder.MakeIntermediate(); constexpr float input_scale = 0.003f; builder.AddDequantizeLinearNode(input_arg, input_scale, dq_zp, dq_output, use_contrib_qdq); diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 31e2280187f76..c338d542b0b79 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -6,6 +6,7 @@ #include "core/common/span_utils.h" #include "core/framework/compute_capability.h" #include "core/framework/node_unit.h" +#include "core/framework/int4.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas.h" @@ -1362,19 +1363,21 @@ TEST(QDQTransformerTests, DoubleQDQPairsRemover_DuplicateLastDQs) { // Runs a test that checks if DQ -> Split -> Q (many) is replaced with just Split. template static void RunDropSplitQDQTestCase(const std::vector& input_shape, int64_t axis, - bool all_same_quant_params, bool use_contrib_qdq = false) { - auto check_graph = [all_same_quant_params, use_contrib_qdq](InferenceSessionWrapper& session) { + bool all_same_quant_params, bool use_contrib_qdq = false, + bool should_not_drop = false) { + auto check_graph = [all_same_quant_params, use_contrib_qdq, should_not_drop](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - int expected_q_ops = all_same_quant_params ? 0 : 3; - int expected_dq_ops = all_same_quant_params ? 0 : 1; + int expected_q_ops = all_same_quant_params && !should_not_drop ? 0 : 3; + int expected_dq_ops = all_same_quant_params && !should_not_drop ? 0 : 1; EXPECT_EQ(op_to_count["Split"], 1); EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_q_ops); EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_dq_ops); }; std::vector opsets = {12, 13, 18, 19, 21}; - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { opsets = std::vector{21}; } @@ -1402,6 +1405,11 @@ TEST(QDQTransformerTests, Split) { RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, !USE_CONTRIB_QDQ_OPS); RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, !USE_CONTRIB_QDQ_OPS); + + // Do not yet support int4 Split, so should not drop + constexpr bool SHOULD_NOT_DROP = true; + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, !USE_CONTRIB_QDQ_OPS, SHOULD_NOT_DROP); + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, !USE_CONTRIB_QDQ_OPS, SHOULD_NOT_DROP); } // Test cases that DO NOT drop Q/DQ ops from DQ -> Split -> Q (many) diff --git a/onnxruntime/test/providers/base_tester.h b/onnxruntime/test/providers/base_tester.h index c276ae494df43..512b3402c5986 100644 --- a/onnxruntime/test/providers/base_tester.h +++ b/onnxruntime/test/providers/base_tester.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "core/framework/customregistry.h" #include "core/framework/prepacked_weights_container.h" @@ -690,8 +691,14 @@ class BaseTester { if (!is_optional_type_tensor || (is_optional_type_tensor && values != nullptr)) { // In case values is nullptr for optional type tensor, it means we are creating // an optional type tensor which is None and we hence skip values count validation - ORT_ENFORCE(shape.Size() == values_count, values_count, " input values doesn't match tensor size of ", - shape.Size()); + if constexpr (std::is_same_v || std::is_same_v) { + const int64_t expected_values_count = T::CalcNumInt4Pairs(shape.Size()); + ORT_ENFORCE(expected_values_count == values_count, values_count, + " input values doesn't match tensor size of ", expected_values_count); + } else { + ORT_ENFORCE(shape.Size() == values_count, values_count, " input values doesn't match tensor size of ", + shape.Size()); + } // If it is an optional tensor type with no values (i.e.) None, // we won't even pass it in to Run() as part of the feeds, diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 47c18c478dd9c..d0e08448ce456 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -7,6 +7,7 @@ #include "core/graph/constants.h" #include "core/framework/TensorSeq.h" +#include "core/framework/int4.h" #include "test/framework/test_utils.h" #include "test/providers/provider_test_utils.h" @@ -162,6 +163,46 @@ struct TensorCheck { } }; +template <> +struct TensorCheck { + void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, + const std::string& /*provider_type*/) const { + ORT_UNUSED_PARAMETER(params); + Tensor expected_sorted, actual_sorted; + const Int4x2* cur_expected; + const Int4x2* cur_actual; + const auto size = actual.Shape().Size(); + cur_expected = expected.Data(); + cur_actual = actual.Data(); + + for (size_t i = 0; i < static_cast(size); ++i) { + size_t r = i >> 1; + size_t c = i & 0x1; + EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; + } + } +}; + +template <> +struct TensorCheck { + void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, + const std::string& /*provider_type*/) const { + ORT_UNUSED_PARAMETER(params); + Tensor expected_sorted, actual_sorted; + const UInt4x2* cur_expected; + const UInt4x2* cur_actual; + const auto size = actual.Shape().Size(); + cur_expected = expected.Data(); + cur_actual = actual.Data(); + + for (size_t i = 0; i < static_cast(size); ++i) { + size_t r = i >> 1; + size_t c = i & 0x1; + EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; + } + } +}; + template <> struct TensorCheck { void operator()(const Tensor& expected, @@ -437,6 +478,7 @@ void Check(std::string_view name, const OrtValue& expected, const Tensor utils::MLTypeCallDispatcher dims{5}; + constexpr int unused_val = 0; + + // Odd number of int4 values to test packing/unpacking + test.AddInput("x", dims, {Int4x2(-8, -3), Int4x2(1, 7), Int4x2(2, unused_val)}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddInput("x_zero_point", {}, {Int4x2(-1, unused_val)}); + test.AddOutput("y", dims, {-14.0f, -4.0f, 4.0f, 16.0f, 6.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// scalar zero & scale with uint4 +TEST(DequantizeLinearOpTest, UInt4) { + OpTester test("DequantizeLinear", 21); + std::vector dims{5}; + constexpr int unused_val = 0; + + // Odd number of uint4 values to test packing/unpacking + test.AddInput("x", dims, {UInt4x2(0, 1), UInt4x2(3, 15), UInt4x2(2, unused_val)}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddInput("x_zero_point", {}, {UInt4x2(1, unused_val)}); + test.AddOutput("y", dims, {-2.0f, 0.0f, 4.0f, 28.0f, 2.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + // Test int16 DequantizeLinear (per tensor) TEST(DequantizeLinearOpTest, Int16) { OpTester test("DequantizeLinear", 21); @@ -349,6 +378,122 @@ TEST(QuantizeLinearOpTest, Int16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// Test int4 QuantizeLinear (per tensor) +TEST(QuantizeLinearOpTest, Int4) { + OpTester test("QuantizeLinear", 21); + std::vector dims{7}; + constexpr int8_t unused_val = 0; + test.AddInput("x", dims, { + -20.0f, // Clamp to qmin + -16.0f, // Close to qmin + -3.0f, // round + 0.0f, // Zero-point + 2.9f, // round + 12.0f, // qmax + 20.0f, // Clamp to qmax + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {Int4x2(1, unused_val)}, true); + test.AddOutput("y", dims, + {Int4x2(-8, -7), Int4x2(-1, 1), Int4x2(2, 7), + Int4x2(7, unused_val)}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test uint4 QuantizeLinear (per tensor) +TEST(QuantizeLinearOpTest, UInt4) { + OpTester test("QuantizeLinear", 21); + std::vector dims{7}; + constexpr uint8_t unused_val = 0; + test.AddInput("x", dims, { + -20.0f, // Clamp to qmin + -8.0f, // qmin + -3.0f, // round + 0.0f, // Zero-point + 2.9f, // round + 22.0f, // qmax + 30.0f, // Clamp to qmax + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {UInt4x2(4, unused_val)}, true); + test.AddOutput("y", dims, + {UInt4x2(0, 0), UInt4x2(2, 4), UInt4x2(5, 15), + UInt4x2(15, unused_val)}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +template +static void GetExpectedInt4Quant(const float* input, Int4x2Base* output, size_t num_elems, float scale, + int8_t zero_point) { + using UnpackedType = typename Int4x2Base::UnpackedType; + + for (size_t n = 0; n < num_elems; n++) { + float float_val = std::nearbyintf(input[n] / scale) + static_cast(zero_point); + float_val = std::max(float_val, static_cast(Int4x2Base::min_val)); + float_val = std::min(float_val, static_cast(Int4x2Base::max_val)); + + UnpackedType int_val = static_cast(float_val); + + size_t i = n >> 1; + size_t j = n & 0x1; + output[i].SetElem(j, int_val); + } +} + +// Test int4 QuantizeLinear (per tensor) with a "large" and odd number of input elements. +// This exercises the TryParallelFor call which splits the input into blocks of even size. +TEST(QuantizeLinearOpTest, OddLarge_Int4) { + OpTester test("QuantizeLinear", 21); + std::vector dims{1017}; + constexpr int8_t unused_val = 0; + constexpr std::array pattern = {-20.0f, -14.0f, -4.1f, -0.0f, 3.0f, 3.3f}; + std::vector input_f32s(static_cast(dims[0])); + std::vector output(Int4x2::CalcNumInt4Pairs(input_f32s.size())); + + for (size_t i = 0; i < input_f32s.size(); ++i) { + input_f32s[i] = pattern[i % pattern.size()]; + } + + float scale = 2.0f; + int8_t zp = 1; + GetExpectedInt4Quant(input_f32s.data(), &output[0], input_f32s.size(), scale, zp); + + test.AddInput("x", dims, input_f32s); + test.AddInput("scale", {}, {scale}, true); + test.AddInput("zero_point", {}, {Int4x2(zp, unused_val)}, true); + test.AddOutput("y", dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test uint4 QuantizeLinear (per tensor) with a "large" and odd number of input elements. +// This exercises the TryParallelFor call which splits the input into blocks of even size. +TEST(QuantizeLinearOpTest, OddLarge_UInt4) { + OpTester test("QuantizeLinear", 21); + std::vector dims{1017}; + constexpr uint8_t unused_val = 0; + constexpr std::array pattern = {-20.0f, -14.0f, -4.1f, -0.0f, 3.0f, 3.3f}; + std::vector input_f32s(static_cast(dims[0])); + std::vector output(UInt4x2::CalcNumInt4Pairs(input_f32s.size())); + + for (size_t i = 0; i < input_f32s.size(); ++i) { + input_f32s[i] = pattern[i % pattern.size()]; + } + + float scale = 2.0f; + uint8_t zp = 1; + GetExpectedInt4Quant(input_f32s.data(), &output[0], input_f32s.size(), scale, zp); + + test.AddInput("x", dims, input_f32s); + test.AddInput("scale", {}, {scale}, true); + test.AddInput("zero_point", {}, {UInt4x2(zp, unused_val)}, true); + test.AddOutput("y", dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + // quantize with scalar zero point and scale TEST(QuantizeLinearOpTest, Int8_NegativeZeroPoint) { // TODO: Unskip when fixed #41968513 diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 0e7ac5ed2b2f0..01dba55ceb8ed 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -49,14 +49,17 @@ void TransposeTest(const std::vector& input_shape, const std::vector* p_perm, const std::vector& expected_shape, const std::vector& expected_vals, - const std::unordered_set& excluded_provider_types = {}) { - OpTester test("Transpose"); - if (nullptr != p_perm) - test.AddAttribute("perm", *p_perm); - test.AddInput("X", input_shape, input_vals); - test.AddOutput("Y", expected_shape, expected_vals); - - test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_provider_types); + const std::unordered_set& excluded_provider_types = {}, + const std::vector& opsets = {7}) { + for (auto opset : opsets) { + OpTester test("Transpose", opset); + if (nullptr != p_perm) + test.AddAttribute("perm", *p_perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_provider_types); + } } // Test 2 dimensional transpose, with no permutation attribute specified @@ -73,7 +76,7 @@ TEST(TransposeOpTest, TwoDimNoAttr) { 3.0f, 6.0f}; TransposeTest(input_shape, input_vals, nullptr, expected_shape, expected_vals, - {kTensorrtExecutionProvider}); // TensorRT: SegFault error + {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: SegFault error } TEST(TransposeOpTest, TwoDimNoAttrStr) { @@ -88,7 +91,7 @@ TEST(TransposeOpTest, TwoDimNoAttrStr) { "2", "5", "3", "6"}; - TransposeTest(input_shape, input_vals, nullptr, expected_shape, expected_vals); + TransposeTest(input_shape, input_vals, nullptr, expected_shape, expected_vals, {}, {7, 21}); } // Test 2 dimensional transpose, with permutation attribute specified @@ -103,7 +106,47 @@ TEST(TransposeOpTest, TwoDim) { 2.0f, 5.0f, 3.0f, 6.0f}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {}, {7, 21}); +} + +// Test Int4 transpose with odd inner dimension. +TEST(TransposeOpTest, TwoDim_Odd_Int4) { + constexpr int8_t unused_val = 0; + std::vector input_shape({5, 3}); + std::vector input_vals = {Int4x2(1, 2), Int4x2(3, 4), Int4x2(5, 6), Int4x2(7, 8), + Int4x2(9, 10), Int4x2(11, 12), Int4x2(13, 14), Int4x2(15, unused_val)}; + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 5}); + std::vector expected_vals = {Int4x2(1, 4), Int4x2(7, 10), Int4x2(13, 2), Int4x2(5, 8), + Int4x2(11, 14), Int4x2(3, 6), Int4x2(9, 12), Int4x2(15, unused_val)}; + + OpTester test("Transpose", 21); + test.AddAttribute("perm", perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test UInt4 transpose with odd inner dimension. +TEST(TransposeOpTest, TwoDim_Odd_UInt4) { + constexpr int8_t unused_val = 0; + std::vector input_shape({5, 3}); + std::vector input_vals = {UInt4x2(1, 2), UInt4x2(3, 4), UInt4x2(5, 6), UInt4x2(7, 8), + UInt4x2(9, 10), UInt4x2(11, 12), UInt4x2(13, 14), UInt4x2(15, unused_val)}; + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 5}); + std::vector expected_vals = {UInt4x2(1, 4), UInt4x2(7, 10), UInt4x2(13, 2), UInt4x2(5, 8), + UInt4x2(11, 14), UInt4x2(3, 6), UInt4x2(9, 12), UInt4x2(15, unused_val)}; + + OpTester test("Transpose", 21); + test.AddAttribute("perm", perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(TransposeOpTest, TwoDim_double) { @@ -131,7 +174,7 @@ TEST(TransposeOpTest, TwoDim_int32) { 2, 5, 3, 6}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {}, {7, 21}); } TEST(TransposeOpTest, TwoDim_int16) { @@ -147,7 +190,7 @@ TEST(TransposeOpTest, TwoDim_int16) { 2, 5, 3, 6}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kOpenVINOExecutionProvider}); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kOpenVINOExecutionProvider}, {7, 21}); } TEST(TransposeOpTest, TwoDim_mlfloat16) { @@ -163,7 +206,7 @@ TEST(TransposeOpTest, TwoDim_mlfloat16) { MLFloat16::FromBits(static_cast(2)), MLFloat16::FromBits(static_cast(5)), MLFloat16::FromBits(static_cast(3)), MLFloat16::FromBits(static_cast(6))}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}, {7, 21}); } #if defined(USE_DNNL) @@ -264,7 +307,7 @@ TEST(TransposeOpTest, TwoDim_int8) { 2, 5, 3, 6}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}, {7, 21}); } TEST(TransposeOpTest, TwoDimStr) { @@ -280,7 +323,7 @@ TEST(TransposeOpTest, TwoDimStr) { "2", "5", "3", "6"}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {}, {7, 21}); } // Test 3 dimensional transpose, with permutation attribute specified @@ -319,7 +362,7 @@ TEST(TransposeOpTest, Transpose021) { 3.3f, 6.3f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, - {kTensorrtExecutionProvider}); // TensorRT: illegal error + {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: illegal error } TEST(TransposeOpTest, Transpose120) { @@ -349,7 +392,7 @@ TEST(TransposeOpTest, Transpose120) { 6.0f, 6.1f, 6.2f, 6.3f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, - {kTensorrtExecutionProvider}); // TensorRT: illegal error + {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: illegal error } // test when the suffix size is > 1 (last dimension is not moved) @@ -382,7 +425,7 @@ TEST(TransposeOpTest, Transpose102) { 4.3f, 5.3f, 6.3f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, - {kTensorrtExecutionProvider}); // TensorRT: illegal error + {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: illegal error } TEST(TransposeOpTest, TransposeReshape) { @@ -416,7 +459,7 @@ TEST(TransposeOpTest, TransposeReshape) { 4.3f, 5.3f, 6.3f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, - {kTensorrtExecutionProvider}); // TensorRT: illegal error + {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: illegal error } TEST(TransposeOpTest, ThreeDimStr) { @@ -453,7 +496,7 @@ TEST(TransposeOpTest, ThreeDimStr) { "2", "5", "3", "6"}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {}, {7, 21}); } TEST(TransposeOpTest, SixDim) { @@ -478,7 +521,7 @@ TEST(TransposeOpTest, SixDim) { }(); TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, - {kQnnExecutionProvider}); // Error: Failed to finalize QNN graph. + {kQnnExecutionProvider}, {7, 21}); // Error: Failed to finalize QNN graph. } template @@ -522,7 +565,7 @@ TEST(TransposeOpTest, NCHW2NHWCStr) { "3", "7", "11", "4", "8", "12"}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}, {7, 21}); } template @@ -582,7 +625,7 @@ TEST(TransposeOpTest, NHWC2NCHW_String) { "2", "5", "8", "11", "3", "6", "9", "12"}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}, {7, 21}); } // test to cover memcpy from single axis moving inwards path diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 0cde8c31a561c..efe21915978ef 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -651,7 +651,9 @@ def setUpClass(cls): def tearDownClass(cls): cls._tmp_model_dir.cleanup() - def construct_model_conv_relu(self, output_model_path, input_shape, weight_shape, output_shape): + def construct_model_conv_relu( + self, output_model_path, input_shape, weight_shape, output_shape, opset=13, ir_version=7 + ): # (input) # | # Conv @@ -686,19 +688,31 @@ def construct_model_conv_relu(self, output_model_path, input_shape, weight_shape [output_tensor], initializer=initializers, ) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", opset)]) + model.ir_version = ir_version onnx.save(model, output_model_path) - def verify_qdq(self, per_channel, activation_type, weight_type, extra_options=None): + def verify_qdq( + self, + per_channel, + activation_type, + weight_type, + extra_options=None, + opset=13, + ir_version=7, + rtol=1e-2, + atol=0.05, + ): np.random.seed(1) model_fp32_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_fp32.{per_channel}.onnx") model_qdq_path = str( Path(self._tmp_model_dir.name) / f"conv_relu_quant_qdq.{activation_type}.{weight_type}.{per_channel}.onnx" ) data_reader = self.input_feeds(1, {"input": [1, 8, 33, 33]}) - self.construct_model_conv_relu(model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 31, 31]) + self.construct_model_conv_relu( + model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 31, 31], opset=opset, ir_version=ir_version + ) quantize_static( model_fp32_path, model_qdq_path, @@ -724,7 +738,7 @@ def verify_qdq(self, per_channel, activation_type, weight_type, extra_options=No "DequantizeLinear", ], ) - check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) + check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next(), rtol=rtol, atol=atol) # If the model uses Q/DQ ops with "com.microsoft" domain (e.g., for int16 support), # then ensure the model has the appropriate opset import. @@ -777,6 +791,16 @@ def test_quantize_conv_without_bias(self): self.verify_qdq(True, QuantType.QUInt16, QuantType.QUInt8, {"UseQDQContribOps": True}) self.verify_qdq(True, QuantType.QInt16, QuantType.QInt8, {"UseQDQContribOps": True}) + # 4-bit QDQ + self.verify_qdq(False, QuantType.QInt16, QuantType.QInt4, opset=21, ir_version=10, atol=0.4) # per-tensor + self.verify_qdq(True, QuantType.QInt16, QuantType.QInt4, opset=21, ir_version=10) # per-channel + self.verify_qdq( + False, QuantType.QInt16, QuantType.QInt4, {"UseQDQContribOps": True}, opset=21, ir_version=10, atol=0.4 + ) # per-tensor + self.verify_qdq( + True, QuantType.QInt16, QuantType.QInt4, {"UseQDQContribOps": True}, opset=21, ir_version=10 + ) # per-channel + def test_quantize_relu_conv(self): float_model_path = str(Path(self._tmp_model_dir.name) / "float_relu_convs_model.onnx") construct_relu_conv_model(float_model_path) @@ -1491,5 +1515,183 @@ def test_16bit_subgraph(self): check_model_correctness(self, f32_model_path, qdq_model_path, data_reader.get_next()) +class TestQDQ4bit(TestQDQFormat): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.4bit_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_conv_test_model( + self, + inp_shape: list[int], + weight_data: np.ndarray, + bias_data: np.ndarray, + ): + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, inp_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, None) + weight = onnx.numpy_helper.from_array(weight_data, "weight") + bias = onnx.numpy_helper.from_array(bias_data, "bias") + + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [conv_node], + "Convf32", + [input_0], + [output_0], + initializer=[weight, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + + return onnx.shape_inference.infer_shapes(model) + + def test_int4_qdq_conv(self): + """ + Test quantization of int4 conv weight. + """ + float_model_path = os.path.join(self._tmp_dir_path, "conv_int4.f32.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, "conv_int4.qdq.onnx") + + inp_shape = [1, 2, 100, 100] + weight_shape = [2, 2, 20, 20] + + # range = 3.0, scale = 3/15, zp = 0 + weight_data = np.linspace(-1.5, 1.5, num=1600, dtype=np.float32).reshape(weight_shape) + bias_data = np.array([-10.0, 10.0], dtype=np.float32) + float_model = self.build_conv_test_model(inp_shape, weight_data, bias_data) + + onnx.checker.check_model(float_model, True) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": inp_shape}, np.float32) + + tensor_quant_overrides = { + "weight": [{"quant_type": QuantType.QInt4}], # Quantize weights to INT4 + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt16, + weight_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": tensor_quant_overrides, + }, + ) + + qdq_node_counts = {"QuantizeLinear": 2, "DequantizeLinear": 4} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + + # Check the the weight's zero-point data type is INT4 and has expected value + zp_val = 0 + weight_zp_init = initializers["weight_zero_point"] + self.assertEqual(weight_zp_init.data_type, onnx.TensorProto.INT4) + self.assertEqual(weight_zp_init.int32_data[0], zp_val) + + # Check for the expected scale value + weight_scale_init = initializers["weight_scale"] + scale_val = np.float32(3.0 / 15) + self.assertEqual(weight_scale_init.data_type, onnx.TensorProto.FLOAT) + self.assertEqual(weight_scale_init.float_data[0], scale_val) + + # Check that INT4 weights take up approximately 50% the size of INT8 weights. + # Using protobuf's ByteSize() is not exact because it includes other fields in the proto message. + unpacked_size = 1 + for dim in weight_shape: + unpacked_size *= dim + + weight_quant_init = initializers["weight_quantized"] + size_ratio = weight_quant_init.ByteSize() / unpacked_size + self.assertLess(size_ratio, 0.55) + + # Check that the quantized weight values are correct. + if weight_quant_init.HasField("raw_data"): + float_data = weight_data.flatten().tolist() + for index, float_val in enumerate(float_data): + expected_int4_val = np.clip(np.float32(float_val / scale_val).round() + zp_val, -8, 7) + int4_pair = onnx.subbyte.unpack_single_4bitx2(weight_quant_init.raw_data[index >> 1], True) + int4_val = int4_pair[index & 0x1] + + self.assertEqual(np.float32(int4_val), expected_int4_val) + + def test_int4_qdq_per_channel_conv(self): + """ + Test per-channel quantization of int4 conv weight. + """ + float_model_path = os.path.join(self._tmp_dir_path, "conv_int4_per_chan.f32.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, "conv_int4_per_chan.qdq.onnx") + + inp_shape = [1, 2, 100, 100] + weight_shape = [2, 2, 20, 20] + + weight_data = np.linspace(-1.5, 1.5, num=1600, dtype=np.float32).reshape(weight_shape) + bias_data = np.array([-10.0, 10.0], dtype=np.float32) + float_model = self.build_conv_test_model(inp_shape, weight_data, bias_data) + + onnx.checker.check_model(float_model, True) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": inp_shape}, np.float32) + + per_chan_axis = 0 + tensor_quant_overrides = { + "weight": [{"quant_type": QuantType.QInt4, "axis": per_chan_axis}], # Quantize weight to INT4 (per-channel) + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt16, + weight_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": tensor_quant_overrides, + }, + ) + + qdq_node_counts = {"QuantizeLinear": 2, "DequantizeLinear": 4} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + + # Check that the weight's zero-point data type is INT4 and has 2 elems + weight_zp_init = initializers["weight_zero_point"] + self.assertEqual(weight_zp_init.data_type, onnx.TensorProto.INT4) + self.assertEqual(weight_zp_init.dims[0], 2) + + # Check that the weight's scale data type is FLOAT and has 2 elems + weight_scale_init = initializers["weight_scale"] + self.assertEqual(weight_scale_init.data_type, onnx.TensorProto.FLOAT) + self.assertEqual(weight_scale_init.dims[0], 2) + + # Check that INT4 weights take up approximately 50% the size of INT8 weights. + # Using protobuf's ByteSize() is not exact because it includes other fields in the proto message. + unpacked_size = 1 + for dim in weight_shape: + unpacked_size *= dim + + weight_quant_init = initializers["weight_quantized"] + size_ratio = weight_quant_init.ByteSize() / unpacked_size + self.assertLess(size_ratio, 0.55) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 051a93ac8458f..58c185f818df7 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -2827,6 +2827,50 @@ TEST(CApiTest, create_tensor_with_data_float8) { #endif +// Test creating an Ort::Value with INT4 data. +TEST(CApiTest, create_tensor_with_data_int4) { + std::array values = {0x10, 0x32, 0x78, 0x06}; // {0, 1, 2, 3, -8, 7, 6, pad_0} + std::vector dims = {7}; // 7 4-bit elements take up 4 bytes. + + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + Ort::Value tensor = Ort::Value::CreateTensor(info, values.data(), values.size(), dims.data(), dims.size(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4); + const auto* new_pointer = tensor.GetTensorData(); + ASSERT_EQ(new_pointer, values.data()); + auto type_info = tensor.GetTypeInfo(); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + ASSERT_NE(tensor_info, nullptr); + auto query_dims = tensor_info.GetShape(); + ASSERT_EQ(query_dims, dims); + ASSERT_EQ(tensor_info.GetElementType(), ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4); + + uint8_t pair_2 = tensor.At({2}); + ASSERT_EQ(values[2], pair_2); +} + +// Test creating an Ort::Value with UINT4 data. +TEST(CApiTest, create_tensor_with_data_uint4) { + std::array values = {0x10, 0x32, 0x54, 0x0F}; // {0, 1, 2, 3, 4, 5, 15, pad_0} + std::vector dims = {7}; // 7 4-bit elements take up 4 bytes. + + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + Ort::Value tensor = Ort::Value::CreateTensor(info, values.data(), values.size(), dims.data(), dims.size(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4); + const auto* new_pointer = tensor.GetTensorData(); + ASSERT_EQ(new_pointer, values.data()); + auto type_info = tensor.GetTypeInfo(); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + ASSERT_NE(tensor_info, nullptr); + auto query_dims = tensor_info.GetShape(); + ASSERT_EQ(query_dims, dims); + ASSERT_EQ(tensor_info.GetElementType(), ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4); + + uint8_t pair_2 = tensor.At({2}); + ASSERT_EQ(values[2], pair_2); +} + TEST(CApiTest, access_tensor_data_elements) { /** * Create a 2x3 data blob that looks like: diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 005128bc05d4a..1885a213bdf32 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -312,7 +312,13 @@ // DequantizeLinear(21) blocked quantization from ONNX 1.16.0 is not implemented in ORT yet. "^test_dequantizelinear_blocked", "^test_quantizelinear_blocked_asymmetric", - "^test_quantizelinear_blocked_symmetric" + "^test_quantizelinear_blocked_symmetric", + // Bug with test model: node's input name does not match the model's input name (x_zero_point vs zero_point) + // PR with fix: https://github.com/onnx/onnx/pull/6122 + "^test_dequantizelinear_int4", + "^test_dequantizelinear_uint4", + "^test_quantizelinear_int4", + "^test_quantizelinear_uint4" ], "current_failing_tests_x86": [ "^test_vgg19", diff --git a/onnxruntime/test/util/compare_ortvalue.cc b/onnxruntime/test/util/compare_ortvalue.cc index 64ebe24188762..cc4c0440d26d9 100644 --- a/onnxruntime/test/util/compare_ortvalue.cc +++ b/onnxruntime/test/util/compare_ortvalue.cc @@ -29,6 +29,7 @@ #pragma GCC diagnostic pop #endif +#include "core/framework/int4.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/framework/TensorSeq.h" @@ -202,6 +203,44 @@ std::pair IsResultExactlyMatch(const Tensor& outval return std::make_pair(COMPARE_RESULT::SUCCESS, ""); } +template <> +std::pair IsResultExactlyMatch(const Tensor& outvalue, + const Tensor& expected_value) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const Int4x2* expected_output = expected_value.Data(); + const Int4x2* real_output = outvalue.Data(); + for (size_t di = 0; di != size1; ++di) { + size_t r = di >> 1; + size_t c = di & 0x1; + + if (expected_output[r].GetElem(c) != real_output[r].GetElem(c)) { + std::ostringstream oss; + oss << "expected " << expected_output[r].GetElem(c) << ", got " << real_output[r].GetElem(c); + return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + } + } + return std::make_pair(COMPARE_RESULT::SUCCESS, ""); +} + +template <> +std::pair IsResultExactlyMatch(const Tensor& outvalue, + const Tensor& expected_value) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const UInt4x2* expected_output = expected_value.Data(); + const UInt4x2* real_output = outvalue.Data(); + for (size_t di = 0; di != size1; ++di) { + size_t r = di >> 1; + size_t c = di & 0x1; + + if (expected_output[r].GetElem(c) != real_output[r].GetElem(c)) { + std::ostringstream oss; + oss << "expected " << expected_output[r].GetElem(c) << ", got " << real_output[r].GetElem(c); + return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + } + } + return std::make_pair(COMPARE_RESULT::SUCCESS, ""); +} + std::pair CompareFloat16Result(const Tensor& outvalue, const Tensor& expected_value, double per_sample_tolerance, double relative_per_sample_tolerance, @@ -313,6 +352,10 @@ std::pair CompareTwoTensors(const Tensor& outvalue, return IsResultExactlyMatch(outvalue, expected_tensor); } else if (outvalue.IsDataType()) { return IsResultExactlyMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return IsResultExactlyMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return IsResultExactlyMatch(outvalue, expected_tensor); } else if (outvalue.IsDataType()) { return CompareFloat16Result(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, post_processing);