diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index d669a4e599d4e..035b52d6bb689 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -16,6 +16,7 @@ #include "core/common/float8.h" #include "core/common/float16.h" #include "core/framework/int4.h" +#include "core/framework/int2.h" #include "core/framework/float4.h" #include "core/graph/onnx_protobuf.h" #include "core/framework/to_tensor_proto_element_type.h" @@ -211,6 +212,7 @@ class DataTypeImpl { static const std::vector& AllTensorTypesIRv9(); static const std::vector& AllTensorTypesIRv10(); static const std::vector& AllTensorTypesIRv11(); + static const std::vector& AllTensorTypesIRv13(); static const std::vector& AllFixedSizeTensorTypes(); // up to IR4 (no float 8), deprecated static const std::vector& AllFixedSizeTensorTypesIRv4(); @@ -285,7 +287,7 @@ template struct IsTensorContainedType : public IsAnyOf struct IsSparseTensorContainedType : public IsAnyOf(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -171,6 +177,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ retval = function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + retval = function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -230,6 +242,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -287,6 +305,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ retval = function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + retval = function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -355,6 +379,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -421,6 +451,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ retval = function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + retval = function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -477,6 +513,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -531,6 +573,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ retval = function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + retval = function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } diff --git a/include/onnxruntime/core/framework/int2.h b/include/onnxruntime/core/framework/int2.h new file mode 100644 index 0000000000000..0d406d6fcd8d3 --- /dev/null +++ b/include/onnxruntime/core/framework/int2.h @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/common/common.h" +#include + +namespace onnxruntime { + +template +struct Int2Traits; + +template <> +struct Int2Traits { + using UnpackedType = int8_t; + static constexpr int8_t min_val = -2; + static constexpr int8_t max_val = 1; +}; + +template <> +struct Int2Traits { + using UnpackedType = uint8_t; + static constexpr uint8_t min_val = 0; + static constexpr uint8_t max_val = 3; +}; + +/// +/// Stores 4 packed 2-bit elements in 1 byte. +/// Packing follows ONNX spec: x0 | (x1 << 2) | (x2 << 4) | (x3 << 6) +/// +/// Set to true if signed int2, or false if unsigned uint2. +template +struct Int2x4Base { + using UnpackedType = typename Int2Traits::UnpackedType; + static constexpr UnpackedType min_val = Int2Traits::min_val; + static constexpr UnpackedType max_val = Int2Traits::max_val; + + std::byte bits_{}; + + Int2x4Base() = default; + + explicit Int2x4Base(std::byte bits) { + bits_ = bits; + } + + Int2x4Base(UnpackedType val0, UnpackedType val1, UnpackedType val2, UnpackedType val3) { + bits_ = static_cast( + (val0 & 0x3) | + ((val1 & 0x3) << 2) | + ((val2 & 0x3) << 4) | + ((val3 & 0x3) << 6)); + } + + static inline int8_t SignExtendLower2Bits(std::byte bits) { + // Sign-extend lower 2-bits by left shifting and then doing an arithmetic right shift. + constexpr uint8_t shift = (sizeof(int32_t) * 8) - 2; + return static_cast((static_cast(bits) << shift) >> shift); + } + + inline UnpackedType GetElem(size_t index) const { + assert(index <= 3); + const uint8_t shift = 2 * static_cast(index); + const std::byte val = (bits_ >> shift) & std::byte{0x3}; + + if constexpr (Signed) { + return SignExtendLower2Bits(val); + } else { + return static_cast(val); + } + } + + inline void SetElem(size_t index, UnpackedType val) { + assert(index <= 3); + const uint8_t shift = 2 * static_cast(index); + const std::byte clear_mask = ~(std::byte{0x3} << shift); + + bits_ &= clear_mask; // Clear 2-bit element to 0 + bits_ |= static_cast((val & 0x3) << shift); // Set 2-bit element to val + } + + inline std::byte ToBits() const { + return bits_; + } + + /// + /// Calculates the number of packed byte units needed to store the given number of 2-bit elements. + /// Each byte stores 4 x 2-bit elements. + /// + static size_t CalcNumInt2Quads(size_t num_int2_elems) { + return (num_int2_elems + 3) / 4; + } + + /// + /// Copy a source buffer of 2-bit elements (packed) into a destination buffer of 8-bit elements (unpacked). + /// + /// Destination buffer to store unpacked 8-bit elements + /// Source buffer with 2-bit elements + /// True on success + static bool Unpack(gsl::span dst, gsl::span> src) { + if (CalcNumInt2Quads(dst.size()) != src.size()) { + return false; + } + + if (src.empty()) { + return true; + } + + for (size_t i = 0; i < dst.size(); i++) { + size_t byte_idx = i >> 2; // i / 4 + size_t elem_idx = i & 0x3; // i % 4 + dst[i] = src[byte_idx].GetElem(elem_idx); + } + + return true; + } + + /// + /// Copy a source buffer of 8-bit elements (unpacked) into a destination buffer of 2-bit elements (packed). + /// + /// Destination buffer to store packed 2-bit elements + /// Source buffer with 8-bit elements + /// True on success + static bool Pack(gsl::span> dst, gsl::span src) { + if (CalcNumInt2Quads(src.size()) != dst.size()) { + return false; + } + + if (src.empty()) { + return true; + } + + size_t src_i = 0; + size_t dst_i = 0; + const size_t full_quads = src.size() / 4; + + // Process complete groups of 4 elements + for (; dst_i < full_quads; dst_i++) { + dst[dst_i] = Int2x4Base(src[src_i], src[src_i + 1], src[src_i + 2], src[src_i + 3]); + src_i += 4; + } + + // Handle remaining elements (1-3) + if (src_i < src.size()) { + UnpackedType vals[4] = {0, 0, 0, 0}; + size_t remaining = src.size() - src_i; + for (size_t j = 0; j < remaining; j++) { + vals[j] = src[src_i + j]; + } + dst[dst_i] = Int2x4Base(vals[0], vals[1], vals[2], vals[3]); + } + + return true; + } + + /// + /// Returns hierarchical indices for a packed int2 element from the given element index. + /// + /// Usage: + /// Int2x4* data = ...; + /// auto indices = GetTensorElemIndices(5); // 6th int2 element + /// int8_t elem = data[indices.first].GetElem(indices.second); + /// + /// Index of 2-bit element + /// Pair of (byte_index, element_index_within_byte) + static inline std::pair GetTensorElemIndices(size_t index) { + return {index >> 2, index & 0x3}; + } +}; + +using Int2x4 = Int2x4Base; +using UInt2x4 = Int2x4Base; +static_assert(sizeof(Int2x4) == sizeof(std::byte)); +static_assert(sizeof(UInt2x4) == sizeof(std::byte)); + +} // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h index e1b5e614d095d..82aefe0165fcc 100644 --- a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h +++ b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h @@ -13,6 +13,7 @@ #include "core/framework/float4.h" #include "core/common/float8.h" #include "core/common/float16.h" +#include "core/framework/int2.h" #include "core/framework/int4.h" namespace onnxruntime { @@ -116,5 +117,14 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType return ONNX_NAMESPACE::TensorProto_DataType_UINT4; } +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_INT2; +} +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_UINT2; +} + } // namespace utils } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5acac571f3f3b..695d4dfbb4631 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -209,6 +209,9 @@ typedef enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, // maps to a pair of packed int4 values (size == 1 byte) // Float4 types were introduced in ONNX 1.18. See https://onnx.ai/onnx/technical/float4.html ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1, // maps to a pair of packed float4 values (size == 1 byte) + // Int2 types were introduced in ONNX 1.20. See https://onnx.ai/onnx/technical/int2.html + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2, // maps to 4 packed uint2 values (size == 1 byte) + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2, // maps to 4 packed int2 values (size == 1 byte) } ONNXTensorElementDataType; // Synced with onnx TypeProto oneof diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index 30896b37654ff..e5a8e718bd024 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -645,6 +645,9 @@ ORT_REGISTER_TENSOR_TYPE(Float4E2M1x2); ORT_REGISTER_TENSOR_TYPE(Int4x2); ORT_REGISTER_TENSOR_TYPE(UInt4x2); +ORT_REGISTER_TENSOR_TYPE(Int2x4); +ORT_REGISTER_TENSOR_TYPE(UInt2x4); + #if !defined(DISABLE_SPARSE_TENSORS) ORT_REGISTER_SPARSE_TENSOR_TYPE(int32_t); ORT_REGISTER_SPARSE_TENSOR_TYPE(float); @@ -708,6 +711,9 @@ ORT_REGISTER_SEQ_TENSOR_TYPE(Float8E5M2FNUZ); ORT_REGISTER_SEQ_TENSOR_TYPE(Int4x2); ORT_REGISTER_SEQ_TENSOR_TYPE(UInt4x2); +ORT_REGISTER_SEQ_TENSOR_TYPE(Int2x4); +ORT_REGISTER_SEQ_TENSOR_TYPE(UInt2x4); + #if !defined(DISABLE_ML_OPS) ORT_REGISTER_SEQ(VectorMapStringToFloat); ORT_REGISTER_SEQ(VectorMapInt64ToFloat); @@ -735,7 +741,9 @@ ORT_REGISTER_SEQ(VectorMapInt64ToFloat); ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2FNUZ); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int4x2); \ - ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2); + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int2x4); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt2x4); #else @@ -755,7 +763,9 @@ ORT_REGISTER_SEQ(VectorMapInt64ToFloat); ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, MLFloat16); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, BFloat16); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int4x2); \ - ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2); + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int2x4); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt2x4); #endif @@ -825,6 +835,8 @@ void RegisterAllProtos(const std::function& reg_fn) { #endif REGISTER_TENSOR_PROTO(Int4x2, reg_fn); REGISTER_TENSOR_PROTO(UInt4x2, reg_fn); + REGISTER_TENSOR_PROTO(Int2x4, reg_fn); + REGISTER_TENSOR_PROTO(UInt2x4, reg_fn); #if !defined(DISABLE_SPARSE_TENSORS) REGISTER_SPARSE_TENSOR_PROTO(int32_t, reg_fn); @@ -886,6 +898,8 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_SEQ_TENSOR_PROTO(Int4x2, reg_fn); REGISTER_SEQ_TENSOR_PROTO(UInt4x2, reg_fn); + REGISTER_SEQ_TENSOR_PROTO(Int2x4, reg_fn); + REGISTER_SEQ_TENSOR_PROTO(UInt2x4, reg_fn); #if !defined(DISABLE_ML_OPS) REGISTER_ONNX_PROTO(VectorMapStringToFloat, reg_fn); @@ -916,7 +930,9 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2FNUZ, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int4x2, reg_fn); \ - REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn); + REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int2x4, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt2x4, reg_fn); #else @@ -936,7 +952,9 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_OPTIONAL_PROTO(ORT_TYPE, MLFloat16, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, BFloat16, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int4x2, reg_fn); \ - REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn); + REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int2x4, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt2x4, reg_fn); #endif @@ -1003,6 +1021,10 @@ const char* DataTypeImpl::ToString(MLDataType type) { return "Int4x2"; case TensorProto_DataType_UINT4: return "UInt4x2"; + case TensorProto_DataType_INT2: + return "Int2x4"; + case TensorProto_DataType_UINT2: + return "UInt2x4"; default: break; } @@ -1077,6 +1099,10 @@ const TensorTypeBase* DataTypeImpl::TensorTypeFromONNXEnum(int type) { return DataTypeImpl::GetTensorType()->AsTensorType(); case TensorProto_DataType_UINT4: return DataTypeImpl::GetTensorType()->AsTensorType(); + case TensorProto_DataType_INT2: + return DataTypeImpl::GetTensorType()->AsTensorType(); + case TensorProto_DataType_UINT2: + return DataTypeImpl::GetTensorType()->AsTensorType(); default: ORT_NOT_IMPLEMENTED("tensor type ", type, " is not supported"); @@ -1130,6 +1156,10 @@ const SequenceTensorTypeBase* DataTypeImpl::SequenceTensorTypeFromONNXEnum(int t return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); case TensorProto_DataType_UINT4: return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); + case TensorProto_DataType_INT2: + return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); + case TensorProto_DataType_UINT2: + return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); default: ORT_NOT_IMPLEMENTED("sequence tensor type ", type, " is not supported"); @@ -1232,6 +1262,8 @@ ORT_REGISTER_PRIM_SUBBYTE_TYPE(Float4E2M1x2, 2); ORT_REGISTER_PRIM_SUBBYTE_TYPE(Int4x2, 2); ORT_REGISTER_PRIM_SUBBYTE_TYPE(UInt4x2, 2); +ORT_REGISTER_PRIM_SUBBYTE_TYPE(Int2x4, 4); +ORT_REGISTER_PRIM_SUBBYTE_TYPE(UInt2x4, 4); namespace { template @@ -1334,6 +1366,12 @@ const std::vector& DataTypeImpl::AllTensorTypesIRv11() { return all_tensor_types; } +const std::vector& DataTypeImpl::AllTensorTypesIRv13() { + static std::vector all_tensor_types = + GetTensorTypesFromTypeList(); + return all_tensor_types; +} + const std::vector& DataTypeImpl::AllFixedSizeSequenceTensorTypes() { return AllFixedSizeSequenceTensorTypesIRv4(); } diff --git a/onnxruntime/core/framework/element_type_lists.h b/onnxruntime/core/framework/element_type_lists.h index ce7c243849d5e..67358d045ba68 100644 --- a/onnxruntime/core/framework/element_type_lists.h +++ b/onnxruntime/core/framework/element_type_lists.h @@ -12,6 +12,7 @@ #include "core/common/float8.h" #include "core/common/float16.h" #include "core/framework/int4.h" +#include "core/framework/int2.h" #include "core/framework/float4.h" namespace onnxruntime { @@ -99,6 +100,13 @@ using AllIRv11 = using AllIRv11 = AllIRv10; #endif +// IR v13 adds INT2/UINT2 (2-bit integer types) +using AllIRv13 = + boost::mp11::mp_push_back< + AllIRv11, + UInt2x4, + Int2x4>; + // TODO: This needs upgrade to some newer version ,buit it has been // at this version for a while and it needs changes at the use sites // where-in the types in the newer IR versions are not supported. diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc index 461e82d72dc83..ffeb8b5b4a193 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc @@ -87,6 +87,12 @@ ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) { case TensorType::TensorProto_DataType_FLOAT4E2M1: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1; } // maps to a pair of float4 (size == 1 byte) + case TensorType::TensorProto_DataType_INT2: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2; + } // maps to 4 packed int2 values (size == 1 byte) + case TensorType::TensorProto_DataType_UINT2: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2; + } // maps to 4 packed uint2 values (size == 1 byte) default: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index cbf1a953819d3..0bac24a2c3aa0 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -229,6 +229,12 @@ constexpr ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementData case o::TensorProto_DataType_FLOAT4E2M1: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1; break; + case o::TensorProto_DataType_INT2: + type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2; + break; + case o::TensorProto_DataType_UINT2: + type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2; + break; default: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; break; diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 0f5622ec2ed45..dfdcf07492903 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -73,6 +73,21 @@ TensorProto ToScalarTensor(TensorProto_DataType datatype, int32_t value) { return t; \ } +// 2-bit types use the same storage pattern as 4-bit types +#define TO_TENSOR_ORT_TYPE_2BIT_TYPE(TYPE) \ + template <> \ + TensorProto ToTensor(const onnxruntime::TYPE& value) { \ + return ToScalarTensor(ToTensorProtoElementType(), static_cast(value.ToBits())); \ + } \ + template <> \ + TensorProto ToTensor(const std::vector& values) { \ + TensorProto t = ToTensorInitialize(ToTensorProtoElementType()); \ + for (const onnxruntime::TYPE& val : values) { \ + t.add_int32_data(static_cast(val.ToBits())); \ + } \ + return t; \ + } + namespace ONNX_NAMESPACE { // Provide template specializations for onnxruntime-specific types. @@ -90,6 +105,9 @@ TO_TENSOR_ORT_TYPE_4BIT_TYPE(Float4E2M1x2) TO_TENSOR_ORT_TYPE_4BIT_TYPE(Int4x2) TO_TENSOR_ORT_TYPE_4BIT_TYPE(UInt4x2) +TO_TENSOR_ORT_TYPE_2BIT_TYPE(Int2x4) +TO_TENSOR_ORT_TYPE_2BIT_TYPE(UInt2x4) + bool operator==(const ONNX_NAMESPACE::TensorShapeProto_Dimension& l, const ONNX_NAMESPACE::TensorShapeProto_Dimension& r) { if (l.has_dim_value()) { @@ -167,6 +185,10 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2, CalcNumInt4Pairs) DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2, CalcNumInt4Pairs) +// 2-bit types use the same pattern - CalcNumInt2Quads gives number of packed bytes +DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int2x4, CalcNumInt2Quads) +DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt2x4, CalcNumInt2Quads) + #if !defined(DISABLE_FLOAT4_TYPES) DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Float4E2M1x2, CalcNumFloat4Pairs) #endif @@ -515,6 +537,10 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Int4x2, CalcNumInt4Pairs) DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(UInt4x2, CalcNumInt4Pairs) +// 2-bit types +DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Int2x4, CalcNumInt2Quads) +DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(UInt2x4, CalcNumInt2Quads) + #if !defined(DISABLE_FLOAT4_TYPES) DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Float4E2M1x2, CalcNumFloat4Pairs) #endif @@ -899,6 +925,41 @@ DEFINE_INT4_UNPACK_TENSOR_IMPL(Int4x2, TensorProto_DataType_INT4) // UnpackTensor DEFINE_INT4_UNPACK_TENSOR_IMPL(UInt4x2, TensorProto_DataType_UINT4) +// 2-bit type unpack implementation +#define DEFINE_INT2_UNPACK_TENSOR_IMPL(INT2_TYPE, ONNX_INT2_TYPE) \ + template <> \ + Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \ + /*out*/ INT2_TYPE* p_data, size_t expected_num_elems) { \ + if (nullptr == p_data) { \ + const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); \ + return size == 0 ? Status::OK() : Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ + } \ + if (ONNX_NAMESPACE::ONNX_INT2_TYPE != tensor.data_type()) { \ + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ + } \ + \ + size_t expected_int2_quads = INT2_TYPE::CalcNumInt2Quads(expected_num_elems); \ + \ + if (raw_data != nullptr) { \ + return UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); \ + } \ + \ + ORT_RETURN_IF_NOT(static_cast(tensor.int32_data_size()) == expected_int2_quads, \ + "UnpackTensor: the pre-allocated size does not match the size in proto"); \ + \ + for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { \ + p_data[i] = INT2_TYPE(static_cast(tensor.int32_data()[i])); \ + } \ + \ + return Status::OK(); \ + } + +// UnpackTensor +DEFINE_INT2_UNPACK_TENSOR_IMPL(Int2x4, TensorProto_DataType_INT2) + +// UnpackTensor +DEFINE_INT2_UNPACK_TENSOR_IMPL(UInt2x4, TensorProto_DataType_UINT2) + #if !defined(DISABLE_FLOAT4_TYPES) template <> @@ -985,6 +1046,9 @@ INSTANTIATE_UNPACK_TENSOR(Float8E5M2FNUZ) INSTANTIATE_UNPACK_TENSOR(Int4x2) INSTANTIATE_UNPACK_TENSOR(UInt4x2) +INSTANTIATE_UNPACK_TENSOR(Int2x4) +INSTANTIATE_UNPACK_TENSOR(UInt2x4) + #define CASE_PROTO_TRACE(X, Y) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ if (!IAllocator::CalcMemSizeForArrayWithAlignment(size, sizeof(Y), out)) { \ @@ -1008,6 +1072,14 @@ INSTANTIATE_UNPACK_TENSOR(UInt4x2) break; #endif +// 2-bit types +#define CASE_PROTO_TRACE_INT2(X, Y) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + if (!IAllocator::CalcMemSizeForArrayWithAlignment(Y::CalcNumInt2Quads(size), sizeof(Y), out)) { \ + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid TensorProto"); \ + } \ + break; + template common::Status GetSizeInBytesFromTensorShapeAndType(const TensorShape& shape, int32_t element_type, size_t* out) { const auto size = narrow(shape.Size()); @@ -1034,6 +1106,8 @@ common::Status GetSizeInBytesFromTensorShapeAndType(const TensorShape& shape, in #endif CASE_PROTO_TRACE_INT4(UINT4, UInt4x2); CASE_PROTO_TRACE_INT4(INT4, Int4x2); + CASE_PROTO_TRACE_INT2(UINT2, UInt2x4); + CASE_PROTO_TRACE_INT2(INT2, Int2x4); #if !defined(DISABLE_FLOAT4_TYPES) CASE_PROTO_TRACE_FLOAT4(FLOAT4E2M1, Float4E2M1x2); @@ -1428,6 +1502,8 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa #endif CASE_PROTO(INT4, Int4x2); CASE_PROTO(UINT4, UInt4x2); + CASE_PROTO(INT2, Int2x4); + CASE_PROTO(UINT2, UInt2x4); #if !defined(DISABLE_FLOAT4_TYPES) CASE_PROTO(FLOAT4E2M1, Float4E2M1x2); @@ -2073,11 +2149,14 @@ template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::T break; \ } -#define CASE_UNPACK_4BIT_TYPE(TYPE, ELEMENT_TYPE, DATA_SIZE, CALC_PAIR_FUN) \ +// Sub-byte types (2-bit and 4-bit) are stored in a packed format. +// This unpacking code is shared for INT4, UINT4, FLOAT4E2M1, INT2, and UINT2. +// CALC_PACKED_UNITS_FUN specifies the function to calculate packed byte count from element count. +#define CASE_UNPACK_SUBBYTE_TYPE(TYPE, ELEMENT_TYPE, DATA_SIZE, CALC_PACKED_UNITS_FUN) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \ TensorShape tensor_shape = GetTensorShapeFromTensorProto(initializer); \ size_t element_count = static_cast(tensor_shape.Size()); \ - size_t packed_element_count = ELEMENT_TYPE::CALC_PAIR_FUN(element_count); \ + size_t packed_element_count = ELEMENT_TYPE::CALC_PACKED_UNITS_FUN(element_count); \ unpacked_tensor.resize(packed_element_count * sizeof(ELEMENT_TYPE)); \ return onnxruntime::utils::UnpackTensor(initializer, \ initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \ @@ -2120,11 +2199,13 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer, CASE_UNPACK(FLOAT8E5M2, onnxruntime::Float8E5M2, int32_data_size); CASE_UNPACK(FLOAT8E5M2FNUZ, onnxruntime::Float8E5M2FNUZ, int32_data_size); #endif - CASE_UNPACK_4BIT_TYPE(INT4, Int4x2, int32_data_size, CalcNumInt4Pairs); - CASE_UNPACK_4BIT_TYPE(UINT4, UInt4x2, int32_data_size, CalcNumInt4Pairs); + CASE_UNPACK_SUBBYTE_TYPE(INT4, Int4x2, int32_data_size, CalcNumInt4Pairs); + CASE_UNPACK_SUBBYTE_TYPE(UINT4, UInt4x2, int32_data_size, CalcNumInt4Pairs); + CASE_UNPACK_SUBBYTE_TYPE(INT2, Int2x4, int32_data_size, CalcNumInt2Quads); + CASE_UNPACK_SUBBYTE_TYPE(UINT2, UInt2x4, int32_data_size, CalcNumInt2Quads); #if !defined(DISABLE_FLOAT4_TYPES) - CASE_UNPACK_4BIT_TYPE(FLOAT4E2M1, Float4E2M1x2, int32_data_size, CalcNumFloat4Pairs); + CASE_UNPACK_SUBBYTE_TYPE(FLOAT4E2M1, Float4E2M1x2, int32_data_size, CalcNumFloat4Pairs); #endif default: diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 4b4c483ba1202..5eed13ec1073c 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -227,6 +227,16 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; } +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2; +} + +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2; +} + #if !defined(DISABLE_FLOAT4_TYPES) template <> constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 2a9a8127874ee..f5421d8540db8 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -30,6 +30,7 @@ #include "core/common/float8.h" #include "core/common/float16.h" #include "core/framework/int4.h" +#include "core/framework/int2.h" #include "core/framework/float4.h" #include "core/framework/tensor_shape.h" #include "core/providers/providers.h" @@ -80,6 +81,8 @@ enum TensorProto_DataType : int { TensorProto_DataType_INT4 = 22, TensorProto_DataType_FLOAT4E2M1 = 23, TensorProto_DataType_FLOAT8E8M0 = 24, + TensorProto_DataType_UINT2 = 25, + TensorProto_DataType_INT2 = 26, }; enum TensorProto_DataLocation : int { @@ -410,6 +413,15 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; } +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2; +} +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2; +} + inline std::vector> CreateSupportedPartitions(const GraphViewer& graph_viewer, const std::unordered_set& supported_nodes, diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 5732984af29b4..ee00a06751d0a 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -172,6 +172,10 @@ template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_Int4x2(); } template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_UInt4x2(); } +template <> +MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_Int2x4(); } +template <> +MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_UInt2x4(); } #if !defined(DISABLE_FLOAT4_TYPES) template <> @@ -222,6 +226,10 @@ template <> MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_Int4x2(); } template <> MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_UInt4x2(); } +template <> +MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_Int2x4(); } +template <> +MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_UInt2x4(); } #if !defined(DISABLE_FLOAT4_TYPES) template <> diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 51bd2c467acec..aeaf05cf14591 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -775,6 +775,8 @@ struct ProviderHost { #endif virtual MLDataType DataTypeImpl__GetType_Int4x2() = 0; virtual MLDataType DataTypeImpl__GetType_UInt4x2() = 0; + virtual MLDataType DataTypeImpl__GetType_Int2x4() = 0; + virtual MLDataType DataTypeImpl__GetType_UInt2x4() = 0; virtual MLDataType DataTypeImpl__GetTensorTypeFromOnnxType(int) = 0; virtual MLDataType DataTypeImpl__GetTensorType_bool() = 0; @@ -802,6 +804,8 @@ struct ProviderHost { virtual MLDataType DataTypeImpl__GetTensorType_Int4x2() = 0; virtual MLDataType DataTypeImpl__GetTensorType_UInt4x2() = 0; + virtual MLDataType DataTypeImpl__GetTensorType_Int2x4() = 0; + virtual MLDataType DataTypeImpl__GetTensorType_UInt2x4() = 0; #if !defined(DISABLE_SPARSE_TENSORS) virtual MLDataType DataTypeImpl__GetSparseTensorType_bool() = 0; @@ -1260,6 +1264,8 @@ struct ProviderHost { #endif virtual Int4x2* Tensor__MutableData_Int4x2(Tensor* p) = 0; virtual UInt4x2* Tensor__MutableData_UInt4x2(Tensor* p) = 0; + virtual Int2x4* Tensor__MutableData_Int2x4(Tensor* p) = 0; + virtual UInt2x4* Tensor__MutableData_UInt2x4(Tensor* p) = 0; virtual const bool* Tensor__Data_bool(const Tensor* p) = 0; virtual const int8_t* Tensor__Data_int8(const Tensor* p) = 0; @@ -1286,6 +1292,8 @@ struct ProviderHost { #endif virtual const Int4x2* Tensor__Data_Int4x2(const Tensor* p) = 0; virtual const UInt4x2* Tensor__Data_UInt4x2(const Tensor* p) = 0; + virtual const Int2x4* Tensor__Data_Int2x4(const Tensor* p) = 0; + virtual const UInt2x4* Tensor__Data_UInt2x4(const Tensor* p) = 0; virtual gsl::span Tensor__DataAsSpan_int64(const Tensor* p) = 0; @@ -1322,6 +1330,8 @@ struct ProviderHost { #endif virtual bool Tensor__IsDataType_Int4x2(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataType_UInt4x2(const Tensor* p) noexcept = 0; + virtual bool Tensor__IsDataType_Int2x4(const Tensor* p) noexcept = 0; + virtual bool Tensor__IsDataType_UInt2x4(const Tensor* p) noexcept = 0; virtual const TensorShape& Tensor__Shape(const Tensor* p) = 0; virtual void Tensor__Reshape(Tensor* p, const TensorShape& new_shape) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 0ab7ee0aedd1a..041cf764e7ede 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1524,6 +1524,10 @@ inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataTy template <> inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_UInt4x2(this); } template <> +inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_Int2x4(this); } +template <> +inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_UInt2x4(this); } +template <> inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_int8(this); } template <> inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_uint8(this); } @@ -1571,6 +1575,10 @@ inline Int4x2* Tensor::MutableData() { return g_host->Tensor__MutableDat template <> inline UInt4x2* Tensor::MutableData() { return g_host->Tensor__MutableData_UInt4x2(this); } template <> +inline Int2x4* Tensor::MutableData() { return g_host->Tensor__MutableData_Int2x4(this); } +template <> +inline UInt2x4* Tensor::MutableData() { return g_host->Tensor__MutableData_UInt2x4(this); } +template <> inline int8_t* Tensor::MutableData() { return g_host->Tensor__MutableData_int8(this); } template <> inline uint8_t* Tensor::MutableData() { return g_host->Tensor__MutableData_uint8(this); } @@ -1618,6 +1626,10 @@ inline const Int4x2* Tensor::Data() const { return g_host->Tensor__Data_ template <> inline const UInt4x2* Tensor::Data() const { return g_host->Tensor__Data_UInt4x2(this); } template <> +inline const Int2x4* Tensor::Data() const { return g_host->Tensor__Data_Int2x4(this); } +template <> +inline const UInt2x4* Tensor::Data() const { return g_host->Tensor__Data_UInt2x4(this); } +template <> inline const int8_t* Tensor::Data() const { return g_host->Tensor__Data_int8(this); } template <> inline const uint8_t* Tensor::Data() const { return g_host->Tensor__Data_uint8(this); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index a55ab38113a0f..5700a32cf5ca1 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1002,6 +1002,8 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetType_Int4x2() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetType_UInt4x2() override { return DataTypeImpl::GetType(); } + MLDataType DataTypeImpl__GetType_Int2x4() override { return DataTypeImpl::GetType(); } + MLDataType DataTypeImpl__GetType_UInt2x4() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetTensorTypeFromOnnxType(int onnx_type) override { return DataTypeImpl::TensorTypeFromONNXEnum(onnx_type)->AsTensorType(); } MLDataType DataTypeImpl__GetTensorType_bool() override { return DataTypeImpl::GetTensorType(); } @@ -1031,6 +1033,8 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetTensorType_Int4x2() override { return DataTypeImpl::GetTensorType(); } MLDataType DataTypeImpl__GetTensorType_UInt4x2() override { return DataTypeImpl::GetTensorType(); } + MLDataType DataTypeImpl__GetTensorType_Int2x4() override { return DataTypeImpl::GetTensorType(); } + MLDataType DataTypeImpl__GetTensorType_UInt2x4() override { return DataTypeImpl::GetTensorType(); } #if !defined(DISABLE_SPARSE_TENSORS) MLDataType DataTypeImpl__GetSparseTensorType_bool() override { return DataTypeImpl::GetSparseTensorType(); } @@ -1680,6 +1684,8 @@ struct ProviderHostImpl : ProviderHost { Int4x2* Tensor__MutableData_Int4x2(Tensor* p) override { return p->MutableData(); } UInt4x2* Tensor__MutableData_UInt4x2(Tensor* p) override { return p->MutableData(); } + Int2x4* Tensor__MutableData_Int2x4(Tensor* p) override { return p->MutableData(); } + UInt2x4* Tensor__MutableData_UInt2x4(Tensor* p) override { return p->MutableData(); } const bool* Tensor__Data_bool(const Tensor* p) override { return p->Data(); } const int8_t* Tensor__Data_int8(const Tensor* p) override { return p->Data(); } @@ -1708,6 +1714,8 @@ struct ProviderHostImpl : ProviderHost { const Int4x2* Tensor__Data_Int4x2(const Tensor* p) override { return p->Data(); } const UInt4x2* Tensor__Data_UInt4x2(const Tensor* p) override { return p->Data(); } + const Int2x4* Tensor__Data_Int2x4(const Tensor* p) override { return p->Data(); } + const UInt2x4* Tensor__Data_UInt2x4(const Tensor* p) override { return p->Data(); } gsl::span Tensor__DataAsSpan_int64(const Tensor* p) override { return p->DataAsSpan(); } @@ -1744,6 +1752,8 @@ struct ProviderHostImpl : ProviderHost { bool Tensor__IsDataType_Int4x2(const Tensor* p) noexcept override { return p->IsDataType(); } bool Tensor__IsDataType_UInt4x2(const Tensor* p) noexcept override { return p->IsDataType(); } + bool Tensor__IsDataType_Int2x4(const Tensor* p) noexcept override { return p->IsDataType(); } + bool Tensor__IsDataType_UInt2x4(const Tensor* p) noexcept override { return p->IsDataType(); } const TensorShape& Tensor__Shape(const Tensor* p) override { return p->Shape(); } void Tensor__Reshape(Tensor* p, const TensorShape& new_shape) override { return p->Reshape(new_shape); } diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index 7c95b4d10a872..1fd3f8eb76e5f 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -15,6 +15,7 @@ #include "core/common/type_utils.h" #include "core/common/float16.h" #include "core/framework/int4.h" +#include "core/framework/int2.h" #include "test/util/include/test_random_seed.h" namespace onnxruntime { @@ -126,6 +127,22 @@ class RandomValueGenerator { return data; } + template + typename std::enable_if< + std::is_same_v || std::is_same_v, + std::vector>::type + Uniform(gsl::span dims, TInt2 min, TInt2 max) { + using UnpackedType = typename TInt2::UnpackedType; + std::vector data_int8 = Uniform(dims, min.GetElem(0), max.GetElem(0)); + std::vector data(TInt2::CalcNumInt2Quads(data_int8.size())); + for (size_t i = 0; i < data_int8.size(); i++) { + size_t r = i >> 2; + size_t c = i & 0x3; + data[r].SetElem(c, data_int8[i]); + } + return data; + } + // Gaussian distribution for float template typename std::enable_if< diff --git a/onnxruntime/test/framework/int2_test.cc b/onnxruntime/test/framework/int2_test.cc new file mode 100644 index 0000000000000..cc7c4c1b54f97 --- /dev/null +++ b/onnxruntime/test/framework/int2_test.cc @@ -0,0 +1,322 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "core/framework/int2.h" +#include "core/framework/data_types.h" +#include "core/framework/tensorprotoutils.h" +#include "core/platform/env.h" +#include "test/test_environment.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// ============================================== +// Int2x4 Tests (signed 2-bit integer, 4 per byte) +// ============================================== + +TEST(Int2_Tests, Int2x4_DefaultConstructor) { + Int2x4 int2; + EXPECT_EQ(static_cast(int2.ToBits()), 0); +} + +TEST(Int2_Tests, Int2x4_BitsConstructor) { + // Pack 4 signed 2-bit values: val0=1, val1=-1 (0b11), val2=-2 (0b10), val3=0 + // Binary: 0b00'10'11'01 = 0x2D + Int2x4 int2(std::byte{0x2D}); + EXPECT_EQ(int2.GetElem(0), 1); + EXPECT_EQ(int2.GetElem(1), -1); // 0b11 sign-extended is -1 + EXPECT_EQ(int2.GetElem(2), -2); // 0b10 sign-extended is -2 + EXPECT_EQ(int2.GetElem(3), 0); +} + +TEST(Int2_Tests, Int2x4_FourValueConstructor) { + Int2x4 int2(1, -1, -2, 0); + EXPECT_EQ(int2.GetElem(0), 1); + EXPECT_EQ(int2.GetElem(1), -1); + EXPECT_EQ(int2.GetElem(2), -2); + EXPECT_EQ(int2.GetElem(3), 0); +} + +TEST(Int2_Tests, Int2x4_GetSetElem) { + Int2x4 int2; + + // Set and get each element + int2.SetElem(0, 1); + int2.SetElem(1, -1); + int2.SetElem(2, -2); + int2.SetElem(3, 0); + + EXPECT_EQ(int2.GetElem(0), 1); + EXPECT_EQ(int2.GetElem(1), -1); + EXPECT_EQ(int2.GetElem(2), -2); + EXPECT_EQ(int2.GetElem(3), 0); +} + +TEST(Int2_Tests, Int2x4_ValueRange) { + // Verify min/max values + EXPECT_EQ(Int2x4::min_val, -2); + EXPECT_EQ(Int2x4::max_val, 1); + + // Test all valid signed 2-bit values: -2, -1, 0, 1 + Int2x4 int2(-2, -1, 0, 1); + EXPECT_EQ(int2.GetElem(0), -2); + EXPECT_EQ(int2.GetElem(1), -1); + EXPECT_EQ(int2.GetElem(2), 0); + EXPECT_EQ(int2.GetElem(3), 1); +} + +TEST(Int2_Tests, Int2x4_CalcNumInt2Quads) { + // 0 elements -> 0 bytes + EXPECT_EQ(Int2x4::CalcNumInt2Quads(0), 0u); + // 1 element -> 1 byte + EXPECT_EQ(Int2x4::CalcNumInt2Quads(1), 1u); + // 4 elements -> 1 byte + EXPECT_EQ(Int2x4::CalcNumInt2Quads(4), 1u); + // 5 elements -> 2 bytes + EXPECT_EQ(Int2x4::CalcNumInt2Quads(5), 2u); + // 8 elements -> 2 bytes + EXPECT_EQ(Int2x4::CalcNumInt2Quads(8), 2u); +} + +TEST(Int2_Tests, Int2x4_PackUnpack) { + std::vector src_values = {1, -1, -2, 0, 1, -1, -2, 0}; + std::vector packed(Int2x4::CalcNumInt2Quads(src_values.size())); + + // Pack + bool pack_result = Int2x4::Pack(gsl::make_span(packed), gsl::make_span(src_values)); + EXPECT_TRUE(pack_result); + + // Unpack + std::vector unpacked(src_values.size()); + bool unpack_result = Int2x4::Unpack(gsl::make_span(unpacked), gsl::make_span(packed)); + EXPECT_TRUE(unpack_result); + + // Verify + for (size_t i = 0; i < src_values.size(); i++) { + EXPECT_EQ(unpacked[i], src_values[i]) << "Mismatch at index " << i; + } +} + +TEST(Int2_Tests, Int2x4_PackUnpackOddElements) { + // Test with non-multiple-of-4 element count + std::vector src_values = {1, -1, -2}; + std::vector packed(Int2x4::CalcNumInt2Quads(src_values.size())); + + // Pack + bool pack_result = Int2x4::Pack(gsl::make_span(packed), gsl::make_span(src_values)); + EXPECT_TRUE(pack_result); + + // Unpack + std::vector unpacked(src_values.size()); + bool unpack_result = Int2x4::Unpack(gsl::make_span(unpacked), gsl::make_span(packed)); + EXPECT_TRUE(unpack_result); + + // Verify + for (size_t i = 0; i < src_values.size(); i++) { + EXPECT_EQ(unpacked[i], src_values[i]) << "Mismatch at index " << i; + } +} + +// ============================================== +// UInt2x4 Tests (unsigned 2-bit integer, 4 per byte) +// ============================================== + +TEST(Int2_Tests, UInt2x4_DefaultConstructor) { + UInt2x4 uint2; + EXPECT_EQ(static_cast(uint2.ToBits()), 0); +} + +TEST(Int2_Tests, UInt2x4_BitsConstructor) { + // Pack 4 unsigned 2-bit values: val0=0, val1=1, val2=2, val3=3 + // Binary: 0b11'10'01'00 = 0xE4 + UInt2x4 uint2(std::byte{0xE4}); + EXPECT_EQ(uint2.GetElem(0), 0); + EXPECT_EQ(uint2.GetElem(1), 1); + EXPECT_EQ(uint2.GetElem(2), 2); + EXPECT_EQ(uint2.GetElem(3), 3); +} + +TEST(Int2_Tests, UInt2x4_FourValueConstructor) { + UInt2x4 uint2(0, 1, 2, 3); + EXPECT_EQ(uint2.GetElem(0), 0); + EXPECT_EQ(uint2.GetElem(1), 1); + EXPECT_EQ(uint2.GetElem(2), 2); + EXPECT_EQ(uint2.GetElem(3), 3); +} + +TEST(Int2_Tests, UInt2x4_GetSetElem) { + UInt2x4 uint2; + + // Set and get each element + uint2.SetElem(0, 0); + uint2.SetElem(1, 1); + uint2.SetElem(2, 2); + uint2.SetElem(3, 3); + + EXPECT_EQ(uint2.GetElem(0), 0); + EXPECT_EQ(uint2.GetElem(1), 1); + EXPECT_EQ(uint2.GetElem(2), 2); + EXPECT_EQ(uint2.GetElem(3), 3); +} + +TEST(Int2_Tests, UInt2x4_ValueRange) { + // Verify min/max values + EXPECT_EQ(UInt2x4::min_val, 0); + EXPECT_EQ(UInt2x4::max_val, 3); + + // Test all valid unsigned 2-bit values: 0, 1, 2, 3 + UInt2x4 uint2(0, 1, 2, 3); + EXPECT_EQ(uint2.GetElem(0), 0); + EXPECT_EQ(uint2.GetElem(1), 1); + EXPECT_EQ(uint2.GetElem(2), 2); + EXPECT_EQ(uint2.GetElem(3), 3); +} + +TEST(Int2_Tests, UInt2x4_CalcNumInt2Quads) { + // Same as Int2x4 + EXPECT_EQ(UInt2x4::CalcNumInt2Quads(0), 0u); + EXPECT_EQ(UInt2x4::CalcNumInt2Quads(1), 1u); + EXPECT_EQ(UInt2x4::CalcNumInt2Quads(4), 1u); + EXPECT_EQ(UInt2x4::CalcNumInt2Quads(5), 2u); +} + +TEST(Int2_Tests, UInt2x4_PackUnpack) { + std::vector src_values = {0, 1, 2, 3, 3, 2, 1, 0}; + std::vector packed(UInt2x4::CalcNumInt2Quads(src_values.size())); + + // Pack + bool pack_result = UInt2x4::Pack(gsl::make_span(packed), gsl::make_span(src_values)); + EXPECT_TRUE(pack_result); + + // Unpack + std::vector unpacked(src_values.size()); + bool unpack_result = UInt2x4::Unpack(gsl::make_span(unpacked), gsl::make_span(packed)); + EXPECT_TRUE(unpack_result); + + // Verify + for (size_t i = 0; i < src_values.size(); i++) { + EXPECT_EQ(unpacked[i], src_values[i]) << "Mismatch at index " << i; + } +} + +TEST(Int2_Tests, UInt2x4_PackUnpackOddElements) { + // Test with non-multiple-of-4 element count + std::vector src_values = {3, 2, 1}; + std::vector packed(UInt2x4::CalcNumInt2Quads(src_values.size())); + + // Pack + bool pack_result = UInt2x4::Pack(gsl::make_span(packed), gsl::make_span(src_values)); + EXPECT_TRUE(pack_result); + + // Unpack + std::vector unpacked(src_values.size()); + bool unpack_result = UInt2x4::Unpack(gsl::make_span(unpacked), gsl::make_span(packed)); + EXPECT_TRUE(unpack_result); + + // Verify + for (size_t i = 0; i < src_values.size(); i++) { + EXPECT_EQ(unpacked[i], src_values[i]) << "Mismatch at index " << i; + } +} + +// ============================================== +// Additional edge case tests +// ============================================== + +TEST(Int2_Tests, Int2x4_AllSameValue) { + // All values are -2 (minimum signed value) + Int2x4 int2_min(-2, -2, -2, -2); + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(int2_min.GetElem(i), -2); + } + + // All values are 1 (maximum signed value) + Int2x4 int2_max(1, 1, 1, 1); + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(int2_max.GetElem(i), 1); + } +} + +TEST(Int2_Tests, UInt2x4_AllSameValue) { + // All values are 0 (minimum unsigned value) + UInt2x4 uint2_min(0, 0, 0, 0); + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(uint2_min.GetElem(i), 0); + } + + // All values are 3 (maximum unsigned value) + UInt2x4 uint2_max(3, 3, 3, 3); + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(uint2_max.GetElem(i), 3); + } +} + +TEST(Int2_Tests, Int2x4_BitManipulation) { + // Test that ToBits returns correct packed representation + Int2x4 int2(0, 1, -1, -2); // 0b00, 0b01, 0b11, 0b10 + // Expected: 0b10'11'01'00 = 0xB4 + EXPECT_EQ(static_cast(int2.ToBits()), 0xB4); +} + +// ============================================== +// TypeProto / TypeFromProto smoke checks +// ============================================== + +TEST(Int2_Tests, TensorTypeFromONNXEnum_Int2UInt2) { + auto* int2_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_INT2); + auto* uint2_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT2); + + ASSERT_NE(int2_type, nullptr); + ASSERT_NE(uint2_type, nullptr); + EXPECT_EQ(int2_type->GetElementType(), DataTypeImpl::GetType()); + EXPECT_EQ(uint2_type->GetElementType(), DataTypeImpl::GetType()); +} + +TEST(Int2_Tests, TypeFromProto_TensorProto_Int2) { + ONNX_NAMESPACE::TypeProto tp; + tp.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT2); + auto mltype = DataTypeImpl::TypeFromProto(tp); + ASSERT_NE(mltype, nullptr); + const auto* tensor_type = mltype->AsTensorType(); + ASSERT_NE(tensor_type, nullptr); + EXPECT_EQ(tensor_type->GetElementType(), DataTypeImpl::GetType()); +} + +TEST(Int2_Tests, TensorProtoRoundTrip_Int2) { + // Build a tiny TensorProto with raw_data containing 2 bytes (8 int2 elements packed) + ONNX_NAMESPACE::TensorProto proto; + proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT2); + proto.add_dims(8); + // pack values [1, -1, -2, 0, 1, -1, -2, 0] + std::array values = {1, -1, -2, 0, 1, -1, -2, 0}; + std::vector packed(Int2x4::CalcNumInt2Quads(values.size())); + ASSERT_TRUE(Int2x4::Pack(gsl::make_span(packed), gsl::make_span(values))); + proto.set_raw_data(packed.data(), packed.size() * sizeof(Int2x4)); + + Tensor result; + // Use CreateTensorFromTensorProto which pre-allocates the tensor with proper shape + ORT_THROW_IF_ERROR(utils::CreateTensorFromTensorProto(onnxruntime::Env::Default(), std::filesystem::path{}, proto, result)); + ASSERT_TRUE(result.IsDataType()); + const auto* data = result.Data(); + std::vector unpacked(values.size()); + ASSERT_TRUE(Int2x4::Unpack(gsl::make_span(unpacked), gsl::make_span(data, packed.size()))); + for (size_t i = 0; i < values.size(); ++i) { + EXPECT_EQ(unpacked[i], values[i]) << "Mismatch at index " << i; + } +} + +TEST(Int2_Tests, UInt2x4_BitManipulation) { + // Test that ToBits returns correct packed representation + UInt2x4 uint2(3, 2, 1, 0); // 0b11, 0b10, 0b01, 0b00 + // Expected: 0b00'01'10'11 = 0x1B + EXPECT_EQ(static_cast(uint2.ToBits()), 0x1B); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/onnx/tensorprotoutils.cc b/onnxruntime/test/onnx/tensorprotoutils.cc index bf2e19aa37371..fb8a55618df2f 100644 --- a/onnxruntime/test/onnx/tensorprotoutils.cc +++ b/onnxruntime/test/onnx/tensorprotoutils.cc @@ -14,6 +14,7 @@ #include "core/common/status.h" #include "core/framework/allocator.h" #include "core/framework/data_types.h" +#include "core/framework/int2.h" #include "core/common/endian.h" #include "core/framework/endian_utils.h" #include "core/graph/onnx_protobuf.h" @@ -122,6 +123,48 @@ void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); } +template <> +void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, + /*out*/ Int2x4* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + if (p_data == nullptr) { + ORT_CXX_API_THROW("nullptr == p_data", OrtErrorCode::ORT_FAIL); + } + + size_t num_packed_quads = (expected_num_elements + 3) / 4; + + if (num_packed_quads != raw_data_len) { + ORT_CXX_API_THROW("Unexpected number of packed int2 quads", OrtErrorCode::ORT_FAIL); + } + + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_quads); + gsl::span dst_span = gsl::make_span(p_data, num_packed_quads); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_quads); +} + +template <> +void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, + /*out*/ UInt2x4* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + if (p_data == nullptr) { + ORT_CXX_API_THROW("nullptr == p_data", OrtErrorCode::ORT_FAIL); + } + + size_t num_packed_quads = (expected_num_elements + 3) / 4; + + if (num_packed_quads != raw_data_len) { + ORT_CXX_API_THROW("Unexpected number of packed uint2 quads", OrtErrorCode::ORT_FAIL); + } + + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_quads); + gsl::span dst_span = gsl::make_span(p_data, num_packed_quads); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_quads); +} + #if !defined(DISABLE_FLOAT4_TYPES) template <> void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, @@ -373,6 +416,41 @@ DEFINE_UNPACK_TENSOR_FLOAT8(Float8E5M2FNUZ, TensorProto_DataType_FLOAT8E5M2FNUZ) DEFINE_UNPACK_TENSOR_INT4(Int4x2, TensorProto_DataType_INT4) DEFINE_UNPACK_TENSOR_INT4(UInt4x2, TensorProto_DataType_UINT4) +#define DEFINE_UNPACK_TENSOR_INT2(INT2_TYPE, ONNX_TYPE) \ + template <> \ + void UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \ + /*out*/ INT2_TYPE* p_data, size_t expected_num_elems) { \ + if (nullptr == p_data) { \ + const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); \ + if (size == 0) { \ + return; \ + } \ + ORT_CXX_API_THROW("p_data == nullptr, but size != 0", OrtErrorCode::ORT_INVALID_ARGUMENT); \ + } \ + if (ONNX_NAMESPACE::ONNX_TYPE != tensor.data_type()) { \ + ORT_CXX_API_THROW("TensorProto data type is not INT2", OrtErrorCode::ORT_INVALID_ARGUMENT); \ + } \ + \ + size_t expected_int2_quads = (expected_num_elems + 3) / 4; \ + \ + if (raw_data != nullptr) { \ + UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); \ + return; \ + } \ + \ + if (static_cast(tensor.int32_data_size()) != expected_int2_quads) { \ + ORT_CXX_API_THROW("UnpackTensor: the pre-allocated size does not match the size in proto", \ + OrtErrorCode::ORT_FAIL); \ + } \ + \ + for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { \ + p_data[i] = INT2_TYPE(static_cast(tensor.int32_data()[i])); \ + } \ + } + +DEFINE_UNPACK_TENSOR_INT2(Int2x4, TensorProto_DataType_INT2) +DEFINE_UNPACK_TENSOR_INT2(UInt2x4, TensorProto_DataType_UINT2) + #if !defined(DISABLE_FLOAT4_TYPES) template <> void UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, @@ -426,6 +504,13 @@ void UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_dat } \ break; +#define CASE_PROTO_TRACE_INT2(X) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + if (!CalcMemSizeForArrayWithAlignment((size + 3) / 4, 1, alignment, out)) { \ + ORT_CXX_API_THROW("Invalid TensorProto", OrtErrorCode::ORT_FAIL); \ + } \ + break; + #if !defined(DISABLE_FLOAT4_TYPES) #define CASE_PROTO_TRACE_FLOAT4(X) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ @@ -474,6 +559,8 @@ Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_p CASE_PROTO_TRACE(STRING, std::string); CASE_PROTO_TRACE_INT4(UINT4); CASE_PROTO_TRACE_INT4(INT4); + CASE_PROTO_TRACE_INT2(UINT2); + CASE_PROTO_TRACE_INT2(INT2); default: return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); } @@ -630,6 +717,8 @@ Status TensorProtoToMLValue(const onnx::TensorProto& tensor_proto, const MemBuff #endif CASE_PROTO(INT4, Int4x2); CASE_PROTO(UINT4, UInt4x2); + CASE_PROTO(INT2, Int2x4); + CASE_PROTO(UINT2, UInt2x4); case onnx::TensorProto_DataType::TensorProto_DataType_STRING: if (preallocated != nullptr) { OrtStatus* status = OrtInitializeBufferForTensor(preallocated, preallocated_size, ele_type); diff --git a/onnxruntime/test/unittest_util/base_tester.h b/onnxruntime/test/unittest_util/base_tester.h index 58b67a0d67d3c..79a74ef1651c5 100644 --- a/onnxruntime/test/unittest_util/base_tester.h +++ b/onnxruntime/test/unittest_util/base_tester.h @@ -700,6 +700,10 @@ class BaseTester { const int64_t expected_values_count = T::CalcNumInt4Pairs(shape.Size()); ORT_ENFORCE(expected_values_count == values_count, values_count, " input values doesn't match tensor size of ", expected_values_count); + } else if constexpr (std::is_same_v || std::is_same_v) { + const int64_t expected_values_count = T::CalcNumInt2Quads(shape.Size()); + ORT_ENFORCE(expected_values_count == values_count, values_count, + " input values doesn't match tensor size of ", expected_values_count); } #if !defined(DISABLE_FLOAT4_TYPES) else if constexpr (std::is_same_v) { diff --git a/onnxruntime/test/unittest_util/checkers.cc b/onnxruntime/test/unittest_util/checkers.cc index 7b2a5a4a4ff2f..88a6241bf7ee3 100644 --- a/onnxruntime/test/unittest_util/checkers.cc +++ b/onnxruntime/test/unittest_util/checkers.cc @@ -9,6 +9,7 @@ #include "core/graph/constants.h" #include "core/framework/TensorSeq.h" #include "core/framework/int4.h" +#include "core/framework/int2.h" #include "core/framework/float4.h" #include "test/unittest_util/framework_test_utils.h" #include "test/unittest_util/conversion.h" @@ -259,6 +260,44 @@ struct TensorCheck { } }; +template <> +struct TensorCheck { + void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, + const std::string& /*provider_type*/) const { + ORT_UNUSED_PARAMETER(params); + const Int2x4* cur_expected; + const Int2x4* cur_actual; + const auto size = narrow(actual.Shape().Size()); + cur_expected = expected.Data(); + cur_actual = actual.Data(); + + for (size_t i = 0; i < size; ++i) { + size_t r = i >> 2; + size_t c = i & 0x3; + EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; + } + } +}; + +template <> +struct TensorCheck { + void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, + const std::string& /*provider_type*/) const { + ORT_UNUSED_PARAMETER(params); + const UInt2x4* cur_expected; + const UInt2x4* cur_actual; + const auto size = narrow(actual.Shape().Size()); + cur_expected = expected.Data(); + cur_actual = actual.Data(); + + for (size_t i = 0; i < size; ++i) { + size_t r = i >> 2; + size_t c = i & 0x3; + EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; + } + } +}; + template <> struct TensorCheck { void operator()(const Tensor& expected, @@ -536,7 +575,7 @@ void Check(std::string_view name, const OrtValue& expected, const Tensor utils::MLTypeCallDispatcher IsResultExactlyMatch(const Tenso return std::make_pair(COMPARE_RESULT::SUCCESS, ""); } +template <> +std::pair IsResultExactlyMatch(const Tensor& outvalue, + const Tensor& expected_value) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const Int2x4* expected_output = expected_value.Data(); + const Int2x4* real_output = outvalue.Data(); + for (size_t di = 0; di != size1; ++di) { + size_t r = di >> 2; + size_t c = di & 0x3; + + if (expected_output[r].GetElem(c) != real_output[r].GetElem(c)) { + std::ostringstream oss; + oss << "expected " << static_cast(expected_output[r].GetElem(c)) + << ", got " << static_cast(real_output[r].GetElem(c)); + return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + } + } + return std::make_pair(COMPARE_RESULT::SUCCESS, ""); +} + +template <> +std::pair IsResultExactlyMatch(const Tensor& outvalue, + const Tensor& expected_value) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const UInt2x4* expected_output = expected_value.Data(); + const UInt2x4* real_output = outvalue.Data(); + for (size_t di = 0; di != size1; ++di) { + size_t r = di >> 2; + size_t c = di & 0x3; + + if (expected_output[r].GetElem(c) != real_output[r].GetElem(c)) { + std::ostringstream oss; + oss << "expected " << static_cast(expected_output[r].GetElem(c)) + << ", got " << static_cast(real_output[r].GetElem(c)); + return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + } + } + return std::make_pair(COMPARE_RESULT::SUCCESS, ""); +} + std::pair CompareFloat16Result(const Tensor& outvalue, const Tensor& expected_value, double per_sample_tolerance, double relative_per_sample_tolerance, @@ -356,6 +396,10 @@ std::pair CompareTwoTensors(const Tensor& outvalue, return IsResultExactlyMatch(outvalue, expected_tensor); } else if (outvalue.IsDataType()) { return IsResultExactlyMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return IsResultExactlyMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return IsResultExactlyMatch(outvalue, expected_tensor); } else if (outvalue.IsDataType()) { return CompareFloat16Result(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, post_processing);