diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 12b772ceff282..34911cfc7972e 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -93,7 +93,10 @@ jobs:
github_token: ${{ secrets.github_token }}
reporter: github-pr-check
level: warning
- flags: --linelength=120 --exclude=java/src/main/native/*.c
+ flags: --linelength=120
+ --exclude=java/src/main/native/*.c
+ --exclude=onnxruntime/core/mlas/inc/*
+ --exclude=onnxruntime/core/mlas/lib/*
filter: "-runtime/references"
lint-js:
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index bf38dd56247d9..4d7493bd69650 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1373,7 +1373,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T1 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32)
+- T1 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32), tensor(int4), tensor(uint4)
- Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, 16-bit integer tensors, or 32-bit signed integer tensors.
- T2 : tensor(float16), tensor(float)
- Constrain 'y', 'x_scale' to float tensors.
@@ -4832,7 +4832,7 @@ This version of the operator has been available since version 1 of the 'com.micr
- T1 : tensor(float16), tensor(float)
- Constrain 'x', 'y_scale' to float tensors.
-- T2 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16)
+- T2 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int4), tensor(uint4)
- Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 4533884a51773..8092c26da651a 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -87,7 +87,7 @@ Do not modify directly.*
|DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[1, 10]|**T** = tensor(double), tensor(float)|
-|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**
or
*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
+|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**
or
*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
|||[19, 20]|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
|||[13, 18]|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|||[10, 12]|**T** = tensor(int32), tensor(int8), tensor(uint8)|
@@ -259,7 +259,7 @@ Do not modify directly.*
|||[7, 11]|**T** = tensor(double), tensor(float)|
|QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)|
|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**
or
*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)|
-|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**
or
*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
+|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**
or
*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)|
|||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
|||[10, 12]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
@@ -418,7 +418,7 @@ Do not modify directly.*
|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**
or
*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float)|
|||[1, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float)|
-|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)|
@@ -468,7 +468,7 @@ Do not modify directly.*
|CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)|
-|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)
**T2** = tensor(float)|
+|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float)|
|DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
|DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
|EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float)|
@@ -504,7 +504,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearWhere|*in* condition:**B**
*in* X:**T**
*in* x_scale:**TF**
*in* x_zero_point:**T**
*in* Y:**T**
*in* y_scale:**TF**
*in* y_zero_point:**T**
*in* z_scale:**TF**
*in* z_zero_point:**T**
*out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)|
-|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
+|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)|
diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h
index f3942128077de..b197d88090432 100644
--- a/include/onnxruntime/core/framework/data_types.h
+++ b/include/onnxruntime/core/framework/data_types.h
@@ -15,6 +15,7 @@
#include "core/framework/endian.h"
#include "core/framework/float8.h"
#include "core/framework/float16.h"
+#include "core/framework/int4.h"
#include "core/graph/onnx_protobuf.h"
#include "core/framework/to_tensor_proto_element_type.h"
@@ -280,7 +281,8 @@ struct IsAnyOf {
template
struct IsTensorContainedType : public IsAnyOf {
* Base class for primitive Tensor contained types
*
* \details This class contains an integer constant that can be
- * used for input data type dispatching
+ * used for input data type dispatching. This class also stores the number of subelements per size units.
+ * Example: For int4, the size unit is 1 byte and the number of subelements is 2.
*
*/
class PrimitiveDataTypeBase : public DataTypeImpl {
@@ -934,12 +937,21 @@ class PrimitiveDataTypeBase : public DataTypeImpl {
return data_type_;
}
+ int32_t GetNumSubElems() const {
+ return num_sub_elems_;
+ }
+
+ bool HasSubElems() const {
+ return num_sub_elems_ > 1;
+ }
+
protected:
- PrimitiveDataTypeBase(size_t size, int32_t data_type)
- : DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type} {}
+ PrimitiveDataTypeBase(size_t size, int32_t data_type, int32_t num_sub_elems)
+ : DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type}, num_sub_elems_{num_sub_elems} {}
private:
const int32_t data_type_;
+ const int32_t num_sub_elems_; // > 1 for subbyte primitives, 1 for normal primitives.
};
/**
@@ -965,9 +977,9 @@ class PrimitiveDataType : public PrimitiveDataTypeBase {
}
private:
- PrimitiveDataType()
+ explicit PrimitiveDataType(int32_t num_sub_elems)
: PrimitiveDataTypeBase{sizeof(T),
- utils::ToTensorProtoElementType()} {
+ utils::ToTensorProtoElementType(), num_sub_elems} {
}
};
@@ -1074,15 +1086,30 @@ inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const {
return SequenceTensorType::Type(); \
}
-#define ORT_REGISTER_PRIM_TYPE(TYPE) \
- template <> \
- MLDataType PrimitiveDataType::Type() { \
- static PrimitiveDataType prim_data_type; \
- return &prim_data_type; \
- } \
- template <> \
- MLDataType DataTypeImpl::GetType() { \
- return PrimitiveDataType::Type(); \
+#define ORT_REGISTER_PRIM_TYPE(TYPE) \
+ template <> \
+ MLDataType PrimitiveDataType::Type() { \
+ static PrimitiveDataType prim_data_type(1); \
+ return &prim_data_type; \
+ } \
+ template <> \
+ MLDataType DataTypeImpl::GetType() { \
+ return PrimitiveDataType::Type(); \
+ }
+
+// Registers a subbyte primitive.
+// Examples:
+// - Int4x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Int4x2, 2)
+// - [not supported] Int3x8 could store 8 packed 3-bit elements in 3 bytes: ORT_*_SUBBYTE_TYPE(Int3x8, 8)
+#define ORT_REGISTER_PRIM_SUBBYTE_TYPE(TYPE, NUM_SUB_ELEMS) \
+ template <> \
+ MLDataType PrimitiveDataType::Type() { \
+ static PrimitiveDataType prim_data_type(NUM_SUB_ELEMS); \
+ return &prim_data_type; \
+ } \
+ template <> \
+ MLDataType DataTypeImpl::GetType() { \
+ return PrimitiveDataType::Type(); \
}
#define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \
diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h
index 3a3b5cb6888f2..05f4c10995ef2 100644
--- a/include/onnxruntime/core/framework/data_types_internal.h
+++ b/include/onnxruntime/core/framework/data_types_internal.h
@@ -93,6 +93,12 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
function(__VA_ARGS__); \
break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
+ function(__VA_ARGS__); \
+ break; \
default: \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
}
@@ -153,6 +159,12 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
retval = function(__VA_ARGS__); \
break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
+ retval = function(__VA_ARGS__); \
+ break; \
default: \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
}
@@ -203,6 +215,12 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
function(__VA_ARGS__); \
break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
+ function(__VA_ARGS__); \
+ break; \
default: \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
}
@@ -251,6 +269,12 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
retval = function(__VA_ARGS__); \
break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
+ retval = function(__VA_ARGS__); \
+ break; \
default: \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
}
diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h
new file mode 100644
index 0000000000000..228c1e4e872de
--- /dev/null
+++ b/include/onnxruntime/core/framework/int4.h
@@ -0,0 +1,125 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include
+#include "core/common/common.h"
+#include "core/common/gsl.h"
+
+namespace onnxruntime {
+
+template
+struct Int4Traits;
+
+template <>
+struct Int4Traits {
+ using UnpackedType = int8_t;
+ static constexpr int8_t min_val = -8;
+ static constexpr int8_t max_val = 7;
+};
+
+template <>
+struct Int4Traits {
+ using UnpackedType = uint8_t;
+ static constexpr uint8_t min_val = 0;
+ static constexpr uint8_t max_val = 15;
+};
+
+///
+/// Stores 2 packed 4-bit elements in 1 byte.
+///
+/// Set to true if signed int4, or false if unsigned uint4.
+template
+struct Int4x2Base {
+ using UnpackedType = typename Int4Traits::UnpackedType;
+ static constexpr UnpackedType min_val = Int4Traits::min_val;
+ static constexpr UnpackedType max_val = Int4Traits::max_val;
+
+ std::byte bits_{};
+
+ Int4x2Base() = default;
+
+ explicit Int4x2Base(std::byte bits) {
+ bits_ = bits;
+ }
+
+ Int4x2Base(UnpackedType val0, UnpackedType val1) {
+ bits_ = static_cast(((val1 & 0xF) << 4) | (val0 & 0xF));
+ }
+
+ static inline int8_t SignExtendLower4Bits(std::byte bits) {
+ // Sign-extend lower 4-bits by left shifting and then doing an arithmetic right shift.
+ constexpr uint8_t shift = (sizeof(int32_t) * 8) - 4;
+ return static_cast((static_cast(bits) << shift) >> shift);
+ }
+
+ inline UnpackedType GetElem(size_t index) const {
+ assert(index <= 1);
+ const uint8_t shift = 4 * static_cast(index);
+ const std::byte val = (bits_ >> shift) & std::byte{0xF};
+
+ if constexpr (Signed) {
+ return SignExtendLower4Bits(val);
+ } else {
+ return static_cast(val);
+ }
+ }
+
+ inline void SetElem(size_t index, UnpackedType val) {
+ assert(index <= 1);
+ const uint8_t shift = 4 * static_cast(index);
+ const std::byte mask = std::byte{0xF0} >> shift;
+
+ bits_ &= mask; // Clear 4-bit element to 0
+ bits_ |= static_cast((val & 0xF) << shift); // Set 4-bit element to val
+ }
+
+ inline std::byte ToBits() const {
+ return bits_;
+ }
+
+ static size_t CalcNumInt4Pairs(size_t num_int4_elems) {
+ return (num_int4_elems + 1) / 2;
+ }
+
+ static bool Unpack(gsl::span dst, gsl::span> src) {
+ if (CalcNumInt4Pairs(dst.size()) != src.size()) {
+ return false;
+ }
+
+ for (size_t i = 0; i < dst.size(); i++) {
+ size_t r = i >> 1; // i / 2;
+ size_t c = i & 0x1; // i % 2;
+ dst[i] = src[r].GetElem(c);
+ }
+
+ return true;
+ }
+
+ static bool Pack(gsl::span> dst, gsl::span src) {
+ if (src.empty() || (CalcNumInt4Pairs(src.size()) != dst.size())) {
+ return false;
+ }
+
+ size_t src_i = 0;
+ size_t dst_i = 0;
+
+ for (; src_i < src.size() - 1; src_i += 2) {
+ dst[dst_i++] = Int4x2Base(src[src_i], src[src_i + 1]);
+ }
+
+ if (src_i < src.size()) {
+ dst[dst_i] = Int4x2Base(src[src_i], 0);
+ }
+
+ return true;
+ }
+};
+
+using Int4x2 = Int4x2Base;
+using UInt4x2 = Int4x2Base;
+static_assert(sizeof(Int4x2) == sizeof(std::byte));
+static_assert(sizeof(UInt4x2) == sizeof(std::byte));
+} // namespace onnxruntime
diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h
index a867ab6066485..96725aa103064 100644
--- a/include/onnxruntime/core/framework/tensor.h
+++ b/include/onnxruntime/core/framework/tensor.h
@@ -145,6 +145,17 @@ class Tensor final {
/// Bytes required.
static size_t CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape);
+ ///
+ /// Calculate the required storage for the tensor.
+ ///
+ /// Data type of the tensor elements.
+ /// Tensor shape.
+ /// Power of 2 alignment to include in calculation.
+ /// Bumps up result to the nearest multiple of alignment. Set to 0 to ignore.
+ /// The resulting storage size.
+ /// Status indicating success or failure.
+ static Status CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape, size_t alignment,
+ size_t& storage_size);
/**
Returns the data type.
*/
@@ -200,7 +211,7 @@ class Tensor final {
ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ",
"T ", "!=", dtype_);
T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_);
- return gsl::make_span(data, static_cast(shape_.Size()));
+ return gsl::make_span(data, static_cast(NumStorageElements()));
}
template
@@ -217,7 +228,7 @@ class Tensor final {
ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ",
"T ", "!=", dtype_);
const T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_);
- return gsl::make_span(data, static_cast::size_type>(shape_.Size()));
+ return gsl::make_span(data, static_cast::size_type>(NumStorageElements()));
}
void* MutableDataRaw(MLDataType type) {
@@ -271,6 +282,19 @@ class Tensor final {
byte_offset_ = byte_offset;
}
+ ///
+ /// The number of Tensor "storage" elements. A single storage element may contain multiple sub-elements for
+ /// sub-byte data types (e.g., int4).
+ ///
+ /// For element types smaller than 1 byte (e.g., int4), a single storage element stores multiple sub-byte elements.
+ /// Example: Tensor of shape (4,) has 2 storage elements.
+ ///
+ /// For element types >= 1 byte, this function returns the product of the shape.
+ /// Example: Tensor of shape (4,) has 4 storage elements.
+ ///
+ /// Number of tensor storage elements
+ int64_t NumStorageElements() const;
+
/**
The number of bytes of data.
*/
diff --git a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h
index 21253eb4a6e83..e9e28e4864a67 100644
--- a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h
+++ b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h
@@ -12,6 +12,7 @@
#include "core/framework/float8.h"
#include "core/framework/float16.h"
+#include "core/framework/int4.h"
namespace onnxruntime {
namespace utils {
@@ -97,6 +98,14 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType
+constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() {
+ return ONNX_NAMESPACE::TensorProto_DataType_INT4;
+}
+template <>
+constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() {
+ return ONNX_NAMESPACE::TensorProto_DataType_UINT4;
+}
} // namespace utils
} // namespace onnxruntime
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 67cc493f04ab9..16701f2e0d923 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -196,7 +196,10 @@ typedef enum ONNXTensorElementDataType {
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, // Non-IEEE floating-point format based on IEEE754 single-precision
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, // Non-IEEE floating-point format based on IEEE754 single-precision
- ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ // Non-IEEE floating-point format based on IEEE754 single-precision
+ ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision
+ // Int4 types were introduced in ONNX 1.16. See https://onnx.ai/onnx/technical/int4.html
+ ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of packed uint4 values (size == 1 byte)
+ ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of packed int4 values (size == 1 byte)
} ONNXTensorElementDataType;
// Synced with onnx TypeProto oneof
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index b37ce2f72e721..8aa885cf1ebd6 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -64,10 +64,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, DequantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, QuantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid);
@@ -200,15 +204,32 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc
index 6c4aec417a033..72ab5a9e898c7 100644
--- a/onnxruntime/core/framework/data_types.cc
+++ b/onnxruntime/core/framework/data_types.cc
@@ -639,6 +639,8 @@ ORT_REGISTER_TENSOR_TYPE(Float8E4M3FNUZ);
ORT_REGISTER_TENSOR_TYPE(Float8E5M2);
ORT_REGISTER_TENSOR_TYPE(Float8E5M2FNUZ);
#endif
+ORT_REGISTER_TENSOR_TYPE(Int4x2);
+ORT_REGISTER_TENSOR_TYPE(UInt4x2);
#if !defined(DISABLE_SPARSE_TENSORS)
ORT_REGISTER_SPARSE_TENSOR_TYPE(int32_t);
@@ -700,6 +702,9 @@ ORT_REGISTER_SEQ_TENSOR_TYPE(Float8E5M2FNUZ);
#endif
+ORT_REGISTER_SEQ_TENSOR_TYPE(Int4x2);
+ORT_REGISTER_SEQ_TENSOR_TYPE(UInt4x2);
+
#if !defined(DISABLE_ML_OPS)
ORT_REGISTER_SEQ(VectorMapStringToFloat);
ORT_REGISTER_SEQ(VectorMapInt64ToFloat);
@@ -725,7 +730,9 @@ ORT_REGISTER_SEQ(VectorMapInt64ToFloat);
ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E4M3FN); \
ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E4M3FNUZ); \
ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2); \
- ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2FNUZ);
+ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2FNUZ); \
+ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int4x2); \
+ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2);
#else
@@ -743,7 +750,9 @@ ORT_REGISTER_SEQ(VectorMapInt64ToFloat);
ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, uint32_t); \
ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, uint64_t); \
ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, MLFloat16); \
- ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, BFloat16);
+ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, BFloat16); \
+ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int4x2); \
+ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2);
#endif
@@ -808,6 +817,8 @@ void RegisterAllProtos(const std::function& reg_fn) {
REGISTER_TENSOR_PROTO(Float8E5M2, reg_fn);
REGISTER_TENSOR_PROTO(Float8E5M2FNUZ, reg_fn);
#endif
+ REGISTER_TENSOR_PROTO(Int4x2, reg_fn);
+ REGISTER_TENSOR_PROTO(UInt4x2, reg_fn);
#if !defined(DISABLE_SPARSE_TENSORS)
REGISTER_SPARSE_TENSOR_PROTO(int32_t, reg_fn);
@@ -867,6 +878,9 @@ void RegisterAllProtos(const std::function& reg_fn) {
#endif
+ REGISTER_SEQ_TENSOR_PROTO(Int4x2, reg_fn);
+ REGISTER_SEQ_TENSOR_PROTO(UInt4x2, reg_fn);
+
#if !defined(DISABLE_ML_OPS)
REGISTER_ONNX_PROTO(VectorMapStringToFloat, reg_fn);
REGISTER_ONNX_PROTO(VectorMapInt64ToFloat, reg_fn);
@@ -894,7 +908,9 @@ void RegisterAllProtos(const std::function& reg_fn) {
REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E4M3FN, reg_fn); \
REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E4M3FNUZ, reg_fn); \
REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2, reg_fn); \
- REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2FNUZ, reg_fn);
+ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2FNUZ, reg_fn); \
+ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int4x2, reg_fn); \
+ REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn);
#else
@@ -912,7 +928,9 @@ void RegisterAllProtos(const std::function& reg_fn) {
REGISTER_OPTIONAL_PROTO(ORT_TYPE, uint32_t, reg_fn); \
REGISTER_OPTIONAL_PROTO(ORT_TYPE, uint64_t, reg_fn); \
REGISTER_OPTIONAL_PROTO(ORT_TYPE, MLFloat16, reg_fn); \
- REGISTER_OPTIONAL_PROTO(ORT_TYPE, BFloat16, reg_fn);
+ REGISTER_OPTIONAL_PROTO(ORT_TYPE, BFloat16, reg_fn); \
+ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int4x2, reg_fn); \
+ REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn);
#endif
@@ -973,6 +991,10 @@ const char* DataTypeImpl::ToString(MLDataType type) {
return "Float8E5M2";
case TensorProto_DataType_FLOAT8E5M2FNUZ:
return "Float8E5M2FNUZ";
+ case TensorProto_DataType_INT4:
+ return "Int4x2";
+ case TensorProto_DataType_UINT4:
+ return "UInt4x2";
default:
break;
}
@@ -1041,6 +1063,10 @@ const TensorTypeBase* DataTypeImpl::TensorTypeFromONNXEnum(int type) {
return DataTypeImpl::GetTensorType()->AsTensorType();
#endif
+ case TensorProto_DataType_INT4:
+ return DataTypeImpl::GetTensorType()->AsTensorType();
+ case TensorProto_DataType_UINT4:
+ return DataTypeImpl::GetTensorType()->AsTensorType();
default:
ORT_NOT_IMPLEMENTED("tensor type ", type, " is not supported");
@@ -1090,6 +1116,10 @@ const SequenceTensorTypeBase* DataTypeImpl::SequenceTensorTypeFromONNXEnum(int t
return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType();
#endif
+ case TensorProto_DataType_INT4:
+ return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType();
+ case TensorProto_DataType_UINT4:
+ return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType();
default:
ORT_NOT_IMPLEMENTED("sequence tensor type ", type, " is not supported");
@@ -1183,6 +1213,8 @@ ORT_REGISTER_PRIM_TYPE(Float8E5M2);
ORT_REGISTER_PRIM_TYPE(Float8E5M2FNUZ);
#endif
+ORT_REGISTER_PRIM_SUBBYTE_TYPE(Int4x2, 2);
+ORT_REGISTER_PRIM_SUBBYTE_TYPE(UInt4x2, 2);
namespace {
template
diff --git a/onnxruntime/core/framework/element_type_lists.h b/onnxruntime/core/framework/element_type_lists.h
index 3d956ec26d22e..2478dc27162ac 100644
--- a/onnxruntime/core/framework/element_type_lists.h
+++ b/onnxruntime/core/framework/element_type_lists.h
@@ -11,6 +11,7 @@
#include "core/common/type_list.h"
#include "core/framework/float8.h"
#include "core/framework/float16.h"
+#include "core/framework/int4.h"
namespace onnxruntime {
@@ -82,6 +83,12 @@ using AllIRv9 =
using AllIRv9 = AllIRv4;
#endif
+using AllIRv10 =
+ boost::mp11::mp_push_back<
+ AllIRv9,
+ UInt4x2,
+ Int4x2>;
+
using All = AllIRv4;
#if !defined(DISABLE_FLOAT8_TYPES)
diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc
index 32a5f749af084..894e0daae94b6 100644
--- a/onnxruntime/core/framework/execution_frame.cc
+++ b/onnxruntime/core/framework/execution_frame.cc
@@ -529,17 +529,8 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
return Status(ONNXRUNTIME, FAIL, "Trying to allocate memory for unused optional inputs/outputs");
}
- size_t size;
- int64_t len = shape.Size();
- if (len < 0) {
- return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape cannot contain any negative value");
- }
- if (static_cast(len) > std::numeric_limits::max()) {
- return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape is too large");
- }
- if (!IAllocator::CalcMemSizeForArrayWithAlignment(static_cast(len), element_type->Size(), &size)) {
- return Status(ONNXRUNTIME, FAIL, "size overflow");
- }
+ size_t size = 0;
+ ORT_RETURN_IF_ERROR(Tensor::CalculateTensorStorageSize(element_type, shape, kAllocAlignment, size));
// Lazily get the allocator only if needed.
AllocatorPtr alloc = nullptr;
diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc
index 3963504273599..1370580bad4f6 100644
--- a/onnxruntime/core/framework/onnxruntime_map_type_info.cc
+++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc
@@ -78,6 +78,12 @@ ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) {
case TensorType::TensorProto_DataType_FLOAT8E5M2FNUZ: {
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
} // Non-IEEE floating-point format based on IEEE754 single-precision
+ case TensorType::TensorProto_DataType_INT4: {
+ return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4;
+ } // maps to a pair of int4 (size == 1 byte)
+ case TensorType::TensorProto_DataType_UINT4: {
+ return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4;
+ } // maps to a pair of uint4 (size == 1 byte)
default: {
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
}
@@ -126,4 +132,4 @@ ORT_API_STATUS_IMPL(OrtApis::GetMapValueType,
ORT_API(void, OrtApis::ReleaseMapTypeInfo, _Frees_ptr_opt_ OrtMapTypeInfo* ptr) {
std::unique_ptr p(ptr);
-}
\ No newline at end of file
+}
diff --git a/onnxruntime/core/framework/ort_value_tensor_slicer.cc b/onnxruntime/core/framework/ort_value_tensor_slicer.cc
index cf4020d830bff..2acc1e301d6e0 100644
--- a/onnxruntime/core/framework/ort_value_tensor_slicer.cc
+++ b/onnxruntime/core/framework/ort_value_tensor_slicer.cc
@@ -14,7 +14,13 @@ OrtValueTensorSlicer OrtValueTensorSlicer::Create(T& ort_value, int64_t sl
ORT_ENFORCE(ort_value.IsTensor(), "Can't slice a non-tensor OrtValue. Type was ", ort_value.Type());
ORT_ENFORCE(ort_value.IsAllocated(), "OrtValue has not been allocated so can't be sliced.");
- auto& tensor_shape = ort_value.template Get().Shape();
+ const Tensor& tensor = ort_value.template Get();
+ auto* prim_type = tensor.DataType()->AsPrimitiveDataType();
+ if (prim_type != nullptr) {
+ // TODO(adrianlizarraga): Support slicing Tensors of subbyte element types (e.g., int4).
+ ORT_ENFORCE(!prim_type->HasSubElems(), "Can't slice a tensor with a subbyte element type");
+ }
+ auto& tensor_shape = tensor.Shape();
ORT_ENFORCE(gsl::narrow_cast(tensor_shape.NumDimensions()) >= slice_dimension,
"Insufficient dimensions to slice on ", slice_dimension, ". Shape:", tensor_shape);
diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc
index 692ca08772535..059de8e3c8c4a 100644
--- a/onnxruntime/core/framework/session_state_utils.cc
+++ b/onnxruntime/core/framework/session_state_utils.cc
@@ -37,20 +37,10 @@ namespace session_state_utils {
// It can handle arena-based allocators and non-arena based allocators.
static common::Status AllocateBufferUsingDeviceAllocatorFromShapeAndType(const TensorShape& tensor_shape, const DataTypeImpl* type,
const AllocatorPtr& alloc, /*out*/ void*& p_data) {
- int64_t shape_size = tensor_shape.Size();
- if (shape_size < 0)
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "shape.Size() must >=0");
+ size_t mem_size = 0;
+ ORT_RETURN_IF_ERROR(Tensor::CalculateTensorStorageSize(type, tensor_shape, /*alignment*/ 0, mem_size));
- p_data = nullptr;
- if (shape_size > 0) {
- SafeInt mem_size = 0;
-
- if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), type->Size(), &mem_size)) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed memory size calculation");
- }
-
- p_data = alloc->Reserve(mem_size);
- }
+ p_data = alloc->Reserve(mem_size);
return Status::OK();
}
diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc
index 36f03a9b1046a..60d768cc59a5d 100644
--- a/onnxruntime/core/framework/tensor.cc
+++ b/onnxruntime/core/framework/tensor.cc
@@ -27,20 +27,54 @@ int64_t GetSizeFromStrides(const TensorShape& shape, gsl::span st
} // namespace
#endif
-size_t Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape) {
- int64_t shape_size = shape.Size();
- if (shape_size < 0)
- ORT_THROW("shape.Size() must >=0");
+///
+/// Get the number of elements for a Tensor of the given element type and shape size.
+///
+/// For element types smaller than 1 byte (e.g., int4), a single storage element stores multiple sub-byte elements.
+/// Example: Tensor of shape_size 4 has 2 storage elements.
+///
+/// For element types >= 1 byte, this function returns the product of the shape.
+/// Example: Tensor of shape_size 4 has 4 storage elements.
+///
+/// Data type of the tensor elements.
+/// The number of elements indicated by the shape (i.e., shape.Size()).
+/// Number of Tensor elements. Returns -1 if shape_size is negative.
+static int64_t GetNumTensorStorageElems(MLDataType elt_type, int64_t shape_size) {
+ int64_t num_elems = shape_size;
+ auto prim_type = elt_type->AsPrimitiveDataType();
+
+ if (prim_type != nullptr && num_elems > 0 && prim_type->HasSubElems()) {
+ const int64_t num_sub_elems = prim_type->GetNumSubElems();
+ num_elems = (num_elems + (num_sub_elems - 1)) / num_sub_elems;
+ }
- if (shape_size > 0) {
- SafeInt len = 0;
- if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), elt_type->Size(), &len))
- ORT_THROW("tensor failed memory size calculation");
+ return num_elems;
+}
+
+Status Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape, size_t alignment,
+ /*out*/ size_t& storage_size) {
+ int64_t num_elems = GetNumTensorStorageElems(elt_type, shape.Size());
+ ORT_RETURN_IF(num_elems < 0, "Tensor shape.Size() must be >= 0");
- return len;
+ if (num_elems > 0) {
+ if (static_cast(num_elems) > std::numeric_limits::max()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape is too large");
+ }
+ if (!IAllocator::CalcMemSizeForArrayWithAlignment(static_cast(num_elems), elt_type->Size(), alignment,
+ &storage_size)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Calculation for Tensor storage size overflowed");
+ }
+ } else {
+ storage_size = 0;
}
- return 0;
+ return Status::OK();
+}
+
+size_t Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape) {
+ size_t storage_size = 0;
+ ORT_THROW_IF_ERROR(CalculateTensorStorageSize(elt_type, shape, 0, storage_size));
+ return storage_size;
}
Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location,
@@ -98,14 +132,19 @@ void Tensor::InitOrtValue(Tensor&& tensor, OrtValue& ort_value) {
ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc());
}
-size_t Tensor::SizeInBytes() const {
+int64_t Tensor::NumStorageElements() const {
#ifdef ENABLE_STRIDED_TENSORS
int64_t size = IsContiguous() ? shape_.Size() : GetSizeFromStrides(shape_, strides_);
#else
int64_t size = shape_.Size();
#endif
- size_t ret;
- if (!IAllocator::CalcMemSizeForArray(SafeInt(size), dtype_->Size(), &ret)) {
+
+ return GetNumTensorStorageElems(dtype_, size);
+}
+
+size_t Tensor::SizeInBytes() const {
+ size_t ret = 0;
+ if (!IAllocator::CalcMemSizeForArray(SafeInt(NumStorageElements()), dtype_->Size(), &ret)) {
ORT_THROW("tensor size overflow");
}
return ret;
@@ -138,6 +177,8 @@ void Tensor::Init(MLDataType elt_type, const TensorShape& shape, void* p_raw_dat
ORT_ENFORCE(shape.NumDimensions() == strides.size(), "Length of strides doesn't match tensor dimension size.");
strides_.assign(strides.begin(), strides.end());
is_contiguous_ = CheckIsContiguous();
+ ORT_ENFORCE(is_contiguous_ || !dtype_->HasSubElems(),
+ "Do not support subbyte element types with non-contiguous strided tensors.");
}
#else
ORT_UNUSED_PARAMETER(strides);
@@ -254,6 +295,8 @@ void Tensor::SetShapeAndStrides(const TensorShape& new_shape, gsl::spanHasSubElems(),
+ "Do not support subbyte element types with non-contiguous strided tensors.");
}
#endif
diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc
index 6e11bfe1ac8ea..418e46924fb9f 100644
--- a/onnxruntime/core/framework/tensor_type_and_shape.cc
+++ b/onnxruntime/core/framework/tensor_type_and_shape.cc
@@ -164,6 +164,12 @@ constexpr ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementData
case o::TensorProto_DataType_BOOL:
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
break;
+ case o::TensorProto_DataType_INT4:
+ type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4;
+ break;
+ case o::TensorProto_DataType_UINT4:
+ type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4;
+ break;
default:
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
break;
@@ -365,4 +371,4 @@ ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo,
*out = ptr.release();
return nullptr;
API_IMPL_END
-}
\ No newline at end of file
+}
diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc
index 8a2db6d5728af..6af78f18fb82f 100644
--- a/onnxruntime/core/framework/tensorprotoutils.cc
+++ b/onnxruntime/core/framework/tensorprotoutils.cc
@@ -59,6 +59,20 @@ TensorProto ToScalarTensor(TensorProto_DataType datatype, int32_t value) {
return t; \
}
+#define TO_TENSOR_ORT_TYPE_INT4(TYPE) \
+ template <> \
+ TensorProto ToTensor(const onnxruntime::TYPE& value) { \
+ return ToScalarTensor(ToTensorProtoElementType(), static_cast(value.ToBits())); \
+ } \
+ template <> \
+ TensorProto ToTensor(const std::vector& values) { \
+ TensorProto t = ToTensorInitialize(ToTensorProtoElementType()); \
+ for (const onnxruntime::TYPE& val : values) { \
+ t.add_int32_data(static_cast(val.ToBits())); \
+ } \
+ return t; \
+ }
+
namespace ONNX_NAMESPACE {
// Provide template specializations for onnxruntime-specific types.
@@ -70,6 +84,8 @@ TO_TENSOR_ORT_TYPE(Float8E4M3FNUZ)
TO_TENSOR_ORT_TYPE(Float8E5M2)
TO_TENSOR_ORT_TYPE(Float8E5M2FNUZ)
#endif
+TO_TENSOR_ORT_TYPE_INT4(Int4x2)
+TO_TENSOR_ORT_TYPE_INT4(UInt4x2)
bool operator==(const ONNX_NAMESPACE::TensorShapeProto_Dimension& l,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& r) {
@@ -125,6 +141,29 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t
reinterpret_cast(p_data));
}
+#define DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(INT4_TYPE) \
+ template <> \
+ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, \
+ /*out*/ INT4_TYPE* p_data) { \
+ static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \
+ \
+ ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \
+ \
+ size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \
+ ORT_RETURN_IF_NOT(num_packed_pairs == raw_data_len, "Unexpected number of packed int4 pairs"); \
+ \
+ gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), \
+ num_packed_pairs); \
+ gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); \
+ \
+ std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \
+ \
+ return Status::OK(); \
+ }
+
+DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2)
+DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2)
+
static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
const ORTCHAR_T* tensor_proto_dir,
std::basic_string& external_file_path,
@@ -261,6 +300,32 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor,
reinterpret_cast(p_data));
}
+#define DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(INT4_TYPE) \
+ template <> \
+ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, \
+ const ORTCHAR_T* tensor_proto_dir, size_t expected_num_elements, \
+ /*out*/ INT4_TYPE* p_data) { \
+ static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \
+ \
+ ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \
+ std::vector unpacked_tensor; \
+ ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); \
+ \
+ size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \
+ ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); \
+ \
+ gsl::span src_span = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), \
+ num_packed_pairs); \
+ gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); \
+ \
+ std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \
+ \
+ return Status::OK(); \
+ }
+
+DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Int4x2)
+DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(UInt4x2)
+
#define INSTANTIATE_UNPACK_EXTERNAL_TENSOR(type) \
template Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto&, const ORTCHAR_T*, size_t, type*);
@@ -602,6 +667,40 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d
#endif
+#define DEFINE_INT4_UNPACK_TENSOR_IMPL(INT4_TYPE, ONNX_INT4_TYPE) \
+ template <> \
+ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \
+ /*out*/ INT4_TYPE* p_data, size_t expected_num_elems) { \
+ if (nullptr == p_data) { \
+ const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); \
+ return size == 0 ? Status::OK() : Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
+ } \
+ if (ONNX_NAMESPACE::ONNX_INT4_TYPE != tensor.data_type()) { \
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
+ } \
+ \
+ size_t expected_int4_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elems); \
+ \
+ if (raw_data != nullptr) { \
+ return UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); \
+ } \
+ \
+ ORT_RETURN_IF_NOT(static_cast(tensor.int32_data_size()) == expected_int4_pairs, \
+ "UnpackTensor: the pre-allocated size does not match the size in proto"); \
+ \
+ for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { \
+ p_data[i] = INT4_TYPE(static_cast(tensor.int32_data()[i])); \
+ } \
+ \
+ return Status::OK(); \
+ }
+
+// UnpackTensor
+DEFINE_INT4_UNPACK_TENSOR_IMPL(Int4x2, TensorProto_DataType_INT4)
+
+// UnpackTensor
+DEFINE_INT4_UNPACK_TENSOR_IMPL(UInt4x2, TensorProto_DataType_UINT4)
+
// UnpackTensor from raw data, external data or the type specific data field.
// Uses the model path to construct the full path for loading external data. In case when model_path is empty
// it uses current directory.
@@ -651,6 +750,8 @@ INSTANTIATE_UNPACK_TENSOR(Float8E4M3FNUZ)
INSTANTIATE_UNPACK_TENSOR(Float8E5M2)
INSTANTIATE_UNPACK_TENSOR(Float8E5M2FNUZ)
#endif
+INSTANTIATE_UNPACK_TENSOR(Int4x2)
+INSTANTIATE_UNPACK_TENSOR(UInt4x2)
#define CASE_PROTO_TRACE(X, Y) \
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \
@@ -659,6 +760,13 @@ INSTANTIATE_UNPACK_TENSOR(Float8E5M2FNUZ)
} \
break;
+#define CASE_PROTO_TRACE_INT4(X, Y) \
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \
+ if (!IAllocator::CalcMemSizeForArrayWithAlignment(Y::CalcNumInt4Pairs(size), sizeof(Y), out)) { \
+ return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid TensorProto"); \
+ } \
+ break;
+
template
common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out) {
const auto& dims = tensor_proto.dims();
@@ -692,6 +800,8 @@ common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto&
CASE_PROTO_TRACE(FLOAT8E5M2, Float8E5M2);
CASE_PROTO_TRACE(FLOAT8E5M2FNUZ, Float8E5M2FNUZ);
#endif
+ CASE_PROTO_TRACE_INT4(UINT4, UInt4x2);
+ CASE_PROTO_TRACE_INT4(INT4, Int4x2);
default:
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
}
@@ -998,6 +1108,8 @@ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path,
CASE_PROTO(FLOAT8E5M2, Float8E5M2);
CASE_PROTO(FLOAT8E5M2FNUZ, Float8E5M2FNUZ);
#endif
+ CASE_PROTO(INT4, Int4x2);
+ CASE_PROTO(UINT4, UInt4x2);
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_STRING:
ORT_RETURN_IF_ERROR(UnpackTensor(tensor_proto, raw_data, raw_data_len,
static_cast(preallocated),
@@ -1053,6 +1165,8 @@ ONNXTensorElementDataType CApiElementTypeFromProtoType(int type) {
CASE_TYPE(FLOAT8E5M2)
CASE_TYPE(FLOAT8E5M2FNUZ)
#endif
+ CASE_TYPE(UINT4)
+ CASE_TYPE(INT4)
default:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
}
@@ -1570,6 +1684,20 @@ template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::T
break; \
}
+#define CASE_UNPACK_INT4(TYPE, ELEMENT_TYPE, DATA_SIZE) \
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \
+ TensorShape tensor_shape = GetTensorShapeFromTensorProto(initializer); \
+ size_t element_count = static_cast(tensor_shape.Size()); \
+ size_t packed_element_count = ELEMENT_TYPE::CalcNumInt4Pairs(element_count); \
+ unpacked_tensor.resize(packed_element_count * sizeof(ELEMENT_TYPE)); \
+ return onnxruntime::utils::UnpackTensor( \
+ initializer, \
+ initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \
+ initializer.has_raw_data() ? initializer.raw_data().size() : 0, \
+ reinterpret_cast(unpacked_tensor.data()), element_count); \
+ break; \
+ }
+
Status UnpackInitializerData(const onnx::TensorProto& initializer,
const Path& model_path,
std::vector& unpacked_tensor) {
@@ -1604,6 +1732,8 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer,
CASE_UNPACK(FLOAT8E5M2, onnxruntime::Float8E5M2, int32_data_size);
CASE_UNPACK(FLOAT8E5M2FNUZ, onnxruntime::Float8E5M2FNUZ, int32_data_size);
#endif
+ CASE_UNPACK_INT4(INT4, Int4x2, int32_data_size);
+ CASE_UNPACK_INT4(UINT4, UInt4x2, int32_data_size);
default:
break;
}
diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h
index f0b1b9109d405..fd14eeeb33d27 100644
--- a/onnxruntime/core/framework/utils.h
+++ b/onnxruntime/core/framework/utils.h
@@ -222,6 +222,16 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType
#endif
+template <>
+constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {
+ return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4;
+}
+
+template <>
+constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {
+ return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4;
+}
+
int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType);
#ifdef ENABLE_TRAINING
diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
index 47f61a43458ed..762d892c45ce8 100644
--- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
@@ -164,7 +164,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"T2", OpSchema::Optional)
.Output(0, "y", "N-D quantized output tensor. It has same shape as input 'x'.", "T2")
.TypeConstraint("T1", {"tensor(float16)", "tensor(float)"}, "Constrain 'x', 'y_scale' to float tensors.")
- .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", "tensor(uint16)"},
+ .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", "tensor(uint16)", "tensor(int4)",
+ "tensor(uint4)"},
"Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.")
.SetDoc(QuantizeLinear_ver1_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
@@ -206,7 +207,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(DequantizeLinear, 1,
.Output(0, "y", "N-D full precision output tensor. It has same shape as input 'x'.",
"T2")
.TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)", "tensor(int16)",
- "tensor(uint16)", "tensor(int32)"},
+ "tensor(uint16)", "tensor(int32)", "tensor(int4)",
+ "tensor(uint4)"},
"Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, "
"16-bit integer tensors, or 32-bit signed integer tensors.")
.TypeConstraint("T2", {"tensor(float16)", "tensor(float)"},
diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
index ce7838556fbf0..cdfd283899c8c 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -1222,6 +1222,26 @@ MlasQuantizeLinear(
OutputType ZeroPoint
);
+void
+MLASCALL
+MlasQuantizeLinearU4(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ );
+
+void
+MLASCALL
+MlasQuantizeLinearS4(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ );
+
/**
* @brief Requantize a block of the intermediate buffer to the output buffer,
* optionally adding the supplied bias
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
index 04da9ab4fd749..83200187963e1 100644
--- a/onnxruntime/core/mlas/lib/mlasi.h
+++ b/onnxruntime/core/mlas/lib/mlasi.h
@@ -680,6 +680,24 @@ void
float Scale,
int16_t ZeroPoint);
+typedef
+void
+(MLASCALL MLAS_QUANTIZE_LINEAR_U4_KERNEL)(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint);
+
+typedef
+void
+(MLASCALL MLAS_QUANTIZE_LINEAR_S4_KERNEL)(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint);
+
template
struct MLAS_QUANT_KERNEL
{
@@ -826,6 +844,8 @@ extern "C" {
MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8Kernel;
MLAS_QUANTIZE_LINEAR_S16_KERNEL MlasQuantizeLinearS16Kernel;
MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel;
+ MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel;
+ MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel;
#if defined(MLAS_TARGET_AMD64)
MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3;
MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3;
@@ -1083,6 +1103,8 @@ struct MLAS_PLATFORM {
MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel;
MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel;
MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel;
+ MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel;
+ MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel;
#endif
#if defined(MLAS_TARGET_AMD64)
MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine;
@@ -1112,6 +1134,8 @@ struct MLAS_PLATFORM {
MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel;
MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel;
MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel;
+ MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel;
+ MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel;
uint32_t NchwcBlockSize;
uint32_t PreferredBufferAlignment;
int32_t MaximumThreadCount;
@@ -2497,3 +2521,51 @@ MlasThreadedBufAlloc(size_t size)
ThreadedBufSize = size;
}
}
+
+//
+// Utilities for INT4 quantization.
+//
+
+template
+struct Int4Traits;
+
+template<>
+struct Int4Traits {
+ using UnpackedType = int8_t;
+ static constexpr int8_t Min = -8;
+ static constexpr int8_t Max = 7;
+};
+
+template<>
+struct Int4Traits {
+ using UnpackedType = uint8_t;
+ static constexpr int8_t Min = 0;
+ static constexpr int8_t Max = 15;
+};
+
+template
+MLAS_FORCEINLINE
+void
+MlasSetInt4Element(uint8_t* Output, size_t ElemIndex, UnpackedType Value)
+{
+ static_assert(std::is_same_v || std::is_same_v);
+
+ const size_t OutputIndex = ElemIndex >> 1; // which byte
+ const size_t NibbleIndex = ElemIndex & 0x1; // which 4-bit elem in the byte
+ const uint8_t Shift = static_cast(NibbleIndex << 2); // Either 0 or 4
+ const uint8_t Mask = static_cast(0xF0 >> Shift);
+ uint8_t* Dst = &Output[OutputIndex];
+
+ *Dst &= Mask; // Clear 4-bit lane
+ *Dst |= static_cast((Value & 0xF) << Shift); // Set 4-bit lane
+}
+
+template
+MLAS_FORCEINLINE
+void
+MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueHigh)
+{
+ static_assert(std::is_same_v || std::is_same_v);
+ *Output = static_cast(((ValueHigh & 0xF) << 4) | (ValueLow & 0xF));
+}
+
diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp
index 3f86b3f7c5062..72eb35c894094 100644
--- a/onnxruntime/core/mlas/lib/platform.cpp
+++ b/onnxruntime/core/mlas/lib/platform.cpp
@@ -274,6 +274,8 @@ Return Value:
this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel;
this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel;
this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel;
+ this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel;
+ this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel;
this->NchwcBlockSize = 8;
this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT;
@@ -545,6 +547,8 @@ Return Value:
this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel;
this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel;
this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel;
+ this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel;
+ this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel;
#if defined(__linux__)
unsigned long hwcap2 = getauxval(AT_HWCAP2);
diff --git a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp
index 1fed8af21b31c..0cfa56740edfb 100644
--- a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp
+++ b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp
@@ -107,6 +107,119 @@ Return Value:
}
}
+template
+void
+MLASCALL
+MlasQuantizeLinearInt4Kernel(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+/*++
+
+Routine Description:
+
+ This routine quantizes the input buffer as int4 using the supplied quantization
+ parameters.
+
+Arguments:
+
+ Input - Supplies the input buffer.
+
+ Output - Supplies the output buffer. Contains packed 4-bit elements.
+
+ N - Supplies the number of elements to process.
+
+ Scale - Supplies the quantization scale.
+
+ ZeroPoint - Supplies the quantization zero point value.
+
+Return Value:
+
+ None.
+
+--*/
+{
+ constexpr int32_t MinimumValue = Int4Traits::Min;
+ constexpr int32_t MaximumValue = Int4Traits::Max;
+ using UnpackedType = typename Int4Traits::UnpackedType;
+
+ auto ScaleVector = vec_splats(Scale);
+ auto MinimumValueVector = vec_splats(float(MinimumValue));
+ auto MaximumValueVector = vec_splats(float(MaximumValue));
+ auto ZeroPointVector = vec_splats(float(ZeroPoint));
+
+ // Holds 16 quantized 8-bit values that will be packed into the output as packed 4-bit values.
+ UnpackedType TmpOutput[16] = {};
+
+ while (N >= 16) {
+ auto FloatVector0 = vec_xl(0, Input);
+ auto FloatVector1 = vec_xl(0, Input + 4);
+ auto FloatVector2 = vec_xl(0, Input + 8);
+ auto FloatVector3 = vec_xl(0, Input + 12);
+
+ FloatVector0 = vec_div(FloatVector0, ScaleVector);
+ FloatVector1 = vec_div(FloatVector1, ScaleVector);
+ FloatVector2 = vec_div(FloatVector2, ScaleVector);
+ FloatVector3 = vec_div(FloatVector3, ScaleVector);
+
+ FloatVector0 = vec_round(FloatVector0);
+ FloatVector1 = vec_round(FloatVector1);
+ FloatVector2 = vec_round(FloatVector2);
+ FloatVector3 = vec_round(FloatVector3);
+
+ FloatVector0 = vec_add(FloatVector0, ZeroPointVector);
+ FloatVector1 = vec_add(FloatVector1, ZeroPointVector);
+ FloatVector2 = vec_add(FloatVector2, ZeroPointVector);
+ FloatVector3 = vec_add(FloatVector3, ZeroPointVector);
+
+ FloatVector0 = vec_max(FloatVector0, MinimumValueVector);
+ FloatVector1 = vec_max(FloatVector1, MinimumValueVector);
+ FloatVector2 = vec_max(FloatVector2, MinimumValueVector);
+ FloatVector3 = vec_max(FloatVector3, MinimumValueVector);
+
+ FloatVector0 = vec_min(FloatVector0, MaximumValueVector);
+ FloatVector1 = vec_min(FloatVector1, MaximumValueVector);
+ FloatVector2 = vec_min(FloatVector2, MaximumValueVector);
+ FloatVector3 = vec_min(FloatVector3, MaximumValueVector);
+
+ auto IntegerVector0 = vec_signed(FloatVector0);
+ auto IntegerVector1 = vec_signed(FloatVector1);
+ auto IntegerVector2 = vec_signed(FloatVector2);
+ auto IntegerVector3 = vec_signed(FloatVector3);
+
+ auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1);
+ auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3);
+
+ auto CharVector = vec_pack(ShortVector0, ShortVector1);
+ vec_xst(CharVector, 0, static_cast(&TmpOutput[0]));
+
+ MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]);
+ MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]);
+ MlasPackInt4Elements(Output++, TmpOutput[4], TmpOutput[5]);
+ MlasPackInt4Elements(Output++, TmpOutput[6], TmpOutput[7]);
+ MlasPackInt4Elements(Output++, TmpOutput[8], TmpOutput[9]);
+ MlasPackInt4Elements(Output++, TmpOutput[10], TmpOutput[11]);
+ MlasPackInt4Elements(Output++, TmpOutput[12], TmpOutput[13]);
+ MlasPackInt4Elements(Output++, TmpOutput[14], TmpOutput[15]);
+
+ Input += 16;
+ N -= 16;
+ }
+
+ for (size_t n = 0; n < N; n++) {
+
+ float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint);
+ FloatValue = std::max(FloatValue, static_cast(MinimumValue));
+ FloatValue = std::min(FloatValue, static_cast(MaximumValue));
+ UnpackedType IntValue = static_cast(FloatValue);
+
+ MlasSetInt4Element(Output, n, IntValue);
+ }
+}
+
void
MLASCALL
MlasQuantizeLinearU8Kernel(
@@ -159,3 +272,29 @@ MlasQuantizeLinearS16Kernel(
MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint);
}
+void
+MLASCALL
+MlasQuantizeLinearU4Kernel(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+{
+ MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint);
+}
+
+void
+MLASCALL
+MlasQuantizeLinearS4Kernel(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+{
+ MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint);
+}
+
diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp
index ffecc2dbeff9e..ae638fafee18f 100644
--- a/onnxruntime/core/mlas/lib/quantize.cpp
+++ b/onnxruntime/core/mlas/lib/quantize.cpp
@@ -519,6 +519,87 @@ Return Value:
}
}
+template
+void
+MLASCALL
+MlasQuantizeLinearInt4Kernel(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+{
+ constexpr int32_t MinimumValue = Int4Traits::Min;
+ constexpr int32_t MaximumValue = Int4Traits::Max;
+ using UnpackedType = typename Int4Traits::UnpackedType;
+
+ auto ScaleVector = MlasBroadcastFloat32x4(Scale);
+ auto MinimumValueVector = MlasBroadcastFloat32x4(static_cast(MinimumValue - ZeroPoint));
+ auto MaximumValueVector = MlasBroadcastFloat32x4(static_cast(MaximumValue - ZeroPoint));
+ auto ZeroPointVector = MlasBroadcastInt32x4(ZeroPoint);
+
+ // Holds 4 quantized 8bit values that will be packed into the output as packed 4bit values.
+ UnpackedType TmpOutput[4] = {};
+
+ while (N >= 4) {
+
+ auto FloatVector = MlasLoadFloat32x4(Input);
+ auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector,
+ MinimumValueVector, MaximumValueVector, ZeroPointVector);
+
+ IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector);
+ MlasQuantizeLinearStore4PackedValues(IntegerVector, &TmpOutput[0]);
+ MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]);
+ MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]);
+
+ Input += 4;
+ N -= 4;
+ }
+
+ for (size_t n = 0; n < N; n++) {
+
+#if defined(MLAS_NEON64_INTRINSICS)
+ auto FloatVector = vld1q_dup_f32(Input + n);
+#elif defined(MLAS_LSX_INTRINSICS)
+ MLAS_FLOAT32X4 FloatVector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input+n, 0);
+#else
+ auto FloatVector = _mm_load_ss(Input + n);
+#endif
+ auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector,
+ MinimumValueVector, MaximumValueVector, ZeroPointVector);
+
+ MlasQuantizeLinearStoreSingleValue(IntegerVector, &TmpOutput[0]);
+ MlasSetInt4Element(Output, n, TmpOutput[0]);
+ }
+}
+
+void
+MLASCALL
+MlasQuantizeLinearS4Kernel(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+{
+ MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint);
+}
+
+void
+MLASCALL
+MlasQuantizeLinearU4Kernel(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+{
+ MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint);
+}
+
void
MLASCALL
MlasQuantizeLinearS8Kernel(
@@ -571,6 +652,42 @@ MlasQuantizeLinearS16Kernel(
MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint);
}
+void
+MLASCALL
+MlasQuantizeLinearS4(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+{
+#if defined(MLAS_TARGET_AMD64)
+ GetMlasPlatform().QuantizeLinearS4Kernel(
+#else
+ MlasQuantizeLinearS4Kernel(
+#endif
+ Input, Output, N, Scale, ZeroPoint);
+}
+
+void
+MLASCALL
+MlasQuantizeLinearU4(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+{
+#if defined(MLAS_TARGET_AMD64)
+ GetMlasPlatform().QuantizeLinearU4Kernel(
+#else
+ MlasQuantizeLinearU4Kernel(
+#endif
+ Input, Output, N, Scale, ZeroPoint);
+}
+
template<>
void
MLASCALL
@@ -707,6 +824,31 @@ MlasQuantizeLinear(
GetMlasPlatform().QuantizeLinearU16Kernel(Input, Output, N, Scale, ZeroPoint);
}
+void
+MLASCALL
+MlasQuantizeLinearS4(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+{
+ GetMlasPlatform().QuantizeLinearS4Kernel(Input, Output, N, Scale, ZeroPoint);
+}
+
+void
+MLASCALL
+MlasQuantizeLinearU4(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+{
+ GetMlasPlatform().QuantizeLinearU4Kernel(Input, Output, N, Scale, ZeroPoint);
+}
#endif
//
@@ -805,6 +947,58 @@ MlasQuantizeLinear(
uint16_t ZeroPoint
);
+template
+void
+MLASCALL
+MlasQuantizeLinearInt4(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+{
+ constexpr int32_t MinimumValue = Int4Traits::Min;
+ constexpr int32_t MaximumValue = Int4Traits::Max;
+ using UnpackedType = typename Int4Traits::UnpackedType;
+
+ for (size_t n = 0; n < N; n++) {
+ float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint);
+ FloatValue = std::max(FloatValue, static_cast(MinimumValue));
+ FloatValue = std::min(FloatValue, static_cast(MaximumValue));
+ UnpackedType IntValue = static_cast(FloatValue);
+
+ MlasSetInt4Element(Output, n, IntValue);
+ }
+}
+
+// QuantizeLinear INT4 implementation using the C++ runtime.
+void
+MLASCALL
+MlasQuantizeLinearS4(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+{
+ MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint);
+}
+
+// QuantizeLinear UINT4 implementation using the C++ runtime.
+void
+MLASCALL
+MlasQuantizeLinearU4(
+ const float* Input,
+ uint8_t* Output,
+ size_t N,
+ float Scale,
+ int8_t ZeroPoint
+ )
+{
+ MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint);
+}
#endif
#endif
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
index 6b4f62ae1343d..09705f61c82ce 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
@@ -20,6 +20,11 @@ constexpr bool Is16BitIntType(int32_t data_type) {
(data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16);
}
+constexpr bool Is4BitIntType(int32_t data_type) {
+ return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4) ||
+ (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4);
+}
+
// adjust for an optional input/output that has an entry but does not exist
int NumActualValues(const Node& node, bool input) {
const auto& defs = input ? node.InputDefs() : node.OutputDefs();
@@ -134,6 +139,10 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}
+ if (!allow_4bit_ && Is4BitIntType(dt_input)) {
+ return false;
+ }
+
const Node& dq_node = *dq_nodes.front();
const Node& q_node = *q_nodes.front();
@@ -167,6 +176,10 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}
+ if (!allow_4bit_ && Is4BitIntType(dt_input)) {
+ return false;
+ }
+
auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) {
return graph_viewer.GetConstantInitializer(initializer_name, true);
};
@@ -193,6 +206,10 @@ bool UnaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node&
return false;
}
+ if (!allow_4bit_ && Is4BitIntType(dt_input)) {
+ return false;
+ }
+
return true;
}
@@ -218,6 +235,10 @@ bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}
+ if (!allow_4bit_ && Is4BitIntType(dt_input_1)) {
+ return false;
+ }
+
return true;
}
@@ -253,6 +274,10 @@ bool VariadicNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}
+ if (!allow_4bit_ && Is4BitIntType(dt_input)) {
+ return false;
+ }
+
return true;
}
@@ -275,6 +300,10 @@ bool SplitNodeGroupSelector::Check(const GraphViewer& graph_viewer,
const Node& dq_node = *dq_nodes.front();
int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
+ if (!allow_4bit_ && Is4BitIntType(dt_input)) {
+ return false;
+ }
+
// All Q outputs should have same data type and (optionally) equal quantization parameters as the input.
for (size_t q_idx = 0; q_idx < q_nodes.size(); q_idx++) {
const Node& q_node = *q_nodes[q_idx];
@@ -312,6 +341,10 @@ bool ConvNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}
+ if (!allow_4bit_weight_ && Is4BitIntType(dt_weight)) {
+ return false;
+ }
+
if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
if (!int8_allowed_ || dt_weight != dt_input) {
return false;
@@ -359,6 +392,11 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}
+ // 4-bit int types must be explicitly allowed.
+ if (!allow_4bit_ && (Is4BitIntType(dt_input) || Is4BitIntType(dt_weight))) {
+ return false;
+ }
+
// potential match for QLinearMatMul or MatMulIntegerToFloat
bool qlinear = !q_nodes.empty();
@@ -407,6 +445,10 @@ bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return false;
}
+ if (!allow_4bit_ && (Is4BitIntType(dt_A) || Is4BitIntType(dt_B))) {
+ return false;
+ }
+
if (dq_nodes.size() < 3) { // no bias
return true;
}
@@ -445,6 +487,10 @@ bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node&
return false;
}
+ if (!allow_4bit_ && Is4BitIntType(dt_input_1)) {
+ return false;
+ }
+
return true;
}
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
index 5d550669e2e86..1a2a620acb480 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
@@ -5,6 +5,7 @@
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
+#include
#include "core/framework/node_unit.h"
#include "core/optimizer/selectors_actions/selector_action_transformer.h"
@@ -47,7 +48,8 @@ class NodeGroupSelector {
// Zero point and scale are constant scalars and must match
class DropQDQNodeGroupSelector : public NodeGroupSelector {
public:
- explicit DropQDQNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {}
+ explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true)
+ : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
@@ -55,12 +57,14 @@ class DropQDQNodeGroupSelector : public NodeGroupSelector {
const std::vector& q_nodes) const override;
bool allow_16bit_;
+ bool allow_4bit_;
};
// Single DQ -> node.
class DropDQNodeGroupSelector : public NodeGroupSelector {
public:
- explicit DropDQNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {}
+ explicit DropDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true)
+ : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
@@ -68,12 +72,14 @@ class DropDQNodeGroupSelector : public NodeGroupSelector {
const std::vector& q_nodes) const override;
bool allow_16bit_;
+ bool allow_4bit_;
};
// single input. default is to only support uint8.
class UnaryNodeGroupSelector : public NodeGroupSelector {
public:
- explicit UnaryNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {}
+ explicit UnaryNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true)
+ : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
@@ -81,12 +87,14 @@ class UnaryNodeGroupSelector : public NodeGroupSelector {
const std::vector& q_nodes) const override;
bool allow_16bit_;
+ bool allow_4bit_;
};
// 2 DQ nodes providing input -> node -> Q
class BinaryNodeGroupSelector : public NodeGroupSelector {
public:
- explicit BinaryNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {}
+ explicit BinaryNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true)
+ : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
@@ -94,12 +102,14 @@ class BinaryNodeGroupSelector : public NodeGroupSelector {
const std::vector& q_nodes) const override;
bool allow_16bit_;
+ bool allow_4bit_;
};
// Variadic DQ nodes -> node -> Q
class VariadicNodeGroupSelector : public NodeGroupSelector {
public:
- explicit VariadicNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {}
+ explicit VariadicNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true)
+ : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
@@ -107,6 +117,7 @@ class VariadicNodeGroupSelector : public NodeGroupSelector {
const std::vector& q_nodes) const override;
bool allow_16bit_;
+ bool allow_4bit_;
};
// DQ node -> Split -> multiple Q nodes with equal quantization types.
@@ -114,8 +125,8 @@ class VariadicNodeGroupSelector : public NodeGroupSelector {
// equal and constant.
class SplitNodeGroupSelector : public NodeGroupSelector {
public:
- explicit SplitNodeGroupSelector(bool req_equal_quant_params = false)
- : req_equal_quant_params_(req_equal_quant_params) {}
+ explicit SplitNodeGroupSelector(bool req_equal_quant_params = false, bool allow_4bit = true)
+ : req_equal_quant_params_(req_equal_quant_params), allow_4bit_(allow_4bit) {}
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
@@ -125,14 +136,15 @@ class SplitNodeGroupSelector : public NodeGroupSelector {
bool req_equal_quant_params_; // If true, only selects a node group if the input and output
// quantization parameters are all equal/constant, which enables the
// optimizer to drop the Q/DQ ops if the group is assigned to the CPU EP.
+ bool allow_4bit_;
};
// DQ nodes for X, W and optionally B -> node -> Q
class ConvNodeGroupSelector : public NodeGroupSelector {
public:
// default to 'true'
- ConvNodeGroupSelector(bool int8_allowed = true, bool allow_16bit = true)
- : int8_allowed_(int8_allowed), allow_16bit_(allow_16bit) {}
+ ConvNodeGroupSelector(bool int8_allowed = true, bool allow_16bit = true, bool allow_4bit_weight = true)
+ : int8_allowed_(int8_allowed), allow_16bit_(allow_16bit), allow_4bit_weight_(allow_4bit_weight) {}
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
@@ -141,12 +153,13 @@ class ConvNodeGroupSelector : public NodeGroupSelector {
bool int8_allowed_;
bool allow_16bit_;
+ bool allow_4bit_weight_;
};
class WhereNodeGroupSelector : public NodeGroupSelector {
public:
- explicit WhereNodeGroupSelector(bool allow_16bit = true)
- : allow_16bit_(allow_16bit) {}
+ explicit WhereNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true)
+ : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
@@ -154,6 +167,7 @@ class WhereNodeGroupSelector : public NodeGroupSelector {
const std::vector& q_nodes) const override;
bool allow_16bit_;
+ bool allow_4bit_;
};
class PadNodeGroupSelector : public NodeGroupSelector {
@@ -172,10 +186,12 @@ class MatMulNodeGroupSelector : public NodeGroupSelector {
public:
MatMulNodeGroupSelector(bool int8_allowed = true,
bool matmulintegertofloat_allowed = false,
- bool allow_16bit = true)
+ bool allow_16bit = true,
+ bool allow_4bit = true)
: int8_allowed_(int8_allowed),
matmulintegertofloat_allowed_(matmulintegertofloat_allowed),
- allow_16bit_(allow_16bit) {
+ allow_16bit_(allow_16bit),
+ allow_4bit_(allow_4bit) {
}
private:
@@ -185,13 +201,15 @@ class MatMulNodeGroupSelector : public NodeGroupSelector {
bool int8_allowed_;
bool matmulintegertofloat_allowed_;
bool allow_16bit_;
+ bool allow_4bit_;
};
// Input: DQ nodes for A, B and optional C
// Output: optional Q node for Y
class GemmNodeGroupSelector : public NodeGroupSelector {
public:
- explicit GemmNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {}
+ explicit GemmNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true)
+ : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
@@ -199,6 +217,7 @@ class GemmNodeGroupSelector : public NodeGroupSelector {
const std::vector& q_nodes) const override;
bool allow_16bit_;
+ bool allow_4bit_;
};
// Input: DQ nodes for input, scale, and B
@@ -273,33 +292,35 @@ class BaseSelector : public NodeSelector {
class DropQDQNodesSelector : public BaseSelector {
public:
- explicit DropQDQNodesSelector(bool allow_16bit = false)
- : BaseSelector(std::make_unique(allow_16bit)) {}
+ explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false)
+ : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {}
};
class DropDQNodesSelector : public BaseSelector {
public:
- explicit DropDQNodesSelector(bool allow_16bit = false)
- : BaseSelector(std::make_unique(allow_16bit)) {}
+ explicit DropDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false)
+ : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {}
};
class UnarySelector : public BaseSelector {
public:
- explicit UnarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false)
- : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {}
+ explicit UnarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false,
+ bool allow_4bit = false)
+ : BaseSelector(std::make_unique(allow_16bit, allow_4bit), compatible_providers) {}
};
class BinarySelector : public BaseSelector {
public:
- explicit BinarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false)
- : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {}
+ explicit BinarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false,
+ bool allow_4bit = false)
+ : BaseSelector(std::make_unique(allow_16bit, allow_4bit), compatible_providers) {}
};
// Variadic DQ nodes -> node -> Q
class InputVariadicSelector : public BaseSelector {
public:
- explicit InputVariadicSelector(bool allow_16bit = false)
- : BaseSelector(std::make_unique(allow_16bit)) {}
+ explicit InputVariadicSelector(bool allow_16bit = false, bool allow_4bit = false)
+ : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {}
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};
@@ -307,8 +328,8 @@ class InputVariadicSelector : public BaseSelector {
// DQ -> Split -> variadic Q nodes
class SplitSelector : public BaseSelector {
public:
- SplitSelector(bool req_equal_quant_params = false)
- : BaseSelector(std::make_unique(req_equal_quant_params)) {}
+ SplitSelector(bool req_equal_quant_params = false, bool allow_4bit = false)
+ : BaseSelector(std::make_unique(req_equal_quant_params, allow_4bit)) {}
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};
@@ -316,32 +337,34 @@ class SplitSelector : public BaseSelector {
// DQ nodes for X, W and optionally B -> node -> Q
class ConvSelector : public BaseSelector {
public:
- ConvSelector(bool int8_allowed = false, bool allow_16bit = false)
- : BaseSelector(std::make_unique(int8_allowed, allow_16bit)) {}
+ ConvSelector(bool int8_allowed = false, bool allow_16bit = false, bool allow_4bit_weight = false)
+ : BaseSelector(std::make_unique(int8_allowed, allow_16bit, allow_4bit_weight)) {}
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};
class WhereSelector : public BaseSelector {
public:
- explicit WhereSelector(gsl::span compatible_providers = {}, bool allow_16bit = false)
- : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {}
+ explicit WhereSelector(gsl::span compatible_providers = {}, bool allow_16bit = false,
+ bool allow_4bit = false)
+ : BaseSelector(std::make_unique(allow_16bit, allow_4bit), compatible_providers) {}
};
// 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not
class MatMulSelector : public BaseSelector {
public:
- MatMulSelector(bool int8_allowed, bool allow_16bit = false)
+ MatMulSelector(bool int8_allowed, bool allow_16bit = false, bool allow_4bit = false)
: BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true,
- allow_16bit)) {}
+ allow_16bit, allow_4bit)) {}
};
// Input: DQ nodes for A, B and optional C
// Output: optional Q node for Y
class GemmSelector : public BaseSelector {
public:
- explicit GemmSelector(gsl::span compatible_providers = {}, bool allow_16bit = false)
- : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {}
+ explicit GemmSelector(gsl::span compatible_providers = {}, bool allow_16bit = false,
+ bool allow_4bit = false)
+ : BaseSelector(std::make_unique(allow_16bit, allow_4bit), compatible_providers) {}
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 8a270a05d7287..b8d5a7852a968 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -5,6 +5,7 @@
#include