From 40da67918e5a2a599ba7aaa79a0397759f57ad3a Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 16:36:59 -0700 Subject: [PATCH 01/72] Update include/framework/ with int4 --- .../onnxruntime/core/framework/data_types.h | 7 +- .../core/framework/data_types_internal.h | 26 +++- include/onnxruntime/core/framework/int4.h | 123 ++++++++++++++++++ include/onnxruntime/core/framework/tensor.h | 13 +- .../framework/to_tensor_proto_element_type.h | 9 ++ 5 files changed, 173 insertions(+), 5 deletions(-) create mode 100644 include/onnxruntime/core/framework/int4.h diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index f3942128077de..dad3f4769019e 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -15,6 +15,7 @@ #include "core/framework/endian.h" #include "core/framework/float8.h" #include "core/framework/float16.h" +#include "core/framework/int4.h" #include "core/graph/onnx_protobuf.h" #include "core/framework/to_tensor_proto_element_type.h" @@ -280,7 +281,8 @@ struct IsAnyOf { template struct IsTensorContainedType : public IsAnyOf struct IsSparseTensorContainedType : public IsAnyOf(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -153,6 +159,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \ retval = function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + retval = function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -203,7 +215,13 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ function(__VA_ARGS__); \ break; \ - default: \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4 \ + function(__VA_ARGS__); \ + break; \ + default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -251,6 +269,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ retval = function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + retval = function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h new file mode 100644 index 0000000000000..26be0c02fae81 --- /dev/null +++ b/include/onnxruntime/core/framework/int4.h @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "endian.h" +#include "core/common/common.h" +#include "core/common/gsl.h" + +namespace onnxruntime { +struct Int4x2 { + int8_t val_0 : 4; + int8_t val_1 : 4; + + Int4x2() : val_0{0}, val_1{0} {} + Int4x2(uint8_t bits) { + val_0 = static_cast(bits & 0xF); + val_1 = static_cast((bits >> 4) & 0xF); + } + Int4x2(int8_t lo, int8_t hi) : val_0{lo}, val_1{hi} {} + + inline int8_t operator[](size_t index) const { + assert(index <= 1); + return index == 0 ? val_0 : val_1; + } + + inline uint8_t ToBits() const { + return (static_cast(val_1) << 4) | (static_cast(val_0) & 0xF); + } + + static bool Unpack(gsl::span dst, gsl::span src) { + if (((dst.size() + 1) / 2) != src.size()) { + return false; + } + + for (size_t i = 0; i < dst.size(); i++) { + size_t r = i >> 1; // i / 2; + size_t c = i & 0x1; // i % 2; + dst[i] = src[r][c]; + } + + return true; + } + + static bool Pack(gsl::span dst, gsl::span src) { + if (((src.size() + 1) / 2) != dst.size()) { + return false; + } + + size_t src_i = 0; + size_t dst_i = 0; + + for (; src_i < src.size() - 1; src_i += 2) { + dst[dst_i++] = Int4x2(src[src_i], src[src_i + 1]); + } + + if (src_i < src.size()) { + dst[dst_i] = Int4x2(src[src_i], 0); + } + + return true; + } +}; + +static_assert(sizeof(Int4x2) == sizeof(int8_t)); + +struct UInt4x2 { + uint8_t val_0 : 4; + uint8_t val_1 : 4; + + UInt4x2() : val_0{0}, val_1{0} {} + UInt4x2(uint8_t bits) { + val_0 = bits & 0xF; + val_1 = (bits >> 4) & 0xF; + } + UInt4x2(uint8_t lo, uint8_t hi) : val_0{lo}, val_1{hi} {} + + inline uint8_t operator[](size_t index) const { + assert(index <= 1); + return index == 0 ? val_0 : val_1; + } + + inline uint8_t ToBits() const { + return (static_cast(val_1) << 4) | (static_cast(val_0) & 0xF); + } + + static bool Unpack(gsl::span dst, gsl::span src) { + if (((dst.size() + 1) / 2) != src.size()) { + return false; + } + + for (size_t i = 0; i < dst.size(); i++) { + size_t r = i >> 1; // i / 2; + size_t c = i & 0x1; // i % 2; + dst[i] = src[r][c]; + } + + return true; + } + + static bool Pack(gsl::span dst, gsl::span src) { + if (((src.size() + 1) / 2) != dst.size()) { + return false; + } + + size_t src_i = 0; + size_t dst_i = 0; + + for (; src_i < src.size() - 1; src_i += 2) { + dst[dst_i++] = UInt4x2(src[src_i], src[src_i + 1]); + } + + if (src_i < src.size()) { + dst[dst_i] = UInt4x2(src[src_i], 0); + } + + return true; + } +}; + +static_assert(sizeof(UInt4x2) == sizeof(uint8_t)); +} // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index a867ab6066485..3c3933024636e 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -200,7 +200,12 @@ class Tensor final { ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", "T ", "!=", dtype_); T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - return gsl::make_span(data, static_cast(shape_.Size())); + int64_t num_elems = shape_.Size(); + if (utils::IsPrimitiveDataType(dtype_) || utils::IsPrimitiveDataType(dtype_)) { + num_elems = (num_elems + 1) / 2; + } + + return gsl::make_span(data, static_cast(num_elems)); } template @@ -217,7 +222,11 @@ class Tensor final { ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", "T ", "!=", dtype_); const T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - return gsl::make_span(data, static_cast::size_type>(shape_.Size())); + int64_t num_elems = shape_.Size(); + if (utils::IsPrimitiveDataType(dtype_) || utils::IsPrimitiveDataType(dtype_)) { + num_elems = (num_elems + 1) / 2; + } + return gsl::make_span(data, static_cast::size_type>(num_elems)); } void* MutableDataRaw(MLDataType type) { diff --git a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h index 21253eb4a6e83..e9e28e4864a67 100644 --- a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h +++ b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h @@ -12,6 +12,7 @@ #include "core/framework/float8.h" #include "core/framework/float16.h" +#include "core/framework/int4.h" namespace onnxruntime { namespace utils { @@ -97,6 +98,14 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_INT4; +} +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_UINT4; +} } // namespace utils } // namespace onnxruntime From e3e8a6bd94c10570b0a2340abe4691acd4272b37 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 16:38:11 -0700 Subject: [PATCH 02/72] Update onnxruntime_c_api.h with int4 type --- include/onnxruntime/core/session/onnxruntime_c_api.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index e7b8f14871122..d2b4c0c0d7ef6 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -196,7 +196,10 @@ typedef enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, // Non-IEEE floating-point format based on IEEE754 single-precision ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of uint4 values (size == 1 byte) + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of int4 values (size == 1 byte) } ONNXTensorElementDataType; // Synced with onnx TypeProto oneof From 5e01e0f66e4e811d5c3364ded2048a4a4331a265 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 16:39:29 -0700 Subject: [PATCH 03/72] Update cpu_contrib_kernels.cc with int4 Q/DQ --- onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index f9d9b13f0fedc..f0e39779a0532 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -63,6 +63,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear); @@ -204,6 +206,8 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, From 44e0e0234ab0109c1c8d5a84fc8d781574612693 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 16:41:06 -0700 Subject: [PATCH 04/72] Update framework/data_types.cc with int4 types --- onnxruntime/core/framework/data_types.cc | 48 ++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index 6c4aec417a033..4558ed0db6f62 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -639,6 +639,8 @@ ORT_REGISTER_TENSOR_TYPE(Float8E4M3FNUZ); ORT_REGISTER_TENSOR_TYPE(Float8E5M2); ORT_REGISTER_TENSOR_TYPE(Float8E5M2FNUZ); #endif +ORT_REGISTER_TENSOR_TYPE(Int4x2); +ORT_REGISTER_TENSOR_TYPE(UInt4x2); #if !defined(DISABLE_SPARSE_TENSORS) ORT_REGISTER_SPARSE_TENSOR_TYPE(int32_t); @@ -663,6 +665,8 @@ ORT_REGISTER_SPARSE_TENSOR_TYPE(Float8E5M2); ORT_REGISTER_SPARSE_TENSOR_TYPE(Float8E5M2FNUZ); #endif +ORT_REGISTER_SPARSE_TENSOR_TYPE(Int4x2); +ORT_REGISTER_SPARSE_TENSOR_TYPE(UInt4x2); #endif #if !defined(DISABLE_ML_OPS) @@ -700,6 +704,9 @@ ORT_REGISTER_SEQ_TENSOR_TYPE(Float8E5M2FNUZ); #endif +ORT_REGISTER_SEQ_TENSOR_TYPE(Int4x2); +ORT_REGISTER_SEQ_TENSOR_TYPE(UInt4x2); + #if !defined(DISABLE_ML_OPS) ORT_REGISTER_SEQ(VectorMapStringToFloat); ORT_REGISTER_SEQ(VectorMapInt64ToFloat); @@ -725,7 +732,9 @@ ORT_REGISTER_SEQ(VectorMapInt64ToFloat); ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E4M3FN); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E4M3FNUZ); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2); \ - ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2FNUZ); + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2FNUZ); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int4x2); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2); #else @@ -743,7 +752,9 @@ ORT_REGISTER_SEQ(VectorMapInt64ToFloat); ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, uint32_t); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, uint64_t); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, MLFloat16); \ - ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, BFloat16); + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, BFloat16); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int4x2); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2); #endif @@ -808,6 +819,8 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_TENSOR_PROTO(Float8E5M2, reg_fn); REGISTER_TENSOR_PROTO(Float8E5M2FNUZ, reg_fn); #endif + REGISTER_TENSOR_PROTO(Int4x2, reg_fn); + REGISTER_TENSOR_PROTO(UInt4x2, reg_fn); #if !defined(DISABLE_SPARSE_TENSORS) REGISTER_SPARSE_TENSOR_PROTO(int32_t, reg_fn); @@ -830,6 +843,8 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_SPARSE_TENSOR_PROTO(Float8E5M2, reg_fn); REGISTER_SPARSE_TENSOR_PROTO(Float8E5M2FNUZ, reg_fn); #endif + REGISTER_SPARSE_TENSOR_PROTO(Int4x2, reg_fn); + REGISTER_SPARSE_TENSOR_PROTO(UInt4x2, reg_fn); #endif #if !defined(DISABLE_ML_OPS) @@ -867,6 +882,9 @@ void RegisterAllProtos(const std::function& reg_fn) { #endif + REGISTER_SEQ_TENSOR_PROTO(Int4x2, reg_fn); + REGISTER_SEQ_TENSOR_PROTO(UInt4x2, reg_fn); + #if !defined(DISABLE_ML_OPS) REGISTER_ONNX_PROTO(VectorMapStringToFloat, reg_fn); REGISTER_ONNX_PROTO(VectorMapInt64ToFloat, reg_fn); @@ -894,7 +912,9 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E4M3FN, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E4M3FNUZ, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2, reg_fn); \ - REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2FNUZ, reg_fn); + REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2FNUZ, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int4x2, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn); #else @@ -912,7 +932,9 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_OPTIONAL_PROTO(ORT_TYPE, uint32_t, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, uint64_t, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, MLFloat16, reg_fn); \ - REGISTER_OPTIONAL_PROTO(ORT_TYPE, BFloat16, reg_fn); + REGISTER_OPTIONAL_PROTO(ORT_TYPE, BFloat16, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int4x2, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn); #endif @@ -973,6 +995,10 @@ const char* DataTypeImpl::ToString(MLDataType type) { return "Float8E5M2"; case TensorProto_DataType_FLOAT8E5M2FNUZ: return "Float8E5M2FNUZ"; + case TensorProto_DataType_INT4: + return "Int4x2"; + case TensorProto_DataType_UINT4: + return "UInt4x2"; default: break; } @@ -1041,6 +1067,10 @@ const TensorTypeBase* DataTypeImpl::TensorTypeFromONNXEnum(int type) { return DataTypeImpl::GetTensorType()->AsTensorType(); #endif + case TensorProto_DataType_INT4: + return DataTypeImpl::GetTensorType()->AsTensorType(); + case TensorProto_DataType_UINT4: + return DataTypeImpl::GetTensorType()->AsTensorType(); default: ORT_NOT_IMPLEMENTED("tensor type ", type, " is not supported"); @@ -1090,6 +1120,10 @@ const SequenceTensorTypeBase* DataTypeImpl::SequenceTensorTypeFromONNXEnum(int t return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); #endif + case TensorProto_DataType_INT4: + return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); + case TensorProto_DataType_UINT4: + return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); default: ORT_NOT_IMPLEMENTED("sequence tensor type ", type, " is not supported"); @@ -1140,6 +1174,10 @@ const SparseTensorTypeBase* DataTypeImpl::SparseTensorTypeFromONNXEnum(int type) return DataTypeImpl::GetSparseTensorType()->AsSparseTensorType(); #endif + case TensorProto_DataType_INT4: + return DataTypeImpl::GetSparseTensorType()->AsSparseTensorType(); + case TensorProto_DataType_UINT4: + return DataTypeImpl::GetSparseTensorType()->AsSparseTensorType(); default: ORT_NOT_IMPLEMENTED("sparse tensor type ", type, " is not supported"); @@ -1183,6 +1221,8 @@ ORT_REGISTER_PRIM_TYPE(Float8E5M2); ORT_REGISTER_PRIM_TYPE(Float8E5M2FNUZ); #endif +ORT_REGISTER_PRIM_TYPE(Int4x2); +ORT_REGISTER_PRIM_TYPE(UInt4x2); namespace { template From ce03eb220e85c1a35555daef2028afcc1a5d87ab Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 16:43:20 -0700 Subject: [PATCH 05/72] Update onnxruntime map type info with int4 types --- onnxruntime/core/framework/onnxruntime_map_type_info.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc index 3963504273599..1370580bad4f6 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc @@ -78,6 +78,12 @@ ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) { case TensorType::TensorProto_DataType_FLOAT8E5M2FNUZ: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ; } // Non-IEEE floating-point format based on IEEE754 single-precision + case TensorType::TensorProto_DataType_INT4: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; + } // maps to a pair of int4 (size == 1 byte) + case TensorType::TensorProto_DataType_UINT4: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; + } // maps to a pair of uint4 (size == 1 byte) default: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } @@ -126,4 +132,4 @@ ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, ORT_API(void, OrtApis::ReleaseMapTypeInfo, _Frees_ptr_opt_ OrtMapTypeInfo* ptr) { std::unique_ptr p(ptr); -} \ No newline at end of file +} From e1590078df3f168b4415a01c4a978b741b153d7c Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 16:45:31 -0700 Subject: [PATCH 06/72] Update Tensor methods to calc int4 tensor data sizze --- onnxruntime/core/framework/tensor.cc | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index 36f03a9b1046a..8d83198ffe03d 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -34,6 +34,12 @@ size_t Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape if (shape_size > 0) { SafeInt len = 0; + + // TODO: Handle more cleanly. Add virtual function to MLDataType: ByteSizeFromShape(TensorShape) ?? + if (utils::IsPrimitiveDataType(elt_type) || utils::IsPrimitiveDataType(elt_type)) { + shape_size = (shape_size + 1) / 2; + } + if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), elt_type->Size(), &len)) ORT_THROW("tensor failed memory size calculation"); @@ -104,7 +110,13 @@ size_t Tensor::SizeInBytes() const { #else int64_t size = shape_.Size(); #endif - size_t ret; + size_t ret = 0; + + // TODO: Handle more cleanly. Add virtual function to MLDataType: ByteSizeFromShape(TensorShape) ?? + if (utils::IsPrimitiveDataType(dtype_) || utils::IsPrimitiveDataType(dtype_)) { + size = (size + 1) / 2; + } + if (!IAllocator::CalcMemSizeForArray(SafeInt(size), dtype_->Size(), &ret)) { ORT_THROW("tensor size overflow"); } From 46c3d0d6db1280ffe89e37f3ed98feb236418145 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 16:46:33 -0700 Subject: [PATCH 07/72] Update function to map tensor_proto int4 to onnxruntime enum --- onnxruntime/core/framework/tensor_type_and_shape.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 6e11bfe1ac8ea..418e46924fb9f 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -164,6 +164,12 @@ constexpr ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementData case o::TensorProto_DataType_BOOL: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; break; + case o::TensorProto_DataType_INT4: + type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; + break; + case o::TensorProto_DataType_UINT4: + type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; + break; default: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; break; @@ -365,4 +371,4 @@ ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, *out = ptr.release(); return nullptr; API_IMPL_END -} \ No newline at end of file +} From 583dae1e04227c5fcd52b82daaccb840fded86e8 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 16:48:33 -0700 Subject: [PATCH 08/72] Update tensorprotoutils to handle int4 protobufs --- .../core/framework/tensorprotoutils.cc | 183 ++++++++++++++++++ 1 file changed, 183 insertions(+) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 8a2db6d5728af..ce686a809438d 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -59,6 +59,20 @@ TensorProto ToScalarTensor(TensorProto_DataType datatype, int32_t value) { return t; \ } +#define TO_TENSOR_ORT_TYPE_INT4(TYPE) \ + template <> \ + TensorProto ToTensor(const onnxruntime::TYPE& value) { \ + return ToScalarTensor(ToTensorProtoElementType(), value.ToBits()); \ + } \ + template <> \ + TensorProto ToTensor(const std::vector& values) { \ + TensorProto t = ToTensorInitialize(ToTensorProtoElementType()); \ + for (const onnxruntime::TYPE& val : values) { \ + t.add_int32_data(val.ToBits()); \ + } \ + return t; \ + } + namespace ONNX_NAMESPACE { // Provide template specializations for onnxruntime-specific types. @@ -70,6 +84,8 @@ TO_TENSOR_ORT_TYPE(Float8E4M3FNUZ) TO_TENSOR_ORT_TYPE(Float8E5M2) TO_TENSOR_ORT_TYPE(Float8E5M2FNUZ) #endif +TO_TENSOR_ORT_TYPE_INT4(Int4x2) +TO_TENSOR_ORT_TYPE_INT4(UInt4x2) bool operator==(const ONNX_NAMESPACE::TensorShapeProto_Dimension& l, const ONNX_NAMESPACE::TensorShapeProto_Dimension& r) { @@ -125,6 +141,42 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t reinterpret_cast(p_data)); } +template <> +Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, + /*out*/ Int4x2* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); + + size_t num_packed_pairs = (expected_num_elements + 1) / 2; + ORT_RETURN_IF_NOT(num_packed_pairs == raw_data_len, "Unexpected number of packed int4 pairs"); + + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); + gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); + + return Status::OK(); +} + +template <> +Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, + /*out*/ UInt4x2* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); + + size_t num_packed_pairs = (expected_num_elements + 1) / 2; + ORT_RETURN_IF_NOT(num_packed_pairs == raw_data_len, "Unexpected number of packed int4 pairs"); + + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); + gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); + + return Status::OK(); +} + static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const ORTCHAR_T* tensor_proto_dir, std::basic_string& external_file_path, @@ -261,6 +313,48 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, reinterpret_cast(p_data)); } +template <> +Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, + const ORTCHAR_T* tensor_proto_dir, size_t expected_num_elements, + /*out*/ Int4x2* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); + std::vector unpacked_tensor; + ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); + + size_t num_packed_pairs = (expected_num_elements + 1) / 2; + ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); + + gsl::span src_span = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), num_packed_pairs); + gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); + + return Status::OK(); +} + +template <> +Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, + const ORTCHAR_T* tensor_proto_dir, size_t expected_num_elements, + /*out*/ UInt4x2* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); + std::vector unpacked_tensor; + ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); + + size_t num_packed_pairs = (expected_num_elements + 1) / 2; + ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); + + gsl::span src_span = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), num_packed_pairs); + gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); + + return Status::OK(); +} + #define INSTANTIATE_UNPACK_EXTERNAL_TENSOR(type) \ template Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto&, const ORTCHAR_T*, size_t, type*); @@ -284,6 +378,8 @@ INSTANTIATE_UNPACK_EXTERNAL_TENSOR(Float8E4M3FNUZ) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(Float8E5M2) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(Float8E5M2FNUZ) #endif +INSTANTIATE_UNPACK_EXTERNAL_TENSOR(Int4x2) +INSTANTIATE_UNPACK_EXTERNAL_TENSOR(UInt4x2) template <> Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& /*tensor*/, @@ -602,6 +698,62 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d #endif +// UnpackTensor +template <> +Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, + /*out*/ Int4x2* p_data, size_t expected_num_elems) { + if (nullptr == p_data) { + const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); + return size == 0 ? Status::OK() : Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); + } + if (ONNX_NAMESPACE::TensorProto_DataType_INT4 != tensor.data_type()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); + } + + size_t expected_int4_pairs = (expected_num_elems + 1) / 2; + + if (raw_data != nullptr) { + return UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); + } + + ORT_RETURN_IF_NOT(static_cast(tensor.int32_data_size()) == expected_int4_pairs, + "UnpackTensor: the pre-allocated size does not match the size in proto"); + + for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { + p_data[i] = Int4x2(static_cast(tensor.int32_data()[i])); + } + + return Status::OK(); +} + +// UnpackTensor +template <> +Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, + /*out*/ UInt4x2* p_data, size_t expected_num_elems) { + if (nullptr == p_data) { + const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); + return size == 0 ? Status::OK() : Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); + } + if (ONNX_NAMESPACE::TensorProto_DataType_INT4 != tensor.data_type()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); + } + + size_t expected_int4_pairs = (expected_num_elems + 1) / 2; + + if (raw_data != nullptr) { + return UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); + } + + ORT_RETURN_IF_NOT(static_cast(tensor.int32_data_size()) == expected_int4_pairs, + "UnpackTensor: the pre-allocated size does not match the size in proto"); + + for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { + p_data[i] = UInt4x2(static_cast(tensor.int32_data()[i])); + } + + return Status::OK(); +} + // UnpackTensor from raw data, external data or the type specific data field. // Uses the model path to construct the full path for loading external data. In case when model_path is empty // it uses current directory. @@ -651,6 +803,8 @@ INSTANTIATE_UNPACK_TENSOR(Float8E4M3FNUZ) INSTANTIATE_UNPACK_TENSOR(Float8E5M2) INSTANTIATE_UNPACK_TENSOR(Float8E5M2FNUZ) #endif +INSTANTIATE_UNPACK_TENSOR(Int4x2) +INSTANTIATE_UNPACK_TENSOR(UInt4x2) #define CASE_PROTO_TRACE(X, Y) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ @@ -659,6 +813,13 @@ INSTANTIATE_UNPACK_TENSOR(Float8E5M2FNUZ) } \ break; +#define CASE_PROTO_TRACE_INT4(X) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + if (!IAllocator::CalcMemSizeForArrayWithAlignment((size + 1) / 2, 1, out)) { \ + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid TensorProto"); \ + } \ + break; + template common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out) { const auto& dims = tensor_proto.dims(); @@ -692,6 +853,8 @@ common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& CASE_PROTO_TRACE(FLOAT8E5M2, Float8E5M2); CASE_PROTO_TRACE(FLOAT8E5M2FNUZ, Float8E5M2FNUZ); #endif + CASE_PROTO_TRACE_INT4(UINT4); + CASE_PROTO_TRACE_INT4(INT4); default: return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); } @@ -998,6 +1161,8 @@ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, CASE_PROTO(FLOAT8E5M2, Float8E5M2); CASE_PROTO(FLOAT8E5M2FNUZ, Float8E5M2FNUZ); #endif + CASE_PROTO(INT4, Int4x2); + CASE_PROTO(UINT4, UInt4x2); case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_STRING: ORT_RETURN_IF_ERROR(UnpackTensor(tensor_proto, raw_data, raw_data_len, static_cast(preallocated), @@ -1053,6 +1218,8 @@ ONNXTensorElementDataType CApiElementTypeFromProtoType(int type) { CASE_TYPE(FLOAT8E5M2) CASE_TYPE(FLOAT8E5M2FNUZ) #endif + CASE_TYPE(UINT4) + CASE_TYPE(INT4) default: return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } @@ -1570,6 +1737,20 @@ template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::T break; \ } +#define CASE_UNPACK_INT4(TYPE, ELEMENT_TYPE, DATA_SIZE) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \ + TensorShape tensor_shape = GetTensorShapeFromTensorProto(initializer); \ + size_t element_count = tensor_shape.Size(); \ + size_t packed_element_count = (element_count + 1) / 2; \ + unpacked_tensor.resize(packed_element_count * sizeof(ELEMENT_TYPE)); \ + return onnxruntime::utils::UnpackTensor( \ + initializer, \ + initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \ + initializer.has_raw_data() ? initializer.raw_data().size() : 0, \ + reinterpret_cast(unpacked_tensor.data()), element_count); \ + break; \ + } + Status UnpackInitializerData(const onnx::TensorProto& initializer, const Path& model_path, std::vector& unpacked_tensor) { @@ -1604,6 +1785,8 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer, CASE_UNPACK(FLOAT8E5M2, onnxruntime::Float8E5M2, int32_data_size); CASE_UNPACK(FLOAT8E5M2FNUZ, onnxruntime::Float8E5M2FNUZ, int32_data_size); #endif + CASE_UNPACK_INT4(INT4, Int4x2, int32_data_size); + CASE_UNPACK_INT4(UINT4, UInt4x2, int32_data_size); default: break; } From 0009a47c3653a777b1c21cbcfcf38f540f7843f9 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 16:50:03 -0700 Subject: [PATCH 09/72] Add functions to map Int4x2 to an onnxruntime tensor type enum --- onnxruntime/core/framework/utils.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index f0b1b9109d405..fd14eeeb33d27 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -222,6 +222,16 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType #endif +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; +} + +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; +} + int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); #ifdef ENABLE_TRAINING From d11a3d495eb44b7d57e4b69d3db337155d8bd144 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 16:52:26 -0700 Subject: [PATCH 10/72] Update com.microsoft.DequantizeLinear schema to support int4 types for input --- onnxruntime/core/graph/contrib_ops/quantization_defs.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 47f61a43458ed..80dab9594dd79 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -206,7 +206,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(DequantizeLinear, 1, .Output(0, "y", "N-D full precision output tensor. It has same shape as input 'x'.", "T2") .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", - "tensor(uint16)", "tensor(int32)"}, + "tensor(uint16)", "tensor(int32)", "tensor(int4)", + "tensor(uint4)"}, "Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, " "16-bit integer tensors, or 32-bit signed integer tensors.") .TypeConstraint("T2", {"tensor(float16)", "tensor(float)"}, From 208c4037177c74f249cbcae849e6d7ff3190fc68 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 16:55:40 -0700 Subject: [PATCH 11/72] Add option to disable int4 type in Conv and MatMul qdq node group selectors --- .../selectors_actions/qdq_selectors.cc | 14 ++++++++++++ .../selectors_actions/qdq_selectors.h | 22 ++++++++++++------- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 6b4f62ae1343d..d4879376b34ad 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -20,6 +20,11 @@ constexpr bool Is16BitIntType(int32_t data_type) { (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16); } +constexpr bool Is4BitIntType(int32_t data_type) { + return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4) || + (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4); +} + // adjust for an optional input/output that has an entry but does not exist int NumActualValues(const Node& node, bool input) { const auto& defs = input ? node.InputDefs() : node.OutputDefs(); @@ -312,6 +317,10 @@ bool ConvNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + if (!allow_4bit_weight_ && Is4BitIntType(dt_weight)) { + return false; + } + if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { if (!int8_allowed_ || dt_weight != dt_input) { return false; @@ -359,6 +368,11 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + // 4-bit int types must be explicitly allowed. + if (!allow_4bit_ && (Is4BitIntType(dt_input) || Is4BitIntType(dt_weight))) { + return false; + } + // potential match for QLinearMatMul or MatMulIntegerToFloat bool qlinear = !q_nodes.empty(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 5d550669e2e86..495f26294266c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -131,8 +131,8 @@ class SplitNodeGroupSelector : public NodeGroupSelector { class ConvNodeGroupSelector : public NodeGroupSelector { public: // default to 'true' - ConvNodeGroupSelector(bool int8_allowed = true, bool allow_16bit = true) - : int8_allowed_(int8_allowed), allow_16bit_(allow_16bit) {} + ConvNodeGroupSelector(bool int8_allowed = true, bool allow_16bit = true, bool allow_4bit_weight = true) + : int8_allowed_(int8_allowed), allow_16bit_(allow_16bit), allow_4bit_weight_(allow_4bit_weight) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -141,6 +141,7 @@ class ConvNodeGroupSelector : public NodeGroupSelector { bool int8_allowed_; bool allow_16bit_; + bool allow_4bit_weight_; }; class WhereNodeGroupSelector : public NodeGroupSelector { @@ -172,10 +173,12 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { public: MatMulNodeGroupSelector(bool int8_allowed = true, bool matmulintegertofloat_allowed = false, - bool allow_16bit = true) + bool allow_16bit = true, + bool allow_4bit = true) : int8_allowed_(int8_allowed), matmulintegertofloat_allowed_(matmulintegertofloat_allowed), - allow_16bit_(allow_16bit) { + allow_16bit_(allow_16bit), + allow_4bit_(allow_4bit) { } private: @@ -185,6 +188,7 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { bool int8_allowed_; bool matmulintegertofloat_allowed_; bool allow_16bit_; + bool allow_4bit_; }; // Input: DQ nodes for A, B and optional C @@ -316,8 +320,8 @@ class SplitSelector : public BaseSelector { // DQ nodes for X, W and optionally B -> node -> Q class ConvSelector : public BaseSelector { public: - ConvSelector(bool int8_allowed = false, bool allow_16bit = false) - : BaseSelector(std::make_unique(int8_allowed, allow_16bit)) {} + ConvSelector(bool int8_allowed = false, bool allow_16bit = false, bool allow_4bit_weight = false) + : BaseSelector(std::make_unique(int8_allowed, allow_16bit, allow_4bit_weight)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; @@ -331,9 +335,11 @@ class WhereSelector : public BaseSelector { // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not class MatMulSelector : public BaseSelector { public: - MatMulSelector(bool int8_allowed, bool allow_16bit = false) + MatMulSelector(gsl::span compatible_providers, bool int8_allowed, bool allow_16bit = false, + bool allow_4bit = false) : BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true, - allow_16bit)) {} + allow_16bit, allow_4bit), + compatible_providers) {} }; // Input: DQ nodes for A, B and optional C From f91ae6976325cbf22a5df1c649bea63c661db7c1 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 17:07:41 -0700 Subject: [PATCH 12/72] Add DequantizeLinear with int4 support (missing block quant) --- .../providers/cpu/cpu_execution_provider.cc | 7 +++ .../cpu/quantization/quantize_linear.cc | 47 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 8a270a05d7287..be5eea06fc7c7 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -5,6 +5,7 @@ #include #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" +#include "core/framework/int4.h" #include "core/mlas/inc/mlas.h" #ifndef DISABLE_CONTRIB_OPS @@ -1070,6 +1071,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint16_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int16_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Int4x2, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2, DequantizeLinear); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FNUZ, DequantizeLinear); @@ -2655,6 +2658,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { DequantizeLinear)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index d8924551e5292..f98959b713cdc 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -5,6 +5,7 @@ #include "core/framework/element_type_lists.h" #include "core/framework/float8.h" #include "core/framework/float16.h" +#include "core/framework/int4.h" #include "core/framework/op_kernel.h" #include "core/providers/common.h" #include "core/mlas/inc/mlas.h" @@ -126,6 +127,8 @@ REGISTER_DEQUANTIZELINEAR(uint8_t) REGISTER_DEQUANTIZELINEAR(int16_t) REGISTER_DEQUANTIZELINEAR(uint16_t) REGISTER_DEQUANTIZELINEAR(int32_t) +REGISTER_DEQUANTIZELINEAR(Int4x2) +REGISTER_DEQUANTIZELINEAR(UInt4x2) #if !defined(DISABLE_FLOAT8_TYPES) REGISTER_DEQUANTIZELINEAR(Float8E4M3FN) REGISTER_DEQUANTIZELINEAR(Float8E4M3FNUZ) @@ -199,6 +202,24 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( .TypeConstraint("T2", DataTypeImpl::GetTensorType()), DequantizeLinear); +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + Int4x2, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + DequantizeLinear, + 1, + UInt4x2, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + DequantizeLinear); + } // namespace contrib #endif // !defined(DISABLE_CONTRIB_OPS) @@ -217,6 +238,32 @@ struct DequantizeLinearApply { } }; +#define DEQUANTIZE_LINEAR_APPLY_INT4(T) \ + template \ + struct DequantizeLinearApply { \ + void op(int64_t N, int64_t broadcast_dim, int64_t block_size, const T* input, const OutT* scale, OutT* output, const T* zero_point) { \ + size_t input_index = 0; \ + for (size_t n = 0; n < static_cast(N); n++) { \ + for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ + size_t bd_i = bd >> 1; /*bd / 2*/ \ + size_t bd_j = bd & 0x1; /*bd % 2*/ \ + auto zp = zero_point ? static_cast(zero_point[bd_i][bd_j]) : 0; \ + auto sc = static_cast(scale[bd]); \ + for (size_t bs = 0; bs < static_cast(block_size); bs++) { \ + size_t input_i = input_index >> 1; \ + size_t input_j = input_index & 0x1; \ + *output++ = static_cast(static_cast(static_cast(input[input_i][input_j]) - zp) * sc); \ + input_index += 1; \ + } \ + } \ + } \ + assert(input_index == static_cast(N * broadcast_dim * block_size)); \ + } \ + }; + +DEQUANTIZE_LINEAR_APPLY_INT4(Int4x2); +DEQUANTIZE_LINEAR_APPLY_INT4(UInt4x2); + #if !defined(DISABLE_FLOAT8_TYPES) #define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ From 7323793a742de92c1f0dedeca16536c7c0f61804 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 17:24:55 -0700 Subject: [PATCH 13/72] update transpose helper to support int4 --- .../core/providers/cpu/tensor/transpose.cc | 111 +++++++++++------- .../core/providers/cpu/tensor/transpose.h | 3 +- 2 files changed, 73 insertions(+), 41 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index ec4624cf59ae6..3aec1605527be 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -336,38 +336,85 @@ bool IsTransposeReshape(const gsl::span& perm, gsl::span& permutations, const Tensor& input, Tensor& output, + const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { + TensorShape shape = input_shape_override ? *input_shape_override : input.Shape(); + + if (IsTransposeReshape(permutations, shape.GetDims())) { + // As long as the dims with values > 1 stay in the same order, it's a reshape. + // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). + CopyCpuTensor(&input, &output); + return Status::OK(); + } + + size_t from = 0, to = 0; + bool moving_single_axis = IsTransposeMovingSingleAxis(permutations, from, to); + + if (moving_single_axis && !input.IsDataTypeString()) { + SingleAxisTranspose(permutations, input, output, from, to, input_shape_override, tp); + return Status::OK(); + } + + // fall back to default implementation + return DoUntypedTranspose(permutations, input, output, input_shape_override); +} + +template +static Status UnpackInt4Tensor(const Tensor& src, Tensor& dst, AllocatorPtr cpu_allocator) { + static_assert(sizeof(PackedType) == 1); + static_assert(sizeof(UnpackedType) == 1); + + MLDataType int8_elem_type = DataTypeImpl::GetType(); + const TensorShape& shape = src.Shape(); + Tensor int8_tensor(int8_elem_type, shape, cpu_allocator); + + ORT_RETURN_IF_NOT(PackedType::Unpack(int8_tensor.MutableDataAsSpan(), src.DataAsSpan()), + "Failed to unpack Int4x2 Tensor to an int8_t Tensor"); + + dst = std::move(int8_tensor); + + return Status::OK(); +} + //`input_shape_override` overrides the shape of `input` for compute purposes. Status TransposeBase::DoTranspose(const gsl::span& permutations, const Tensor& input, Tensor& output, - const TensorShape* input_shape_override) { - Status status = Status::OK(); - + const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { auto input_type = input.DataType(); auto output_type = output.DataType(); if (input_type != output_type) { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Mismatched data types between input and output Tensors. ", - input_type, " != ", output_type); - } else { - TensorShape shape = input_shape_override ? *input_shape_override : input.Shape(); - if (IsTransposeReshape(permutations, shape.GetDims())) { - // As long as the dims with values > 1 stay in the same order, it's a reshape. - // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). - CopyCpuTensor(&input, &output); - return Status::OK(); - } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Mismatched data types between input and output Tensors. ", + input_type, " != ", output_type); + } + if (input.IsDataType()) { + // Convert to Tensor, transpose, and then repack back to Int4x2. + AllocatorPtr cpu_allocator = std::make_shared(); + Tensor input_unpacked; + Tensor output_unpacked(DataTypeImpl::GetType(), output.Shape(), cpu_allocator); - size_t from = 0, to = 0; - bool moving_single_axis = IsTransposeMovingSingleAxis(permutations, from, to); + ORT_RETURN_IF_ERROR((UnpackInt4Tensor(input, input_unpacked, cpu_allocator))); + ORT_RETURN_IF_ERROR(TransposeImpl(permutations, input_unpacked, output_unpacked, input_shape_override, tp)); + ORT_RETURN_IF_NOT(Int4x2::Pack(output.MutableDataAsSpan(), output_unpacked.DataAsSpan()), + "Failed to pack Tensor into Tensor"); - if (moving_single_axis && !input.IsDataTypeString()) { - SingleAxisTranspose(permutations, input, output, from, to, input_shape_override); - } else { - // fall back to default implementation - status = DoUntypedTranspose(permutations, input, output, input_shape_override); - } + return Status::OK(); } - return status; + if (input.IsDataType()) { + // Convert to Tensor, transpose, and then repack back to UInt4x2. + AllocatorPtr cpu_allocator = std::make_shared(); + Tensor input_unpacked; + Tensor output_unpacked(DataTypeImpl::GetType(), output.Shape(), cpu_allocator); + + ORT_RETURN_IF_ERROR((UnpackInt4Tensor(input, input_unpacked, cpu_allocator))); + ORT_RETURN_IF_ERROR(TransposeImpl(permutations, input_unpacked, output_unpacked, input_shape_override, tp)); + ORT_RETURN_IF_NOT(UInt4x2::Pack(output.MutableDataAsSpan(), output_unpacked.DataAsSpan()), + "Failed to pack Tensor into Tensor"); + + return Status::OK(); + } + + return TransposeImpl(permutations, input, output, input_shape_override, tp); } Status Transpose::Compute(OpKernelContext* ctx) const { @@ -388,27 +435,11 @@ Status Transpose::Compute(OpKernelContext* ctx) const { TensorShape output_shape{output_dims}; Tensor& Y = *ctx->Output(0, output_shape); - if (output_shape.Size() == 0) + if (output_shape.Size() == 0) { return Status::OK(); - - if (IsTransposeReshape(*p_perm, input_dims)) { - // As long as the dims with values > 1 stay in the same order, it's a reshape. - // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). - CopyCpuTensor(&X, &Y); - return Status::OK(); - } - - size_t from = 0, to = 0; - bool moving_single_axis = IsTransposeMovingSingleAxis(*p_perm, from, to); - - if (moving_single_axis && !X.IsDataTypeString()) { - SingleAxisTranspose(*p_perm, X, Y, from, to, nullptr, ctx->GetOperatorThreadPool()); - } else { - // fall back to default implementation - status = DoUntypedTranspose(*p_perm, X, Y); } - return status; + return DoTranspose(*p_perm, X, Y, nullptr, ctx->GetOperatorThreadPool()); } ONNX_CPU_OPERATOR_VERSIONED_KERNEL( diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.h b/onnxruntime/core/providers/cpu/tensor/transpose.h index 133b35ac80fe5..fda41c28a2567 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.h +++ b/onnxruntime/core/providers/cpu/tensor/transpose.h @@ -33,7 +33,8 @@ class TransposeBase { Both Tensors must have the same data type. `input_shape_override` overrides the shape of `input` for compute purposes. */ static Status DoTranspose(const gsl::span& permutations, const Tensor& input, Tensor& output, - const TensorShape* input_shape_override = nullptr); + const TensorShape* input_shape_override = nullptr, + concurrency::ThreadPool* tp = nullptr); protected: TransposeBase(const OpKernelInfo& info) { From 8c79905af76302932c25c66d09a00597ee237b71 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 17:39:46 -0700 Subject: [PATCH 14/72] Update provider bridge with int4 apis --- .../core/providers/shared_library/provider_api.h | 12 ++++++++++-- .../shared_library/provider_bridge_provider.cc | 12 ++++++++++++ .../providers/shared_library/provider_interfaces.h | 12 ++++++++++++ onnxruntime/core/session/provider_bridge_ort.cc | 12 ++++++++++++ 4 files changed, 46 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 1cebe4a256fd4..5cc5c2302df6d 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -24,6 +24,7 @@ #include "core/framework/allocator.h" #include "core/framework/float8.h" #include "core/framework/float16.h" +#include "core/framework/int4.h" #include "core/framework/tensor_shape.h" #include "core/providers/providers.h" #include "core/common/path_string.h" @@ -68,7 +69,9 @@ enum TensorProto_DataType : int { TensorProto_DataType_FLOAT8E4M3FN = 17, TensorProto_DataType_FLOAT8E4M3FNUZ = 18, TensorProto_DataType_FLOAT8E5M2 = 19, - TensorProto_DataType_FLOAT8E5M2FNUZ = 20 + TensorProto_DataType_FLOAT8E5M2FNUZ = 20, + TensorProto_DataType_UINT4 = 21, + TensorProto_DataType_INT4 = 22, }; enum TensorProto_DataLocation : int { @@ -86,7 +89,8 @@ enum Version : int { IR_VERSION_2019_9_19 = 6, IR_VERSION_2020_5_8 = 7, IR_VERSION_2021_7_31 = 8, - IR_VERSION = 9 + IR_VERSION_2023_5_5 = 9, + IR_VERSION = 10 }; enum OperatorStatus : int { @@ -345,6 +349,10 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { template <> constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ; } #endif +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; } +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; } } // namespace utils // This is a replacement for Ort::InitApi() to be called before any other onnxruntime API calls. diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 7b73ab36b3742..112dc9abb5f97 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -167,6 +167,10 @@ MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->Data template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_Float8E5M2FNUZ(); } #endif +template <> +MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_Int4x2(); } +template <> +MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_UInt4x2(); } template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_string(); } @@ -207,6 +211,10 @@ MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost() template <> MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_Float8E5M2FNUZ(); } #endif +template <> +MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_Int4x2(); } +template <> +MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_UInt4x2(); } #if !defined(DISABLE_SPARSE_TENSORS) template <> @@ -248,6 +256,10 @@ MLDataType DataTypeImpl::GetSparseTensorType() { return Provider_Get template <> MLDataType DataTypeImpl::GetSparseTensorType() { return Provider_GetHost()->DataTypeImpl__GetSparseTensorType_Float8E5M2FNUZ(); } #endif +template <> +MLDataType DataTypeImpl::GetSparseTensorType() { return Provider_GetHost()->DataTypeImpl__GetSparseTensorType_Int4x2(); } +template <> +MLDataType DataTypeImpl::GetSparseTensorType() { return Provider_GetHost()->DataTypeImpl__GetSparseTensorType_UInt4x2(); } #endif diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 8c8d5b1fd460a..90aa577fba0f6 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -602,6 +602,8 @@ struct ProviderHost { virtual MLDataType DataTypeImpl__GetType_Float8E5M2() = 0; virtual MLDataType DataTypeImpl__GetType_Float8E5M2FNUZ() = 0; #endif + virtual MLDataType DataTypeImpl__GetType_Int4x2() = 0; + virtual MLDataType DataTypeImpl__GetType_UInt4x2() = 0; virtual MLDataType DataTypeImpl__GetTensorType_bool() = 0; virtual MLDataType DataTypeImpl__GetTensorType_int8() = 0; @@ -622,6 +624,8 @@ struct ProviderHost { virtual MLDataType DataTypeImpl__GetTensorType_Float8E5M2() = 0; virtual MLDataType DataTypeImpl__GetTensorType_Float8E5M2FNUZ() = 0; #endif + virtual MLDataType DataTypeImpl__GetTensorType_Int4x2() = 0; + virtual MLDataType DataTypeImpl__GetTensorType_UInt4x2() = 0; #if !defined(DISABLE_SPARSE_TENSORS) virtual MLDataType DataTypeImpl__GetSparseTensorType_bool() = 0; @@ -644,6 +648,8 @@ struct ProviderHost { virtual MLDataType DataTypeImpl__GetSparseTensorType_Float8E5M2() = 0; virtual MLDataType DataTypeImpl__GetSparseTensorType_Float8E5M2FNUZ() = 0; #endif + virtual MLDataType DataTypeImpl__GetSparseTensorType_Int4x2() = 0; + virtual MLDataType DataTypeImpl__GetSparseTensorType_UInt4x2() = 0; #endif virtual const char* DataTypeImpl__ToString(MLDataType type) = 0; @@ -943,6 +949,8 @@ struct ProviderHost { virtual Float8E5M2* Tensor__MutableData_Float8E5M2(Tensor* p) = 0; virtual Float8E5M2FNUZ* Tensor__MutableData_Float8E5M2FNUZ(Tensor* p) = 0; #endif + virtual Int4x2* Tensor__MutableData_Int4x2(Tensor* p) = 0; + virtual UInt4x2* Tensor__MutableData_UInt4x2(Tensor* p) = 0; virtual const bool* Tensor__Data_bool(const Tensor* p) = 0; virtual const int8_t* Tensor__Data_int8(const Tensor* p) = 0; @@ -964,6 +972,8 @@ struct ProviderHost { virtual const Float8E5M2* Tensor__Data_Float8E5M2(const Tensor* p) = 0; virtual const Float8E5M2FNUZ* Tensor__Data_Float8E5M2FNUZ(const Tensor* p) = 0; #endif + virtual const Int4x2* Tensor__Data_Int4x2(const Tensor* p) = 0; + virtual const UInt4x2* Tensor__Data_UInt4x2(const Tensor* p) = 0; virtual gsl::span Tensor__DataAsSpan_int64(const Tensor* p) = 0; @@ -995,6 +1005,8 @@ struct ProviderHost { virtual bool Tensor__IsDataType_Float8E5M2(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataType_Float8E5M2FNUZ(const Tensor* p) noexcept = 0; #endif + virtual bool Tensor__IsDataType_Int4x2(const Tensor* p) noexcept = 0; + virtual bool Tensor__IsDataType_UInt4x2(const Tensor* p) noexcept = 0; virtual const TensorShape& Tensor__Shape(const Tensor* p) = 0; virtual void Tensor__Reshape(Tensor* p, const TensorShape& new_shape) = 0; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index fda41161ac40a..a5dce2e9cd3d7 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -797,6 +797,8 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetType_Float8E5M2() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetType_Float8E5M2FNUZ() override { return DataTypeImpl::GetType(); } #endif + MLDataType DataTypeImpl__GetType_Int4x2() override { return DataTypeImpl::GetType(); } + MLDataType DataTypeImpl__GetType_UInt4x2() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetTensorType_bool() override { return DataTypeImpl::GetTensorType(); } MLDataType DataTypeImpl__GetTensorType_int8() override { return DataTypeImpl::GetTensorType(); } @@ -818,6 +820,8 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetTensorType_Float8E5M2() override { return DataTypeImpl::GetTensorType(); } MLDataType DataTypeImpl__GetTensorType_Float8E5M2FNUZ() override { return DataTypeImpl::GetTensorType(); } #endif + MLDataType DataTypeImpl__GetTensorType_Int4x2() override { return DataTypeImpl::GetTensorType(); } + MLDataType DataTypeImpl__GetTensorType_UInt4x2() override { return DataTypeImpl::GetTensorType(); } #if !defined(DISABLE_SPARSE_TENSORS) MLDataType DataTypeImpl__GetSparseTensorType_bool() override { return DataTypeImpl::GetSparseTensorType(); } @@ -840,6 +844,8 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetSparseTensorType_Float8E5M2() override { return DataTypeImpl::GetSparseTensorType(); } MLDataType DataTypeImpl__GetSparseTensorType_Float8E5M2FNUZ() override { return DataTypeImpl::GetSparseTensorType(); } #endif + MLDataType DataTypeImpl__GetSparseTensorType_Int4x2() override { return DataTypeImpl::GetSparseTensorType(); } + MLDataType DataTypeImpl__GetSparseTensorType_UInt4x2() override { return DataTypeImpl::GetSparseTensorType(); } #endif const char* DataTypeImpl__ToString(MLDataType type) override { return DataTypeImpl::ToString(type); } @@ -1201,6 +1207,8 @@ struct ProviderHostImpl : ProviderHost { Float8E5M2* Tensor__MutableData_Float8E5M2(Tensor* p) override { return p->MutableData(); } Float8E5M2FNUZ* Tensor__MutableData_Float8E5M2FNUZ(Tensor* p) override { return p->MutableData(); } #endif + Int4x2* Tensor__MutableData_Int4x2(Tensor* p) override { return p->MutableData(); } + UInt4x2* Tensor__MutableData_UInt4x2(Tensor* p) override { return p->MutableData(); } const bool* Tensor__Data_bool(const Tensor* p) override { return p->Data(); } const int8_t* Tensor__Data_int8(const Tensor* p) override { return p->Data(); } @@ -1222,6 +1230,8 @@ struct ProviderHostImpl : ProviderHost { const Float8E5M2* Tensor__Data_Float8E5M2(const Tensor* p) override { return p->Data(); } const Float8E5M2FNUZ* Tensor__Data_Float8E5M2FNUZ(const Tensor* p) override { return p->Data(); } #endif + const Int4x2* Tensor__Data_Int4x2(const Tensor* p) override { return p->Data(); } + const UInt4x2* Tensor__Data_UInt4x2(const Tensor* p) override { return p->Data(); } gsl::span Tensor__DataAsSpan_int64(const Tensor* p) override { return p->DataAsSpan(); } @@ -1251,6 +1261,8 @@ struct ProviderHostImpl : ProviderHost { bool Tensor__IsDataType_Float8E5M2(const Tensor* p) noexcept override { return p->IsDataType(); } bool Tensor__IsDataType_Float8E5M2FNUZ(const Tensor* p) noexcept override { return p->IsDataType(); } #endif + bool Tensor__IsDataType_Int4x2(const Tensor* p) noexcept override { return p->IsDataType(); } + bool Tensor__IsDataType_UInt4x2(const Tensor* p) noexcept override { return p->IsDataType(); } const TensorShape& Tensor__Shape(const Tensor* p) override { return p->Shape(); } void Tensor__Reshape(Tensor* p, const TensorShape& new_shape) override { return p->Reshape(new_shape); } From fc695caa318df0f406f0baea774f0c80aa1d52b9 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 17:50:47 -0700 Subject: [PATCH 15/72] Update quantizer tool with int4 --- .../tools/quantization/base_quantizer.py | 29 ++++++++--- .../python/tools/quantization/quant_utils.py | 50 ++++++++++++++++++- 2 files changed, 69 insertions(+), 10 deletions(-) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 625cab25b9c46..b6d0ed4d8adc2 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -339,6 +339,10 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}" f"\nraw={str(q_weight_initializer)[:200]}." ) + elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): + q_weight_initializer = onnx.helper.make_tensor( + q_weight_name, qType, weight.dims, q_weight_data + ) else: q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape( weight.dims @@ -396,7 +400,9 @@ def quantize_weight_per_channel_impl( symmetric = quant_overrides_for_channels[0].get( "symmetric", - (self.is_weight_symmetric or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN)), + (self.is_weight_symmetric or weight_qType in (onnx.TensorProto.INT8, + onnx.TensorProto.FLOAT8E4M3FN, + onnx.TensorProto.INT4)), ) reduce_range = quant_overrides_for_channels[0].get("reduce_range", self.reduce_range and reduce_range) zero_point_list = [] @@ -447,7 +453,8 @@ def quantize_weight_per_channel_impl( quantized_per_channel_data_list.append(quantized_per_channel_data) # combine per_channel_data into one - reshape_dims = list(weights.shape) # deep copy + weights_shape = list(weights.shape) + reshape_dims = list(weights_shape) # deep copy reshape_dims[channel_axis] = 1 # only one per channel for reshape quantized_weights = np.asarray(quantized_per_channel_data_list[0]).reshape(reshape_dims) for i in range(1, len(quantized_per_channel_data_list)): @@ -470,12 +477,18 @@ def quantize_weight_per_channel_impl( self.model.initializer_extend([scale_initializer, zero_initializer]) if not keep_float_weight: - quantized_weights = np.asarray( - quantized_weights, - dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[weight_qType], - ).reshape(initializer.dims) - q_weight_initializer = onnx.numpy_helper.from_array(quantized_weights, q_weight_name) - self.model.initializer_extend([q_weight_initializer]) + if weight_qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): + q_weight_initializer = onnx.helper.make_tensor( + q_weight_name, weight_qType, weights_shape, quantized_weights + ) + self.model.initializer_extend([q_weight_initializer]) + else: + quantized_weights = np.asarray( + quantized_weights, + dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[weight_qType], + ).reshape(initializer.dims) + q_weight_initializer = onnx.numpy_helper.from_array(quantized_weights, q_weight_name) + self.model.initializer_extend([q_weight_initializer]) return q_weight_name, zp_name, scale_name diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 35b5e1c8ba825..bdf6d5a355206 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -21,7 +21,7 @@ from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions try: - from onnx.reference.custom_element_types import float8e4m3fn + from onnx.reference.custom_element_types import float8e4m3fn, int4, uint4 except ImportError: float8e4m3fn = None @@ -81,6 +81,8 @@ class QuantType(Enum): QFLOAT8E4M3FN = 2 QInt16 = 3 QUInt16 = 4 + QInt4 = 5 + QUInt4 = 6 def __str__(self): return self.name @@ -104,6 +106,10 @@ def tensor_type(self): return TensorProto.INT16 if self == QuantType.QFLOAT8E4M3FN: return TensorProto.FLOAT8E4M3FN + if self == QuantType.QUInt4: + return TensorProto.UINT4 + if self == QuantType.QInt4: + return TensorProto.INT4 raise ValueError(f"Unexpected value qtype={self!r}.") @@ -128,6 +134,8 @@ def from_string(format): onnx_proto.TensorProto.INT16: numpy.dtype("int16"), onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"), onnx_proto.TensorProto.FLOAT8E4M3FN: float8e4m3fn, + onnx_proto.TensorProto.INT4: int4, + onnx_proto.TensorProto.UINT4: uint4, } ONNX_INT_TYPE_RANGE = { @@ -135,6 +143,8 @@ def from_string(format): onnx_proto.TensorProto.INT8: (numpy.array(-128, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)), onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(65535, dtype=numpy.uint16)), onnx_proto.TensorProto.INT16: (numpy.array(-32768, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)), + onnx_proto.TensorProto.UINT4: (numpy.array(0, dtype=uint4), numpy.array(15, dtype=uint4)), + onnx_proto.TensorProto.INT4: (numpy.array(-8, dtype=int4), numpy.array(7, dtype=int4)), } ONNX_INT_TYPE_SYMMETRIC_RANGE = { @@ -202,6 +212,35 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): ) ref = ReferenceEvaluator(onnx_model) return _check_type(ref.run(None, {"X": arr, "scale": scale})[0]) + elif qType in ( + onnx_proto.TensorProto.INT4, + onnx_proto.TensorProto.UINT4, + ): + if arr.dtype == numpy.float32: + onnx_type = TensorProto.FLOAT + elif arr.dtype == numpy.float16: + onnx_type = TensorProto.FLOAT16 + else: + raise ValueError(f"Unexpected dtype {arr.dtype}.") + onnx_model = make_model( + make_graph( + [ + make_node("QuantizeLinear", ["X", "scale", "zero_point"], ["Y"]), + ], + "qu", + [ + make_tensor_value_info("X", onnx_type, None), + make_tensor_value_info("scale", onnx_type, None), + make_tensor_value_info("zero_point", qType, None), + ], + [make_tensor_value_info("Y", qType, None)], + ) + ) + # The reference ONNX implementation of QuantizeLinear returns "unpacked" int8 numpy values + # because numpy cannot represent 4bit values (although ONNX TensorProto has no problem with this). + # These "unpacked" int8 values are correctly re-packed when passed to onnx.make_tensor(). + ref = ReferenceEvaluator(onnx_model) + return _check_type(ref.run(None, {"X": arr, "scale": scale, "zero_point": zero_point})[0]) else: dtype = ONNX_TYPE_TO_NP_TYPE[qType] (qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True) @@ -372,7 +411,14 @@ def quantize_data( ) return _check_type(rmin, rmax, zero_point, scale, quantized_data, zero_point_index=2) - if qType in (TensorProto.INT8, TensorProto.UINT8, TensorProto.INT16, TensorProto.UINT16): + if qType in ( + TensorProto.INT8, + TensorProto.UINT8, + TensorProto.INT16, + TensorProto.UINT16, + TensorProto.INT4, + TensorProto.UINT4, + ): if len(data): qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric) zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range) From eeacb78b396134767cb1f46ec67a0e6259f0e162 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 18:24:04 -0700 Subject: [PATCH 16/72] Remove duplicate enum --- include/onnxruntime/core/session/onnxruntime_c_api.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index d2b4c0c0d7ef6..7c6e9b1dbf63b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -197,7 +197,6 @@ typedef enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, // Non-IEEE floating-point format based on IEEE754 single-precision ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of uint4 values (size == 1 byte) ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of int4 values (size == 1 byte) } ONNXTensorElementDataType; From 6f9da045db2a17af99fba83f3f299ec911bcbd35 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 18:37:54 -0700 Subject: [PATCH 17/72] Remove MatMulSelector constructor arg --- .../qdq_transformer/selectors_actions/qdq_selectors.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 495f26294266c..4446f6f4e6b63 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -335,11 +335,9 @@ class WhereSelector : public BaseSelector { // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not class MatMulSelector : public BaseSelector { public: - MatMulSelector(gsl::span compatible_providers, bool int8_allowed, bool allow_16bit = false, - bool allow_4bit = false) + MatMulSelector(bool int8_allowed, bool allow_16bit = false, bool allow_4bit = false) : BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true, - allow_16bit, allow_4bit), - compatible_providers) {} + allow_16bit, allow_4bit)) {} }; // Input: DQ nodes for A, B and optional C From 7e8c4588888fab5a7e9f16b00ac77b811eb9a517 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 19:18:14 -0700 Subject: [PATCH 18/72] Remove unnecessary explicit template instantiation --- onnxruntime/core/framework/tensorprotoutils.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index ce686a809438d..b9574d3f961ce 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -378,8 +378,6 @@ INSTANTIATE_UNPACK_EXTERNAL_TENSOR(Float8E4M3FNUZ) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(Float8E5M2) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(Float8E5M2FNUZ) #endif -INSTANTIATE_UNPACK_EXTERNAL_TENSOR(Int4x2) -INSTANTIATE_UNPACK_EXTERNAL_TENSOR(UInt4x2) template <> Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& /*tensor*/, From 3b7ed5f30b2428114e89480f6a74645d4ee4f14c Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 19:40:14 -0700 Subject: [PATCH 19/72] Add static_cast --- onnxruntime/core/framework/tensorprotoutils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index b9574d3f961ce..0a51b2cfffe4e 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1738,7 +1738,7 @@ template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::T #define CASE_UNPACK_INT4(TYPE, ELEMENT_TYPE, DATA_SIZE) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \ TensorShape tensor_shape = GetTensorShapeFromTensorProto(initializer); \ - size_t element_count = tensor_shape.Size(); \ + size_t element_count = static_cast(tensor_shape.Size()); \ size_t packed_element_count = (element_count + 1) / 2; \ unpacked_tensor.resize(packed_element_count * sizeof(ELEMENT_TYPE)); \ return onnxruntime::utils::UnpackTensor( \ From c7086a5edfe00562b691c97c193589c561cd8ed8 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 19:49:50 -0700 Subject: [PATCH 20/72] Add temporary CPU EP Int4 test (qdq conv) --- .../cpu/cpu_execution_provider_test.cc | 52 ++++++++++++++++++ .../test/testdata/conv.int4.int8.qdq.onnx | Bin 0 -> 1783 bytes 2 files changed, 52 insertions(+) create mode 100644 onnxruntime/test/testdata/conv.int4.int8.qdq.onnx diff --git a/onnxruntime/test/providers/cpu/cpu_execution_provider_test.cc b/onnxruntime/test/providers/cpu/cpu_execution_provider_test.cc index 8b9dcbd943b4a..325f305dad537 100644 --- a/onnxruntime/test/providers/cpu/cpu_execution_provider_test.cc +++ b/onnxruntime/test/providers/cpu/cpu_execution_provider_test.cc @@ -2,7 +2,16 @@ // Licensed under the MIT License. #include "core/providers/cpu/cpu_execution_provider.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #include "gtest/gtest.h" +#include "gmock/gmock.h" + +#define ORT_MODEL_FOLDER ORT_TSTR("testdata/") + +// in test_main.cc +extern std::unique_ptr ort_env; namespace onnxruntime { namespace test { @@ -12,5 +21,48 @@ TEST(CPUExecutionProviderTest, MetadataTest) { EXPECT_TRUE(provider != nullptr); ASSERT_EQ(provider->GetOrtDeviceByMemType(OrtMemTypeDefault).Type(), OrtDevice::CPU); } + +// TODO: Remove. This is a throwaway test for Int4 +TEST(CPUExecutionProviderTest, Example_Conv_Int4) { + Ort::SessionOptions so; + + // Ensure all type/shape inference warnings result in errors! + so.AddConfigEntry(kOrtSessionOptionsConfigStrictShapeTypeInference, "1"); + so.SetGraphOptimizationLevel(ORT_ENABLE_ALL); + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "conv.int4.int8.qdq.onnx"; + Ort::Session session(*ort_env, ort_model_path, so); + + std::array input0_data = {}; + for (size_t i = 0; i < input0_data.size(); i++) { + input0_data[i] = 0.2f; + } + + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector ort_inputs; + std::vector ort_input_names; + + // Add input0 + std::array inputs_shape{1, 3, 8, 8}; + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), inputs_shape.data(), inputs_shape.size())); + ort_input_names.push_back("input_0"); + + // Run session and get outputs + std::array output_names{"output_0"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output shape. + Ort::Value& ort_output = ort_outputs[0]; + auto typeshape = ort_output.GetTensorTypeAndShapeInfo(); + std::vector output_shape = typeshape.GetShape(); + + EXPECT_THAT(output_shape, ::testing::ElementsAre(1, 5, 6, 6)); + const float* results = ort_output.GetTensorData(); + + for (size_t i = 0; i < typeshape.GetElementCount(); i++) { + std::cout << i << ": " << results[i] << std::endl; + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/conv.int4.int8.qdq.onnx b/onnxruntime/test/testdata/conv.int4.int8.qdq.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a5f83ac76489a42aa5b0fced2100297f49daa8ce GIT binary patch literal 1783 zcmdr6v|B380H;adWXGR%8}SFf3qX;wt81 z&&(?*Er~ba;)ODx77D?*NH)pB!~&rnfmjmnUs?h*LrM(0v=Sexa;|PhE`jpY%=C;B zByaOTMIhc4gz}IqmxBtSdkAK`I1YJ2LD0|0h2LMqgp54Q>$qGeflC3V^MH{;$T%(n zfx!h%6GF22;E0Jgh)0cqcu!EAOR+lV=am^Kv4AM9USKfbRnEl+mxjcVFr1GZOK?Fn zEBzs9L>z}aMk1-u5;>^)NLsI$;YSEeVjWg$owgI-!N9mIenaP)_9R86UgO4h}nH z+6g6hXvqLoifJ)3Ffce_W)!VvW)4;kW)K7_dcw5hKNK+jgR=yJ+TrZia2CH}DR#xv z;Yv=yS^WNGE8(0!aMlMnD+Pxx>{bkE_pm!*lxJF6z&K)aLMz<8M2yb;|8H~T|9@cI zF&wdGU^r*Uz@TLSig_oj1;3X_#w=f}*@B!_$h4Lz*mT^Lf|m%$0(20i2rur1#JEH_ z7=;A5m^c`Lm>GyUKsZT>3!Y`r)UX2O*+AHdg^NKzv`C(d3tX#Z=B1?;2?>JOMWw*{ KEjJZX(E$J;l>Np4 literal 0 HcmV?d00001 From 34dfa17131c16838b0c99afee8f388a16f3606cc Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 21:15:21 -0700 Subject: [PATCH 21/72] Update operator docs --- docs/ContribOperators.md | 2 +- docs/OperatorKernels.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 3d984a54c0495..1be16b81a846e 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1372,7 +1372,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32)
+
T1 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32), tensor(int4), tensor(uint4)
Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, 16-bit integer tensors, or 32-bit signed integer tensors.
T2 : tensor(float16), tensor(float)
Constrain 'y', 'x_scale' to float tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 5bae5ea626576..e602d65ac76c9 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -87,7 +87,7 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||[19, 20]|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||[13, 18]|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||[10, 12]|**T** = tensor(int32), tensor(int8), tensor(uint8)| @@ -468,7 +468,7 @@ Do not modify directly.* |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)| -|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)
**T2** = tensor(float)| +|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float)| |DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float)| From e7bec9cfd9c30670704b952afeb0792dd9a923f1 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 21:41:39 -0700 Subject: [PATCH 22/72] Update testing version of tensorprotoutils with int4 helpers --- onnxruntime/test/onnx/tensorprotoutils.cc | 90 +++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/onnxruntime/test/onnx/tensorprotoutils.cc b/onnxruntime/test/onnx/tensorprotoutils.cc index b15b1769a69c4..b98717116280d 100644 --- a/onnxruntime/test/onnx/tensorprotoutils.cc +++ b/onnxruntime/test/onnx/tensorprotoutils.cc @@ -75,6 +75,48 @@ static void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_length memcpy(p_data, raw_data, raw_data_length); } +template <> +void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, + /*out*/ Int4x2* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + if (p_data == nullptr) { + ORT_CXX_API_THROW("nullptr == p_data", OrtErrorCode::ORT_FAIL); + } + + size_t num_packed_pairs = (expected_num_elements + 1) / 2; + + if (num_packed_pairs != raw_data_len) { + ORT_CXX_API_THROW("Unexpected number of packed int4 pairs", OrtErrorCode::ORT_FAIL); + } + + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); + gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); +} + +template <> +void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, + /*out*/ UInt4x2* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + if (p_data == nullptr) { + ORT_CXX_API_THROW("nullptr == p_data", OrtErrorCode::ORT_FAIL); + } + + size_t num_packed_pairs = (expected_num_elements + 1) / 2; + + if (num_packed_pairs != raw_data_len) { + ORT_CXX_API_THROW("Unexpected number of packed int4 pairs", OrtErrorCode::ORT_FAIL); + } + + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); + gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); +} + // This macro doesn't work for Float16/bool/string tensors #define DEFINE_UNPACK_TENSOR(T, Type, field_name, field_size) \ template <> \ @@ -268,6 +310,41 @@ DEFINE_UNPACK_TENSOR_FLOAT8(Float8E5M2, TensorProto_DataType_FLOAT8E5M2) DEFINE_UNPACK_TENSOR_FLOAT8(Float8E5M2FNUZ, TensorProto_DataType_FLOAT8E5M2FNUZ) #endif +#define DEFINE_UNPACK_TENSOR_INT4(INT4_TYPE, ONNX_TYPE) \ + template <> \ + void UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \ + /*out*/ INT4_TYPE* p_data, size_t expected_num_elems) { \ + if (nullptr == p_data) { \ + const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); \ + if (size == 0) { \ + return; \ + } \ + ORT_CXX_API_THROW("p_data == nullptr, but size != 0", OrtErrorCode::ORT_INVALID_ARGUMENT); \ + } \ + if (ONNX_NAMESPACE::ONNX_TYPE != tensor.data_type()) { \ + ORT_CXX_API_THROW("TensorProto data type is not INT4", OrtErrorCode::ORT_INVALID_ARGUMENT); \ + } \ + \ + size_t expected_int4_pairs = (expected_num_elems + 1) / 2; \ + \ + if (raw_data != nullptr) { \ + UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); \ + return; \ + } \ + \ + if (static_cast(tensor.int32_data_size()) != expected_int4_pairs) { \ + ORT_CXX_API_THROW("UnpackTensor: the pre-allocated size does not match the size in proto", \ + OrtErrorCode::ORT_FAIL); \ + } \ + \ + for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { \ + p_data[i] = INT4_TYPE(static_cast(tensor.int32_data()[i])); \ + } \ + } + +DEFINE_UNPACK_TENSOR_INT4(Int4x2, TensorProto_DataType_INT4) +DEFINE_UNPACK_TENSOR_INT4(UInt4x2, TensorProto_DataType_UINT4) + #define CASE_PROTO_TRACE(X, Y) \ case onnx::TensorProto_DataType::TensorProto_DataType_##X: \ if (!CalcMemSizeForArrayWithAlignment(size, sizeof(Y), alignment, out)) { \ @@ -275,6 +352,13 @@ DEFINE_UNPACK_TENSOR_FLOAT8(Float8E5M2FNUZ, TensorProto_DataType_FLOAT8E5M2FNUZ) } \ break; +#define CASE_PROTO_TRACE_INT4(X) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + if (!CalcMemSizeForArrayWithAlignment((size + 1) / 2, 1, alignment, out)) { \ + ORT_CXX_API_THROW("Invalid TensorProto", OrtErrorCode::ORT_FAIL); \ + } \ + break; + template Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out) { const auto& dims = tensor_proto.dims(); @@ -308,6 +392,8 @@ Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_p CASE_PROTO_TRACE(FLOAT8E5M2FNUZ, Float8E5M2FNUZ); #endif CASE_PROTO_TRACE(STRING, std::string); + CASE_PROTO_TRACE_INT4(UINT4); + CASE_PROTO_TRACE_INT4(INT4); default: return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); } @@ -392,6 +478,8 @@ ONNXTensorElementDataType CApiElementTypeFromProtoType(int type) { CASE_TYPE(FLOAT8E4M3FNUZ) CASE_TYPE(FLOAT8E5M2) CASE_TYPE(FLOAT8E5M2FNUZ) + CASE_TYPE(UINT4) + CASE_TYPE(INT4) default: return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } @@ -456,6 +544,8 @@ Status TensorProtoToMLValue(const onnx::TensorProto& tensor_proto, const MemBuff CASE_PROTO(FLOAT8E5M2, Float8E5M2); CASE_PROTO(FLOAT8E5M2FNUZ, Float8E5M2FNUZ); #endif + CASE_PROTO(INT4, Int4x2); + CASE_PROTO(UINT4, UInt4x2); case onnx::TensorProto_DataType::TensorProto_DataType_STRING: if (preallocated != nullptr) { OrtStatus* status = OrtInitializeBufferForTensor(preallocated, preallocated_size, ele_type); From cd8912e0cb1ab945c4a9653bf7d6511eab8bbc25 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 22:04:51 -0700 Subject: [PATCH 23/72] Run lintrunner --- .../python/tools/quantization/base_quantizer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index b6d0ed4d8adc2..aff18a8b361c3 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -340,9 +340,7 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa f"\nraw={str(q_weight_initializer)[:200]}." ) elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): - q_weight_initializer = onnx.helper.make_tensor( - q_weight_name, qType, weight.dims, q_weight_data - ) + q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, q_weight_data) else: q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape( weight.dims @@ -400,9 +398,10 @@ def quantize_weight_per_channel_impl( symmetric = quant_overrides_for_channels[0].get( "symmetric", - (self.is_weight_symmetric or weight_qType in (onnx.TensorProto.INT8, - onnx.TensorProto.FLOAT8E4M3FN, - onnx.TensorProto.INT4)), + ( + self.is_weight_symmetric + or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.INT4) + ), ) reduce_range = quant_overrides_for_channels[0].get("reduce_range", self.reduce_range and reduce_range) zero_point_list = [] From 4cf3a751651c81ebab5c133550fea0af8acaa758 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 22:30:12 -0700 Subject: [PATCH 24/72] Fix api to create int4 ort value --- onnxruntime/core/session/onnxruntime_c_api.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 49e3f0a0213ba..4a4e7bcbd5610 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -216,6 +216,12 @@ ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t } auto elem_count = narrow(tensor_shape.Size()); + + // TODO: Handle this more cleanly. + if (utils::IsPrimitiveDataType(ml_type) || utils::IsPrimitiveDataType(ml_type)) { + elem_count = (elem_count + 1) / 2; + } + size_t size_to_allocate; if (!IAllocator::CalcMemSizeForArray(ml_type->Size(), elem_count, &size_to_allocate)) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "size overflow"); From ca785c2a2bfd935f7e6d0352ca1634b7a57113dc Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 17 Apr 2024 22:59:31 -0700 Subject: [PATCH 25/72] Wrap long lines in tensorprotoutils --- onnxruntime/core/framework/tensorprotoutils.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 0a51b2cfffe4e..9e7cf37255158 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -326,7 +326,8 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& t size_t num_packed_pairs = (expected_num_elements + 1) / 2; ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); - gsl::span src_span = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), num_packed_pairs); + gsl::span src_span = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), + num_packed_pairs); gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); @@ -347,7 +348,8 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& size_t num_packed_pairs = (expected_num_elements + 1) / 2; ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); - gsl::span src_span = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), num_packed_pairs); + gsl::span src_span = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), + num_packed_pairs); gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); From f87e785c196179b4a7ed72c3ee2e345391f6eb0a Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 18 Apr 2024 00:28:44 -0700 Subject: [PATCH 26/72] Add operator unit tests for Dequant int4/uint4 --- onnxruntime/test/providers/base_tester.h | 12 +++++-- .../cpu/tensor/quantize_linear_test.cc | 31 +++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/base_tester.h b/onnxruntime/test/providers/base_tester.h index c276ae494df43..e00855d1e9eac 100644 --- a/onnxruntime/test/providers/base_tester.h +++ b/onnxruntime/test/providers/base_tester.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "core/framework/customregistry.h" #include "core/framework/prepacked_weights_container.h" @@ -690,8 +691,15 @@ class BaseTester { if (!is_optional_type_tensor || (is_optional_type_tensor && values != nullptr)) { // In case values is nullptr for optional type tensor, it means we are creating // an optional type tensor which is None and we hence skip values count validation - ORT_ENFORCE(shape.Size() == values_count, values_count, " input values doesn't match tensor size of ", - shape.Size()); + if constexpr (std::is_same_v || std::is_same_v) { + int64_t expected_values_count = shape.Size(); + expected_values_count = (expected_values_count + 1) / 2; + ORT_ENFORCE(expected_values_count == values_count, values_count, + " input values doesn't match tensor size of ", expected_values_count); + } else { + ORT_ENFORCE(shape.Size() == values_count, values_count, " input values doesn't match tensor size of ", + shape.Size()); + } // If it is an optional tensor type with no values (i.e.) None, // we won't even pass it in to Run() as part of the feeds, diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 34f6455f33853..6169862f9ff95 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -5,6 +5,7 @@ #include "test/common/cuda_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "core/framework/int4.h" namespace onnxruntime { namespace test { @@ -32,6 +33,36 @@ TEST(DequantizeLinearOpTest, Int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// scalar zero & scale with int4 +TEST(DequantizeLinearOpTest, Int4) { + OpTester test("DequantizeLinear", 21); + std::vector dims{5}; + constexpr int unused_val = 0; + + // Odd number of int4 values to test packing/unpacking + test.AddInput("x", dims, {Int4x2(-8, -3), Int4x2(1, 7), Int4x2(2, unused_val)}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddInput("x_zero_point", {}, {Int4x2(-1, unused_val)}); + test.AddOutput("y", dims, {-14.0f, -4.0f, 4.0f, 16.0f, 6.0f}); + // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// scalar zero & scale with uint4 +TEST(DequantizeLinearOpTest, UInt4) { + OpTester test("DequantizeLinear", 21); + std::vector dims{5}; + constexpr int unused_val = 0; + + // Odd number of uint4 values to test packing/unpacking + test.AddInput("x", dims, {UInt4x2(0, 1), UInt4x2(3, 15), UInt4x2(2, unused_val)}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddInput("x_zero_point", {}, {UInt4x2(1, unused_val)}); + test.AddOutput("y", dims, {-2.0f, 0.0f, 4.0f, 28.0f, 2.0f}); + // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + // Test int16 DequantizeLinear (per tensor) TEST(DequantizeLinearOpTest, Int16) { OpTester test("DequantizeLinear", 21); From d028f2f72c001d5442d3e37e8c21aa8c47e4c5f1 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 18 Apr 2024 00:29:37 -0700 Subject: [PATCH 27/72] Remove comments --- onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 6169862f9ff95..ae2fd4581cd4a 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -44,7 +44,6 @@ TEST(DequantizeLinearOpTest, Int4) { test.AddInput("x_scale", {}, {2.0f}); test.AddInput("x_zero_point", {}, {Int4x2(-1, unused_val)}); test.AddOutput("y", dims, {-14.0f, -4.0f, 4.0f, 16.0f, 6.0f}); - // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } @@ -59,7 +58,6 @@ TEST(DequantizeLinearOpTest, UInt4) { test.AddInput("x_scale", {}, {2.0f}); test.AddInput("x_zero_point", {}, {UInt4x2(1, unused_val)}); test.AddOutput("y", dims, {-2.0f, 0.0f, 4.0f, 28.0f, 2.0f}); - // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } From 24cc617236db578a1a84794d2847e10e416a3bf7 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 18 Apr 2024 04:39:49 -0700 Subject: [PATCH 28/72] Add QuantizeLinear int4 impl --- include/onnxruntime/core/framework/int4.h | 16 ++ .../contrib_ops/cpu/cpu_contrib_kernels.cc | 4 + .../core/framework/tensorprotoutils.cc | 2 +- .../graph/contrib_ops/quantization_defs.cc | 3 +- .../providers/cpu/cpu_execution_provider.cc | 6 + .../cpu/quantization/quantize_linear.cc | 161 +++++++++++++++++- onnxruntime/test/providers/checkers.cc | 42 +++++ .../cpu/tensor/quantize_linear_test.cc | 46 +++++ 8 files changed, 277 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h index 26be0c02fae81..f059386dce263 100644 --- a/include/onnxruntime/core/framework/int4.h +++ b/include/onnxruntime/core/framework/int4.h @@ -25,6 +25,14 @@ struct Int4x2 { return index == 0 ? val_0 : val_1; } + inline bool operator==(const Int4x2& other) const { + return this->val_0 == other.val_0 && this->val_1 == other.val_1; + } + + inline bool operator!=(const Int4x2& other) const { + return !(*this == other); + } + inline uint8_t ToBits() const { return (static_cast(val_1) << 4) | (static_cast(val_0) & 0xF); } @@ -81,6 +89,14 @@ struct UInt4x2 { return index == 0 ? val_0 : val_1; } + inline bool operator==(const UInt4x2& other) const { + return this->val_0 == other.val_0 && this->val_1 == other.val_1; + } + + inline bool operator!=(const UInt4x2& other) const { + return !(*this == other); + } + inline uint8_t ToBits() const { return (static_cast(val_1) << 4) | (static_cast(val_0) & 0xF); } diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index f0e39779a0532..b6d86cfafb035 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -69,6 +69,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid); @@ -212,6 +214,8 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 9e7cf37255158..d4608d6e1bc4c 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -734,7 +734,7 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); return size == 0 ? Status::OK() : Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); } - if (ONNX_NAMESPACE::TensorProto_DataType_INT4 != tensor.data_type()) { + if (ONNX_NAMESPACE::TensorProto_DataType_UINT4 != tensor.data_type()) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); } diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 80dab9594dd79..762d892c45ce8 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -164,7 +164,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T2", OpSchema::Optional) .Output(0, "y", "N-D quantized output tensor. It has same shape as input 'x'.", "T2") .TypeConstraint("T1", {"tensor(float16)", "tensor(float)"}, "Constrain 'x', 'y_scale' to float tensors.") - .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", "tensor(uint16)"}, + .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", "tensor(uint16)", "tensor(int4)", + "tensor(uint4)"}, "Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.") .SetDoc(QuantizeLinear_ver1_doc) .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index be5eea06fc7c7..b8d5a7852a968 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1083,6 +1083,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint16_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int16_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Int4x2, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2, QuantizeLinear); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FNUZ, QuantizeLinear); @@ -2680,6 +2682,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { QuantizeLinear)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index f98959b713cdc..403c5fd091c5c 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -388,6 +388,8 @@ REGISTER_QUANTIZELINEAR(int8_t) REGISTER_QUANTIZELINEAR(uint8_t) REGISTER_QUANTIZELINEAR(int16_t) REGISTER_QUANTIZELINEAR(uint16_t) +REGISTER_QUANTIZELINEAR(Int4x2) +REGISTER_QUANTIZELINEAR(UInt4x2) #if !defined(DISABLE_FLOAT8_TYPES) REGISTER_QUANTIZELINEAR(Float8E4M3FN) @@ -451,6 +453,24 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( .TypeConstraint("T1", DataTypeImpl::GetTensorType()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()), QuantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + Int4x2, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + UInt4x2, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); } // namespace contrib #endif // !defined(DISABLE_CONTRIB_OPS) @@ -476,7 +496,8 @@ void ParQuantizeLinear(const InputType* Input, } template -void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const T* zero_point, T* output, int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) { +void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const T* zero_point, T* output, int64_t N, + int64_t broadcast_dim, int64_t block_size, bool saturate) { for (size_t n = 0; n < static_cast(N); n++) { for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { ParQuantizeLinear(input, output, static_cast(block_size), scale[bd], bd, zero_point, saturate, ctx->GetOperatorThreadPool()); @@ -486,6 +507,144 @@ void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const } } +template <> +void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const Int4x2* zero_point, + Int4x2* output, int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + // Quantize as 8bit and then copy to output as packed int4s. + // TODO: Can be done in-place without copying if block_size is even. + size_t total_size = static_cast(N * broadcast_dim * block_size); + auto tmp_buf = std::make_unique(total_size); + size_t tmp_buf_index = 0; + + for (size_t n = 0; n < static_cast(N); n++) { + for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { + size_t bd_i = bd >> 1; /*bd / 2*/ + size_t bd_j = bd & 0x1; /*bd % 2*/ + int8_t zp = zero_point ? zero_point[bd_i][bd_j] : 0; + ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, static_cast(block_size), scale[bd], + zp, ctx->GetOperatorThreadPool()); + input += block_size; + tmp_buf_index += block_size; + } + } + + // Clamp quantized value to 4bit range. + // TODO: This can be combined with the packing step. + for (size_t i = 0; i < total_size; i++) { + tmp_buf[i] = std::min(7, std::max(-8, tmp_buf[i])); + } + + size_t num_int4_pairs = (total_size + 1) / 2; + auto dst = gsl::make_span(output, num_int4_pairs); + auto src = gsl::make_span(tmp_buf.get(), total_size); + Int4x2::Pack(dst, src); +} + +template <> +void ComputeLoop(OpKernelContext* ctx, const MLFloat16* input, const MLFloat16* scale, + const Int4x2* zero_point, Int4x2* output, + int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + // Quantize as 8bit and then copy to output as packed int4s. + // TODO: Can be done in-place without copying if block_size is even. + size_t total_size = static_cast(N * broadcast_dim * block_size); + auto tmp_buf = std::make_unique(total_size); + size_t tmp_buf_index = 0; + + for (size_t n = 0; n < static_cast(N); n++) { + for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { + size_t bd_i = bd >> 1; /*bd / 2*/ + size_t bd_j = bd & 0x1; /*bd % 2*/ + int8_t zp = zero_point ? zero_point[bd_i][bd_j] : 0; + ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, static_cast(block_size), scale[bd], + zp, ctx->GetOperatorThreadPool()); + input += block_size; + tmp_buf_index += block_size; + } + } + + // Clamp quantized value to 4bit range. + // TODO: This can be combined with the packing step. + for (size_t i = 0; i < total_size; i++) { + tmp_buf[i] = std::min(7, std::max(-8, tmp_buf[i])); + } + + size_t num_int4_pairs = (total_size + 1) / 2; + auto dst = gsl::make_span(output, num_int4_pairs); + auto src = gsl::make_span(tmp_buf.get(), total_size); + Int4x2::Pack(dst, src); +} + +template <> +void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const UInt4x2* zero_point, + UInt4x2* output, int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + // Quantize as 8bit and then copy to output as packed int4s. + // TODO: Can be done in-place without copying if block_size is even. + size_t total_size = static_cast(N * broadcast_dim * block_size); + auto tmp_buf = std::make_unique(total_size); + size_t tmp_buf_index = 0; + + for (size_t n = 0; n < static_cast(N); n++) { + for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { + size_t bd_i = bd >> 1; /*bd / 2*/ + size_t bd_j = bd & 0x1; /*bd % 2*/ + uint8_t zp = zero_point ? zero_point[bd_i][bd_j] : 0; + ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, static_cast(block_size), scale[bd], + zp, ctx->GetOperatorThreadPool()); + input += block_size; + tmp_buf_index += block_size; + } + } + + // Clamp quantized value to 4bit range. + // TODO: This can be combined with the packing step. + for (size_t i = 0; i < total_size; i++) { + tmp_buf[i] = std::min(15, std::max(0, tmp_buf[i])); + } + + size_t num_int4_pairs = (total_size + 1) / 2; + auto dst = gsl::make_span(output, num_int4_pairs); + auto src = gsl::make_span(tmp_buf.get(), total_size); + UInt4x2::Pack(dst, src); +} + +template <> +void ComputeLoop(OpKernelContext* ctx, const MLFloat16* input, const MLFloat16* scale, + const UInt4x2* zero_point, UInt4x2* output, + int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + // Quantize as 8bit and then copy to output as packed int4s. + // TODO: Can be done in-place without copying if block_size is even. + size_t total_size = static_cast(N * broadcast_dim * block_size); + auto tmp_buf = std::make_unique(total_size); + size_t tmp_buf_index = 0; + + for (size_t n = 0; n < static_cast(N); n++) { + for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { + size_t bd_i = bd >> 1; /*bd / 2*/ + size_t bd_j = bd & 0x1; /*bd % 2*/ + uint8_t zp = zero_point ? zero_point[bd_i][bd_j] : 0; + ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, static_cast(block_size), scale[bd], + zp, ctx->GetOperatorThreadPool()); + input += block_size; + tmp_buf_index += block_size; + } + } + + // Clamp quantized value to 4bit range. + // TODO: This can be combined with the packing step. + for (size_t i = 0; i < total_size; i++) { + tmp_buf[i] = std::min(15, std::max(0, tmp_buf[i])); + } + + size_t num_int4_pairs = (total_size + 1) / 2; + auto dst = gsl::make_span(output, num_int4_pairs); + auto src = gsl::make_span(tmp_buf.get(), total_size); + UInt4x2::Pack(dst, src); +} + // formula is Y = X / Scale + ZeroPoint template Status QuantizeLinear::Compute(OpKernelContext* ctx) const { diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 47c18c478dd9c..981388c21dfeb 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -7,6 +7,7 @@ #include "core/graph/constants.h" #include "core/framework/TensorSeq.h" +#include "core/framework/int4.h" #include "test/framework/test_utils.h" #include "test/providers/provider_test_utils.h" @@ -162,6 +163,46 @@ struct TensorCheck { } }; +template <> +struct TensorCheck { + void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, + const std::string& /*provider_type*/) const { + ORT_UNUSED_PARAMETER(params); + Tensor expected_sorted, actual_sorted; + const Int4x2* cur_expected; + const Int4x2* cur_actual; + const auto size = actual.Shape().Size(); + cur_expected = expected.Data(); + cur_actual = actual.Data(); + + for (size_t i = 0; i < static_cast(size); ++i) { + size_t r = i >> 2; + size_t c = i & 0x1; + EXPECT_EQ(cur_expected[r][c], cur_actual[r][c]) << "i:" << i; + } + } +}; + +template <> +struct TensorCheck { + void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, + const std::string& /*provider_type*/) const { + ORT_UNUSED_PARAMETER(params); + Tensor expected_sorted, actual_sorted; + const UInt4x2* cur_expected; + const UInt4x2* cur_actual; + const auto size = actual.Shape().Size(); + cur_expected = expected.Data(); + cur_actual = actual.Data(); + + for (size_t i = 0; i < static_cast(size); ++i) { + size_t r = i >> 2; + size_t c = i & 0x1; + EXPECT_EQ(cur_expected[r][c], cur_actual[r][c]) << "i:" << i; + } + } +}; + template <> struct TensorCheck { void operator()(const Tensor& expected, @@ -437,6 +478,7 @@ void Check(std::string_view name, const OrtValue& expected, const Tensor utils::MLTypeCallDispatcher dims{7}; + constexpr int8_t unused_val = 0; + test.AddInput("x", dims, { + -20.0f, // Clamp to qmin + -16.0f, // Close to qmin + -3.0f, // round + 0.0f, // Zero-point + 3.0f, // round + 12.0f, // qmax + 20.0f, // Clamp to qmax + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {Int4x2(1, unused_val)}, true); + test.AddOutput("y", dims, + {Int4x2(-8, -7), Int4x2(-1, 1), Int4x2(2, 7), + Int4x2(7, unused_val)}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test uint4 QuantizeLinear (per tensor) +TEST(QuantizeLinearOpTest, UInt4) { + OpTester test("QuantizeLinear", 21); + std::vector dims{7}; + constexpr uint8_t unused_val = 0; + test.AddInput("x", dims, { + -20.0f, // Clamp to qmin + -8.0f, // qmin + -3.0f, // round + 0.0f, // Zero-point + 3.0f, // round + 22.0f, // qmax + 20.0f, // Clamp to qmax + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {UInt4x2(4, unused_val)}, true); + test.AddOutput("y", dims, + {UInt4x2(0, 0), UInt4x2(2, 4), UInt4x2(6, 15), + UInt4x2(15, unused_val)}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + // quantize with scalar zero point and scale TEST(QuantizeLinearOpTest, Int8_NegativeZeroPoint) { // TODO: Unskip when fixed #41968513 From f35b09effb9ce4b6e36de8a4cee40c2ec33150f0 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 18 Apr 2024 09:06:03 -0700 Subject: [PATCH 29/72] Update operator docs --- docs/ContribOperators.md | 2 +- docs/OperatorKernels.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 1be16b81a846e..0be2b9b18e4c9 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -4823,7 +4823,7 @@ This version of the operator has been available since version 1 of the 'com.micr
T1 : tensor(float16), tensor(float)
Constrain 'x', 'y_scale' to float tensors.
-
T2 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16)
+
T2 : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int4), tensor(uint4)
Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index e602d65ac76c9..d296abedddf6f 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -259,7 +259,7 @@ Do not modify directly.* |||[7, 11]|**T** = tensor(double), tensor(float)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| |QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)| |||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |||[10, 12]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| @@ -503,7 +503,7 @@ Do not modify directly.* |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearWhere|*in* condition:**B**
*in* X:**T**
*in* x_scale:**TF**
*in* x_zero_point:**T**
*in* Y:**T**
*in* y_scale:**TF**
*in* y_zero_point:**T**
*in* z_scale:**TF**
*in* z_zero_point:**T**
*out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)| From e33f198e3bf4ab8c53f147551d6ffa37ce1d3894 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 18 Apr 2024 09:15:54 -0700 Subject: [PATCH 30/72] Disable potentially bugged onnx tests --- onnxruntime/test/onnx/TestCase.cc | 6 +++++- .../test/testdata/onnx_backend_test_series_filters.jsonc | 7 ++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index e12e9401413be..baf79dfce7bfe 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1021,7 +1021,11 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {}}, {"dequantizelinear_blocked", "blocked quantization (onnx 1.16.0) not supported", {}}, {"quantizelinear_blocked_asymmetric", "blocked quantization (onnx 1.16.0) not supported", {}}, - {"quantizelinear_blocked_symmetric", "blocked quantization (onnx 1.16.0) not supported", {}}}); + {"quantizelinear_blocked_symmetric", "blocked quantization (onnx 1.16.0) not supported", {}}, + {"dequantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}}, + {"dequantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}, + {"quantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}}, + {"quantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}}); // Some EPs may fail to pass some specific testcases. // For example TenosrRT EP may fail on FLOAT16 related testcases if GPU doesn't support float16. diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 0d141d634e051..ee8c7c79cc3d6 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -312,7 +312,12 @@ // DequantizeLinear(21) blocked quantization from ONNX 1.16.0 is not implemented in ORT yet. "^test_dequantizelinear_blocked", "^test_quantizelinear_blocked_asymmetric", - "^test_quantizelinear_blocked_symmetric" + "^test_quantizelinear_blocked_symmetric", + // Bug with test model: node's input name does not match the model's input name (x_zero_point vs zero_point) + "^test_dequantizelinear_int4", + "^test_dequantizelinear_uint4", + "^test_quantizelinear_int4", + "^test_quantizelinear_uint4" ], "current_failing_tests_x86": [ "^test_vgg19", From 10f28aa787fd5a34b21cc875b05f42cfffa3c60a Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 18 Apr 2024 09:16:55 -0700 Subject: [PATCH 31/72] Add TODO username --- onnxruntime/core/framework/tensor.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index 8d83198ffe03d..b7017297df4ce 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -35,7 +35,7 @@ size_t Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape if (shape_size > 0) { SafeInt len = 0; - // TODO: Handle more cleanly. Add virtual function to MLDataType: ByteSizeFromShape(TensorShape) ?? + // TODO(adrianlizarraga): Handle more cleanly. if (utils::IsPrimitiveDataType(elt_type) || utils::IsPrimitiveDataType(elt_type)) { shape_size = (shape_size + 1) / 2; } @@ -112,7 +112,7 @@ size_t Tensor::SizeInBytes() const { #endif size_t ret = 0; - // TODO: Handle more cleanly. Add virtual function to MLDataType: ByteSizeFromShape(TensorShape) ?? + // TODO(adrianlizarraga): Handle more cleanly. if (utils::IsPrimitiveDataType(dtype_) || utils::IsPrimitiveDataType(dtype_)) { size = (size + 1) / 2; } From de1ded4c5161594aec679f3e270a9aaeba3ee180 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 18 Apr 2024 17:17:08 -0700 Subject: [PATCH 32/72] Fix warning as error and clean up --- include/onnxruntime/core/framework/int4.h | 24 +-- .../cpu/quantization/quantize_linear.cc | 175 ++++-------------- 2 files changed, 48 insertions(+), 151 deletions(-) diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h index f059386dce263..0268f57c0457b 100644 --- a/include/onnxruntime/core/framework/int4.h +++ b/include/onnxruntime/core/framework/int4.h @@ -10,6 +10,10 @@ namespace onnxruntime { struct Int4x2 { + using unpacked_type = int8_t; + static constexpr unpacked_type min_val = -8; + static constexpr unpacked_type max_val = 7; + int8_t val_0 : 4; int8_t val_1 : 4; @@ -25,14 +29,6 @@ struct Int4x2 { return index == 0 ? val_0 : val_1; } - inline bool operator==(const Int4x2& other) const { - return this->val_0 == other.val_0 && this->val_1 == other.val_1; - } - - inline bool operator!=(const Int4x2& other) const { - return !(*this == other); - } - inline uint8_t ToBits() const { return (static_cast(val_1) << 4) | (static_cast(val_0) & 0xF); } @@ -74,6 +70,10 @@ struct Int4x2 { static_assert(sizeof(Int4x2) == sizeof(int8_t)); struct UInt4x2 { + using unpacked_type = uint8_t; + static constexpr unpacked_type min_val = 0; + static constexpr unpacked_type max_val = 15; + uint8_t val_0 : 4; uint8_t val_1 : 4; @@ -89,14 +89,6 @@ struct UInt4x2 { return index == 0 ? val_0 : val_1; } - inline bool operator==(const UInt4x2& other) const { - return this->val_0 == other.val_0 && this->val_1 == other.val_1; - } - - inline bool operator!=(const UInt4x2& other) const { - return !(*this == other); - } - inline uint8_t ToBits() const { return (static_cast(val_1) << 4) | (static_cast(val_0) & 0xF); } diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index 403c5fd091c5c..cccded643d279 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -507,143 +507,48 @@ void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const } } -template <> -void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const Int4x2* zero_point, - Int4x2* output, int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) { - ORT_UNUSED_PARAMETER(saturate); - // Quantize as 8bit and then copy to output as packed int4s. - // TODO: Can be done in-place without copying if block_size is even. - size_t total_size = static_cast(N * broadcast_dim * block_size); - auto tmp_buf = std::make_unique(total_size); - size_t tmp_buf_index = 0; - - for (size_t n = 0; n < static_cast(N); n++) { - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { - size_t bd_i = bd >> 1; /*bd / 2*/ - size_t bd_j = bd & 0x1; /*bd % 2*/ - int8_t zp = zero_point ? zero_point[bd_i][bd_j] : 0; - ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, static_cast(block_size), scale[bd], - zp, ctx->GetOperatorThreadPool()); - input += block_size; - tmp_buf_index += block_size; - } +#define DEFINE_COMPUTE_LOOP_INT4(INT4_TYPE, FLOAT_TYPE) \ + template <> \ + void ComputeLoop(OpKernelContext * ctx, const FLOAT_TYPE* input, const FLOAT_TYPE* scale, \ + const INT4_TYPE* zero_point, INT4_TYPE* output, int64_t N, \ + int64_t broadcast_dim, int64_t block_size, bool saturate) { \ + ORT_UNUSED_PARAMETER(saturate); \ + \ + /* Quantize as 8bit and then copy to output as packed int4s. */ \ + /* TODO(adrianlizarraga): Can be done in - place without copying if block_size is even.*/ \ + size_t total_size = static_cast(N * broadcast_dim * block_size); \ + auto tmp_buf = std::make_unique(total_size); \ + size_t tmp_buf_index = 0; \ + \ + for (size_t n = 0; n < static_cast(N); n++) { \ + for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ + size_t bd_i = bd >> 1; /*bd / 2*/ \ + size_t bd_j = bd & 0x1; /*bd % 2*/ \ + INT4_TYPE::unpacked_type zp = zero_point ? zero_point[bd_i][bd_j] : 0; \ + ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, \ + static_cast(block_size), scale[bd], \ + zp, ctx->GetOperatorThreadPool()); \ + input += block_size; \ + tmp_buf_index += static_cast(block_size); \ + } \ + } \ + \ + for (size_t i = 0; i < total_size; i++) { \ + tmp_buf[i] = std::min(INT4_TYPE::max_val, \ + std::max(INT4_TYPE::min_val, \ + tmp_buf[i])); \ + } \ + \ + size_t num_int4_pairs = (total_size + 1) / 2; \ + auto dst = gsl::make_span(output, num_int4_pairs); \ + auto src = gsl::make_span(tmp_buf.get(), total_size); \ + INT4_TYPE::Pack(dst, src); \ } - // Clamp quantized value to 4bit range. - // TODO: This can be combined with the packing step. - for (size_t i = 0; i < total_size; i++) { - tmp_buf[i] = std::min(7, std::max(-8, tmp_buf[i])); - } - - size_t num_int4_pairs = (total_size + 1) / 2; - auto dst = gsl::make_span(output, num_int4_pairs); - auto src = gsl::make_span(tmp_buf.get(), total_size); - Int4x2::Pack(dst, src); -} - -template <> -void ComputeLoop(OpKernelContext* ctx, const MLFloat16* input, const MLFloat16* scale, - const Int4x2* zero_point, Int4x2* output, - int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) { - ORT_UNUSED_PARAMETER(saturate); - // Quantize as 8bit and then copy to output as packed int4s. - // TODO: Can be done in-place without copying if block_size is even. - size_t total_size = static_cast(N * broadcast_dim * block_size); - auto tmp_buf = std::make_unique(total_size); - size_t tmp_buf_index = 0; - - for (size_t n = 0; n < static_cast(N); n++) { - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { - size_t bd_i = bd >> 1; /*bd / 2*/ - size_t bd_j = bd & 0x1; /*bd % 2*/ - int8_t zp = zero_point ? zero_point[bd_i][bd_j] : 0; - ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, static_cast(block_size), scale[bd], - zp, ctx->GetOperatorThreadPool()); - input += block_size; - tmp_buf_index += block_size; - } - } - - // Clamp quantized value to 4bit range. - // TODO: This can be combined with the packing step. - for (size_t i = 0; i < total_size; i++) { - tmp_buf[i] = std::min(7, std::max(-8, tmp_buf[i])); - } - - size_t num_int4_pairs = (total_size + 1) / 2; - auto dst = gsl::make_span(output, num_int4_pairs); - auto src = gsl::make_span(tmp_buf.get(), total_size); - Int4x2::Pack(dst, src); -} - -template <> -void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const UInt4x2* zero_point, - UInt4x2* output, int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) { - ORT_UNUSED_PARAMETER(saturate); - // Quantize as 8bit and then copy to output as packed int4s. - // TODO: Can be done in-place without copying if block_size is even. - size_t total_size = static_cast(N * broadcast_dim * block_size); - auto tmp_buf = std::make_unique(total_size); - size_t tmp_buf_index = 0; - - for (size_t n = 0; n < static_cast(N); n++) { - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { - size_t bd_i = bd >> 1; /*bd / 2*/ - size_t bd_j = bd & 0x1; /*bd % 2*/ - uint8_t zp = zero_point ? zero_point[bd_i][bd_j] : 0; - ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, static_cast(block_size), scale[bd], - zp, ctx->GetOperatorThreadPool()); - input += block_size; - tmp_buf_index += block_size; - } - } - - // Clamp quantized value to 4bit range. - // TODO: This can be combined with the packing step. - for (size_t i = 0; i < total_size; i++) { - tmp_buf[i] = std::min(15, std::max(0, tmp_buf[i])); - } - - size_t num_int4_pairs = (total_size + 1) / 2; - auto dst = gsl::make_span(output, num_int4_pairs); - auto src = gsl::make_span(tmp_buf.get(), total_size); - UInt4x2::Pack(dst, src); -} - -template <> -void ComputeLoop(OpKernelContext* ctx, const MLFloat16* input, const MLFloat16* scale, - const UInt4x2* zero_point, UInt4x2* output, - int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) { - ORT_UNUSED_PARAMETER(saturate); - // Quantize as 8bit and then copy to output as packed int4s. - // TODO: Can be done in-place without copying if block_size is even. - size_t total_size = static_cast(N * broadcast_dim * block_size); - auto tmp_buf = std::make_unique(total_size); - size_t tmp_buf_index = 0; - - for (size_t n = 0; n < static_cast(N); n++) { - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { - size_t bd_i = bd >> 1; /*bd / 2*/ - size_t bd_j = bd & 0x1; /*bd % 2*/ - uint8_t zp = zero_point ? zero_point[bd_i][bd_j] : 0; - ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, static_cast(block_size), scale[bd], - zp, ctx->GetOperatorThreadPool()); - input += block_size; - tmp_buf_index += block_size; - } - } - - // Clamp quantized value to 4bit range. - // TODO: This can be combined with the packing step. - for (size_t i = 0; i < total_size; i++) { - tmp_buf[i] = std::min(15, std::max(0, tmp_buf[i])); - } - - size_t num_int4_pairs = (total_size + 1) / 2; - auto dst = gsl::make_span(output, num_int4_pairs); - auto src = gsl::make_span(tmp_buf.get(), total_size); - UInt4x2::Pack(dst, src); -} +DEFINE_COMPUTE_LOOP_INT4(Int4x2, float) +DEFINE_COMPUTE_LOOP_INT4(Int4x2, MLFloat16) +DEFINE_COMPUTE_LOOP_INT4(UInt4x2, float) +DEFINE_COMPUTE_LOOP_INT4(UInt4x2, MLFloat16) // formula is Y = X / Scale + ZeroPoint template From 746312b8ef9c02083f9decc99dc1b10a1a80a6ff Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sat, 20 Apr 2024 14:44:13 -0700 Subject: [PATCH 33/72] Mlas kernels to quantize int4 (not blocked). Missing powerpc --- onnxruntime/core/mlas/inc/mlas.h | 20 ++ onnxruntime/core/mlas/lib/mlasi.h | 24 ++ onnxruntime/core/mlas/lib/platform.cpp | 4 + onnxruntime/core/mlas/lib/quantize.cpp | 250 ++++++++++++++++++ .../cpu/quantization/quantize_linear.cc | 103 +++++--- onnxruntime/core/util/qmath.h | 80 ++++++ .../mlas/unittest/test_quantizelinear.cpp | 112 ++++++++ onnxruntime/test/providers/checkers.cc | 4 +- .../cpu/tensor/quantize_linear_test.cc | 8 +- 9 files changed, 559 insertions(+), 46 deletions(-) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index ce7838556fbf0..f8966657e2109 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1222,6 +1222,26 @@ MlasQuantizeLinear( OutputType ZeroPoint ); +void +MLASCALL +MlasQuantizeLinearU4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ); + +void +MLASCALL +MlasQuantizeLinearS4( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ); + /** * @brief Requantize a block of the intermediate buffer to the output buffer, * optionally adding the supplied bias diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 4b93dde1bcef9..3e08d7e1fa06b 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -680,6 +680,24 @@ void float Scale, int16_t ZeroPoint); +typedef +void +(MLASCALL MLAS_QUANTIZE_LINEAR_U4_KERNEL)( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_QUANTIZE_LINEAR_S4_KERNEL)( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint); + template struct MLAS_QUANT_KERNEL { @@ -826,6 +844,8 @@ extern "C" { MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8Kernel; MLAS_QUANTIZE_LINEAR_S16_KERNEL MlasQuantizeLinearS16Kernel; MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel; + MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; + MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; #if defined(MLAS_TARGET_AMD64) MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; @@ -1077,6 +1097,8 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel; MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; + MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; + MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; #endif #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; @@ -1106,6 +1128,8 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel; MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; + MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; + MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index a53c5085b10cf..6e3511d1b996d 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -274,6 +274,8 @@ Return Value: this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; + this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; + this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; this->NchwcBlockSize = 8; this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; @@ -542,6 +544,8 @@ Return Value: this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; + this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; + this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; #if defined(__linux__) unsigned long hwcap2 = getauxval(AT_HWCAP2); diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index ffecc2dbeff9e..5d3a551f4addf 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -519,6 +519,132 @@ Return Value: } } +void +MLASCALL +MlasQuantizeLinearS4Kernel( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + constexpr int32_t MinimumValue = -8; + constexpr int32_t MaximumValue = 7; + + auto ScaleVector = MlasBroadcastFloat32x4(Scale); + auto MinimumValueVector = MlasBroadcastFloat32x4(float(MinimumValue - ZeroPoint)); + auto MaximumValueVector = MlasBroadcastFloat32x4(float(MaximumValue - ZeroPoint)); + auto ZeroPointVector = MlasBroadcastInt32x4(ZeroPoint); + + // Holds 4 quantized 8bit values that will be packed into the output as packed 4bit values. + std::array TmpOutput = {}; + + while (N >= 4) { + + auto FloatVector = MlasLoadFloat32x4(Input); + auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, + MinimumValueVector, MaximumValueVector, ZeroPointVector); + + IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); + MlasQuantizeLinearStore4PackedValues(IntegerVector, TmpOutput.data()); + + Output[0] = static_cast(((TmpOutput[1] & 0xF) << 4) | (TmpOutput[0] & 0xF)); + Output[1] = static_cast(((TmpOutput[3] & 0xF) << 4) | (TmpOutput[2] & 0xF)); + + Input += 4; + Output += 2; + N -= 4; + } + + for (size_t n = 0; n < N; n++) { + +#if defined(MLAS_NEON64_INTRINSICS) + auto FloatVector = vld1q_dup_f32(Input + n); +#elif defined(MLAS_LSX_INTRINSICS) + MLAS_FLOAT32X4 FloatVector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input+n, 0); +#else + auto FloatVector = _mm_load_ss(Input + n); +#endif + auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, + MinimumValueVector, MaximumValueVector, ZeroPointVector); + + MlasQuantizeLinearStoreSingleValue(IntegerVector, &TmpOutput[0]); + + size_t output_index = n >> 1; // which byte + size_t nibble_index = n & 0x1; // which 4-bit elem in the byte + + if (nibble_index == 0) { + Output[output_index] = static_cast(TmpOutput[0] & 0xF); + } else { + Output[output_index] |= static_cast((TmpOutput[0] & 0xF) << 4); + } + } +} + +void +MLASCALL +MlasQuantizeLinearU4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + constexpr int32_t MinimumValue = 0; + constexpr int32_t MaximumValue = 15; + + auto ScaleVector = MlasBroadcastFloat32x4(Scale); + auto MinimumValueVector = MlasBroadcastFloat32x4(float(MinimumValue - ZeroPoint)); + auto MaximumValueVector = MlasBroadcastFloat32x4(float(MaximumValue - ZeroPoint)); + auto ZeroPointVector = MlasBroadcastInt32x4(ZeroPoint); + + // Holds 4 quantized 8bit values that will be packed into the output as packed 4bit values. + std::array TmpOutput = {}; + + while (N >= 4) { + + auto FloatVector = MlasLoadFloat32x4(Input); + auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, + MinimumValueVector, MaximumValueVector, ZeroPointVector); + + IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); + MlasQuantizeLinearStore4PackedValues(IntegerVector, TmpOutput.data()); + + Output[0] = static_cast(((TmpOutput[1] & 0xF) << 4) | (TmpOutput[0] & 0xF)); + Output[1] = static_cast(((TmpOutput[3] & 0xF) << 4) | (TmpOutput[2] & 0xF)); + + Input += 4; + Output += 2; + N -= 4; + } + + for (size_t n = 0; n < N; n++) { + +#if defined(MLAS_NEON64_INTRINSICS) + auto FloatVector = vld1q_dup_f32(Input + n); +#elif defined(MLAS_LSX_INTRINSICS) + MLAS_FLOAT32X4 FloatVector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input+n, 0); +#else + auto FloatVector = _mm_load_ss(Input + n); +#endif + auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, + MinimumValueVector, MaximumValueVector, ZeroPointVector); + + MlasQuantizeLinearStoreSingleValue(IntegerVector, &TmpOutput[0]); + + size_t output_index = n >> 1; // which byte + size_t nibble_index = n & 0x1; // which 4-bit elem in the byte + + if (nibble_index == 0) { + Output[output_index] = static_cast(TmpOutput[0] & 0xF); + } else { + Output[output_index] |= static_cast((TmpOutput[0] & 0xF) << 4); + } + } +} + void MLASCALL MlasQuantizeLinearS8Kernel( @@ -571,6 +697,42 @@ MlasQuantizeLinearS16Kernel( MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); } +void +MLASCALL +MlasQuantizeLinearS4( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearS4Kernel( +#else + MlasQuantizeLinearS4Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearU4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearU4Kernel( +#else + MlasQuantizeLinearU4Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + template<> void MLASCALL @@ -707,6 +869,31 @@ MlasQuantizeLinear( GetMlasPlatform().QuantizeLinearU16Kernel(Input, Output, N, Scale, ZeroPoint); } +void +MLASCALL +MlasQuantizeLinearS4( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + GetMlasPlatform().QuantizeLinearS4Kernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearU4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + GetMlasPlatform().QuantizeLinearU4Kernel(Input, Output, N, Scale, ZeroPoint); +} #endif // @@ -805,6 +992,69 @@ MlasQuantizeLinear( uint16_t ZeroPoint ); +// QuantizeLinear INT4 implementation using the C++ runtime. +void +MLASCALL +MlasQuantizeLinearS4( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + constexpr int32_t MinimumValue = -8; + constexpr int32_t MaximumValue = 7; + + for (size_t n = 0; n < N; n++) { + + float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); + FloatValue = std::max(FloatValue, static_cast(MinimumValue)); + FloatValue = std::min(FloatValue, static_cast(MaximumValue)); + int8_t IntValue = static_cast(FloatValue); + + size_t i = n >> 1; // which byte + size_t j = n & 0x1; // which 4-bit elem in the byte + + if (j == 0) { + Output[i] = IntValue & 0xF; + } else { + Output[i] |= static_cast((IntValue & 0xF) << 4); + } + } +} + +// QuantizeLinear UINT4 implementation using the C++ runtime. +void +MLASCALL +MlasQuantizeLinearU4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + constexpr int32_t MinimumValue = 0; + constexpr int32_t MaximumValue = 15; + + for (size_t n = 0; n < N; n++) { + + float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); + FloatValue = std::max(FloatValue, static_cast(MinimumValue)); + FloatValue = std::min(FloatValue, static_cast(MaximumValue)); + uint8_t IntValue = static_cast(FloatValue); + + size_t i = n >> 1; // which byte + size_t j = n & 0x1; // which 4-bit elem in the byte + + if (j == 0) { + Output[i] = IntValue & 0xF; + } else { + Output[i] |= static_cast((IntValue & 0xF) << 4); + } + } +} #endif #endif diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index cccded643d279..425befc05b805 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -507,48 +507,71 @@ void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const } } -#define DEFINE_COMPUTE_LOOP_INT4(INT4_TYPE, FLOAT_TYPE) \ - template <> \ - void ComputeLoop(OpKernelContext * ctx, const FLOAT_TYPE* input, const FLOAT_TYPE* scale, \ - const INT4_TYPE* zero_point, INT4_TYPE* output, int64_t N, \ - int64_t broadcast_dim, int64_t block_size, bool saturate) { \ - ORT_UNUSED_PARAMETER(saturate); \ - \ - /* Quantize as 8bit and then copy to output as packed int4s. */ \ - /* TODO(adrianlizarraga): Can be done in - place without copying if block_size is even.*/ \ - size_t total_size = static_cast(N * broadcast_dim * block_size); \ - auto tmp_buf = std::make_unique(total_size); \ - size_t tmp_buf_index = 0; \ - \ - for (size_t n = 0; n < static_cast(N); n++) { \ - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ - size_t bd_i = bd >> 1; /*bd / 2*/ \ - size_t bd_j = bd & 0x1; /*bd % 2*/ \ - INT4_TYPE::unpacked_type zp = zero_point ? zero_point[bd_i][bd_j] : 0; \ - ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, \ - static_cast(block_size), scale[bd], \ - zp, ctx->GetOperatorThreadPool()); \ - input += block_size; \ - tmp_buf_index += static_cast(block_size); \ - } \ - } \ - \ - for (size_t i = 0; i < total_size; i++) { \ - tmp_buf[i] = std::min(INT4_TYPE::max_val, \ - std::max(INT4_TYPE::min_val, \ - tmp_buf[i])); \ - } \ - \ - size_t num_int4_pairs = (total_size + 1) / 2; \ - auto dst = gsl::make_span(output, num_int4_pairs); \ - auto src = gsl::make_span(tmp_buf.get(), total_size); \ - INT4_TYPE::Pack(dst, src); \ +// Quantizes float32 to INT4 (in-place) using MLAS kernel. +#define DEFINE_COMPUTE_LOOP_FP32_TO_INT4(INT4_TYPE, QUANT_FUNC) \ + template <> \ + void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const INT4_TYPE* zero_point, \ + INT4_TYPE* output, int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) { \ + ORT_UNUSED_PARAMETER(saturate); \ + size_t output_index = 0; \ + for (size_t n = 0; n < static_cast(N); n++) { \ + for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ + size_t bd_i = bd >> 1; /*bd / 2*/ \ + size_t bd_j = bd & 0x1; /*bd % 2*/ \ + INT4_TYPE::unpacked_type zp = zero_point ? zero_point[bd_i][bd_j] : 0; \ + QUANT_FUNC(input, output, output_index, output_index + static_cast(block_size), \ + scale[bd], INT4_TYPE(zp, 0), ctx->GetOperatorThreadPool()); \ + input += block_size; \ + output_index += block_size; \ + } \ + } \ + assert(output_index == static_cast(N * broadcast_dim * block_size)); \ } -DEFINE_COMPUTE_LOOP_INT4(Int4x2, float) -DEFINE_COMPUTE_LOOP_INT4(Int4x2, MLFloat16) -DEFINE_COMPUTE_LOOP_INT4(UInt4x2, float) -DEFINE_COMPUTE_LOOP_INT4(UInt4x2, MLFloat16) +DEFINE_COMPUTE_LOOP_FP32_TO_INT4(Int4x2, ParQuantizeLinearStdS4) +DEFINE_COMPUTE_LOOP_FP32_TO_INT4(UInt4x2, ParQuantizeLinearStdU4) + +// Defines functions to quantize MLFloat16 to INT4. +// This is not an efficient implementation: we allocate a buffer, quantize to INT8, and then copy/clamp/pack +// into output INT4 buffer. +#define DEFINE_COMPUTE_LOOP_FP16_TO_INT4(INT4_TYPE) \ + template <> \ + void ComputeLoop(OpKernelContext * ctx, const MLFloat16* input, const MLFloat16* scale, \ + const INT4_TYPE* zero_point, INT4_TYPE* output, int64_t N, \ + int64_t broadcast_dim, int64_t block_size, bool saturate) { \ + ORT_UNUSED_PARAMETER(saturate); \ + \ + size_t total_size = static_cast(N * broadcast_dim * block_size); \ + auto tmp_buf = std::make_unique(total_size); \ + size_t tmp_buf_index = 0; \ + \ + for (size_t n = 0; n < static_cast(N); n++) { \ + for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ + size_t bd_i = bd >> 1; /*bd / 2*/ \ + size_t bd_j = bd & 0x1; /*bd % 2*/ \ + INT4_TYPE::unpacked_type zp = zero_point ? zero_point[bd_i][bd_j] : 0; \ + ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, \ + static_cast(block_size), scale[bd], \ + zp, ctx->GetOperatorThreadPool()); \ + input += block_size; \ + tmp_buf_index += static_cast(block_size); \ + } \ + } \ + \ + for (size_t i = 0; i < total_size; i++) { \ + tmp_buf[i] = std::min(INT4_TYPE::max_val, \ + std::max(INT4_TYPE::min_val, \ + tmp_buf[i])); \ + } \ + \ + size_t num_int4_pairs = (total_size + 1) / 2; \ + auto dst = gsl::make_span(output, num_int4_pairs); \ + auto src = gsl::make_span(tmp_buf.get(), total_size); \ + INT4_TYPE::Pack(dst, src); \ + } + +DEFINE_COMPUTE_LOOP_FP16_TO_INT4(Int4x2) +DEFINE_COMPUTE_LOOP_FP16_TO_INT4(UInt4x2) // formula is Y = X / Scale + ZeroPoint template diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index 173ab632d59cf..d5ba90b6fc90a 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -8,6 +8,7 @@ #include "core/common/narrow.h" #include "core/framework/element_type_lists.h" #include "core/framework/float8.h" +#include "core/framework/int4.h" #include namespace onnxruntime { @@ -129,9 +130,88 @@ ParQuantizeLinearStd(const float* Input, auto begin_idx = begin * block_size; auto end_idx = std::min(static_cast(N), end * block_size); MlasQuantizeLinear(&(Input[begin_idx]), &(Output[begin_idx]), end_idx - begin_idx, Scale, ZeroPoint); + N -= (end_idx - begin_idx); }); } +#define DEFINE_PAR_QUANT_LINEAR_STD_4BIT(FUNC_NAME, INT4_TYPE, MLAS_FUNC) \ + inline void FUNC_NAME(const float* Input, \ + INT4_TYPE* Output, \ + size_t out_start, \ + size_t out_end, \ + float Scale, \ + INT4_TYPE ZeroPoint, \ + concurrency::ThreadPool* thread_pool) { \ + size_t inp_start = 0; \ + size_t inp_end = out_end - out_start; \ + \ + /* If starting at an int4 element in the middle of a byte, quantize it by itself. */ \ + if (out_start & 0x1) { \ + int32_t ival = static_cast(std::nearbyintf(Input[inp_start] / Scale)) + \ + static_cast(ZeroPoint.val_0); \ + size_t output_index = out_start >> 1; \ + \ + Output[output_index].val_1 = static_cast( \ + std::min(static_cast(INT4_TYPE::max_val), \ + std::max(static_cast(INT4_TYPE::min_val), ival))); \ + \ + out_start += 1; \ + inp_start += 1; \ + } \ + \ + /* If ending at element that ends in the middle of a byte, quantize it by itself. */ \ + if (out_end & 0x1) { \ + int32_t ival = static_cast(std::nearbyintf(Input[inp_end - 1] / Scale)) + \ + static_cast(ZeroPoint.val_0); \ + size_t output_index = (out_end - 1) >> 1; \ + \ + Output[output_index].val_0 = static_cast( \ + std::min(static_cast(INT4_TYPE::max_val), \ + std::max(static_cast(INT4_TYPE::min_val), ival))); \ + \ + out_end -= 1; \ + inp_end -= 1; \ + } \ + \ + if (out_start == out_end) { \ + return; \ + } \ + \ + /* At this point, should only need to quantize an *even* number of int4 elements that start and end at */ \ + /* a byte boundary. This is necessary to ensure that no two threads write to different int4 elements that */ \ + /* are stored in the same byte. */ \ + size_t N = out_end - out_start; \ + assert(N % 2 == 0); /* Should be guaranteed by previous code that quantizes boundary elements. */ \ + \ + constexpr std::ptrdiff_t block_size = 128; \ + static_assert(block_size % 2 == 0, \ + "Block size must also be even to ensure no two threads write to the same byte."); \ + \ + const std::ptrdiff_t num_blocks = (N + block_size - 1) / block_size; \ + const TensorOpCost unit_cost{static_cast(block_size * sizeof(float)), \ + static_cast(block_size * sizeof(INT4_TYPE::unpacked_type)) / 2.0, \ + static_cast(block_size) * 2.0}; \ + concurrency::ThreadPool::TryParallelFor( \ + thread_pool, num_blocks, unit_cost, \ + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { \ + auto begin_idx = begin * block_size; \ + auto end_idx = std::min(static_cast(N), end * block_size); \ + auto inp_idx = begin_idx + static_cast(inp_start); \ + auto out_idx = begin_idx + static_cast(out_start); \ + \ + MLAS_FUNC(&(Input[inp_idx]), \ + reinterpret_cast(&(Output[out_idx >> 1])), \ + end_idx - begin_idx, \ + Scale, \ + ZeroPoint.val_0); \ + \ + N -= (end_idx - begin_idx); \ + }); \ + } + +DEFINE_PAR_QUANT_LINEAR_STD_4BIT(ParQuantizeLinearStdS4, Int4x2, MlasQuantizeLinearS4) +DEFINE_PAR_QUANT_LINEAR_STD_4BIT(ParQuantizeLinearStdU4, UInt4x2, MlasQuantizeLinearU4) + // This implementation could be more efficient however the cast from float16 to other types // usually happens on GPU. template diff --git a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp index 986d158d2b1b9..d261207a0d532 100644 --- a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp +++ b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp @@ -71,6 +71,116 @@ class MlasQuantizeLinearTest : public MlasTestBase { } }; +template +class MlasQuantizeLinear4BitTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + + UnpackedType MinVal() const { + if constexpr (std::is_same_v) { + return -8; + } else if (std::is_same_v) { + return 0; + } + } + + UnpackedType MaxVal() const { + if constexpr (std::is_same_v) { + return 7; + } else { + static_assert(std::is_same_v); + return 15; + } + } + + void GenerateReference(const float* Input, UnpackedType* OutputReference, size_t N, float Scale, + UnpackedType ZeroPoint) { + for (size_t n = 0; n < N; n++) { + float FloatValue = std::nearbyintf(Input[n] / Scale) + float(ZeroPoint); + FloatValue = std::max(FloatValue, static_cast(MinVal())); + FloatValue = std::min(FloatValue, static_cast(MaxVal())); + + size_t i = n >> 1; + size_t j = n & 0x1; + + UnpackedType IntValue = static_cast(FloatValue); + + if (j == 0) { + OutputReference[i] = IntValue & 0xF; + } else { + OutputReference[i] |= static_cast((IntValue & 0xF) << 4); + } + } + } + + void Test(size_t N) { + size_t OutBufLen = (N + 1) / 2; + float* Input = BufferInput.GetBuffer(N); + UnpackedType* Output = BufferOutput.GetBuffer(OutBufLen); + UnpackedType* OutputReference = BufferOutputReference.GetBuffer(OutBufLen); + + std::default_random_engine generator(static_cast(N)); + + std::uniform_real_distribution min_gen(-10.f, -10e-3f); + float MinimumValue = min_gen(generator); + + std::uniform_real_distribution max_gen(10e-3f, 10.f); + float MaximumValue = max_gen(generator); + + float Scale = (MaximumValue - MinimumValue) / 32.f; + + std::uniform_int_distribution zp_distribution(MinVal(), MaxVal()); + UnpackedType ZeroPoint = static_cast(zp_distribution(generator)); + + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + for (size_t n = 0; n < N; n++) { + Input[n] = distribution(generator); + } + + GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); + + if constexpr (std::is_same_v) { + MlasQuantizeLinearS4(Input, Output, N, Scale, ZeroPoint); + } else { + static_assert(std::is_same_v); + MlasQuantizeLinearU4(Input, Output, N, Scale, ZeroPoint); + } + + for (size_t n = 0; n < N; n++) { + size_t i = n >> 1; + size_t j = n & 0x1; + + if (j == 0) { + ASSERT_EQ(Output[i] & 0xF, OutputReference[i] & 0xF) << ", size=" << N + << ", index=" << n + << ", nibble=" << j; + } else { + ASSERT_EQ((Output[i] >> 4) & 0xF, (OutputReference[i] >> 4) & 0xF) << ", size=" << N + << ", index=" << n + << ", nibble=" << j; + } + } + } + + public: + static const char* GetTestSuiteName() { + if constexpr (std::is_same_v) { + return "QuantizeLinearS4"; + } else { + static_assert(std::is_same_v); + return "QuantizeLinearU4"; + } + } + + void ExecuteShort(void) override { + for (size_t n = 1; n <= 512; n++) { + Test(n); + } + } +}; + static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { @@ -78,6 +188,8 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe count += MlasDirectShortExecuteTests>::RegisterShortExecute(); count += MlasDirectShortExecuteTests>::RegisterShortExecute(); count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } return count; }); diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 981388c21dfeb..983153d4b1e24 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -176,7 +176,7 @@ struct TensorCheck { cur_actual = actual.Data(); for (size_t i = 0; i < static_cast(size); ++i) { - size_t r = i >> 2; + size_t r = i >> 1; size_t c = i & 0x1; EXPECT_EQ(cur_expected[r][c], cur_actual[r][c]) << "i:" << i; } @@ -196,7 +196,7 @@ struct TensorCheck { cur_actual = actual.Data(); for (size_t i = 0; i < static_cast(size); ++i) { - size_t r = i >> 2; + size_t r = i >> 1; size_t c = i & 0x1; EXPECT_EQ(cur_expected[r][c], cur_actual[r][c]) << "i:" << i; } diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index dad8ae6f47180..4ac79a2d02bdb 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -388,7 +388,7 @@ TEST(QuantizeLinearOpTest, Int4) { -16.0f, // Close to qmin -3.0f, // round 0.0f, // Zero-point - 3.0f, // round + 2.9f, // round 12.0f, // qmax 20.0f, // Clamp to qmax }); @@ -411,14 +411,14 @@ TEST(QuantizeLinearOpTest, UInt4) { -8.0f, // qmin -3.0f, // round 0.0f, // Zero-point - 3.0f, // round + 2.9f, // round 22.0f, // qmax - 20.0f, // Clamp to qmax + 30.0f, // Clamp to qmax }); test.AddInput("scale", {}, {2.0f}, true); test.AddInput("zero_point", {}, {UInt4x2(4, unused_val)}, true); test.AddOutput("y", dims, - {UInt4x2(0, 0), UInt4x2(2, 4), UInt4x2(6, 15), + {UInt4x2(0, 0), UInt4x2(2, 4), UInt4x2(5, 15), UInt4x2(15, unused_val)}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); From 2935f79702edfdaa22a9af8813716ff4d5df6c36 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 21 Apr 2024 11:02:20 -0700 Subject: [PATCH 34/72] branchless update of 4-bit element --- onnxruntime/core/mlas/lib/quantize.cpp | 26 +++++++++---------- onnxruntime/core/util/qmath.h | 3 +-- .../mlas/unittest/test_quantizelinear.cpp | 13 +++++----- 3 files changed, 19 insertions(+), 23 deletions(-) diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index 5d3a551f4addf..12599b8b582e9 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -571,14 +571,13 @@ MlasQuantizeLinearS4Kernel( MlasQuantizeLinearStoreSingleValue(IntegerVector, &TmpOutput[0]); - size_t output_index = n >> 1; // which byte - size_t nibble_index = n & 0x1; // which 4-bit elem in the byte + size_t OutputIndex = n >> 1; // which byte + size_t NibbleIndex = n & 0x1; // which 4-bit elem in the byte + uint8_t Shift = 4 * static_cast(NibbleIndex); + int8_t Mask = 0xF << Shift; - if (nibble_index == 0) { - Output[output_index] = static_cast(TmpOutput[0] & 0xF); - } else { - Output[output_index] |= static_cast((TmpOutput[0] & 0xF) << 4); - } + Output[OutputIndex] &= ~Mask; // Clear 4-bit lane + Output[OutputIndex] |= static_cast((TmpOutput[0] & 0xF) << Shift); // Set 4-bit lane } } @@ -634,14 +633,13 @@ MlasQuantizeLinearU4Kernel( MlasQuantizeLinearStoreSingleValue(IntegerVector, &TmpOutput[0]); - size_t output_index = n >> 1; // which byte - size_t nibble_index = n & 0x1; // which 4-bit elem in the byte + size_t OutputIndex = n >> 1; // which byte + size_t NibbleIndex = n & 0x1; // which 4-bit elem in the byte + uint8_t Shift = 4 * static_cast(NibbleIndex); + uint8_t Mask = 0xF << Shift; - if (nibble_index == 0) { - Output[output_index] = static_cast(TmpOutput[0] & 0xF); - } else { - Output[output_index] |= static_cast((TmpOutput[0] & 0xF) << 4); - } + Output[OutputIndex] &= ~Mask; // Clear 4-bit lane + Output[OutputIndex] |= static_cast((TmpOutput[0] & 0xF) << Shift); // Set 4-bit lane } } diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index d5ba90b6fc90a..588974e6eba43 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -130,7 +130,6 @@ ParQuantizeLinearStd(const float* Input, auto begin_idx = begin * block_size; auto end_idx = std::min(static_cast(N), end * block_size); MlasQuantizeLinear(&(Input[begin_idx]), &(Output[begin_idx]), end_idx - begin_idx, Scale, ZeroPoint); - N -= (end_idx - begin_idx); }); } @@ -205,7 +204,7 @@ ParQuantizeLinearStd(const float* Input, Scale, \ ZeroPoint.val_0); \ \ - N -= (end_idx - begin_idx); \ + N -= static_cast(end_idx - begin_idx); \ }); \ } diff --git a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp index d261207a0d532..876a000014bc9 100644 --- a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp +++ b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp @@ -102,16 +102,15 @@ class MlasQuantizeLinear4BitTest : public MlasTestBase { FloatValue = std::max(FloatValue, static_cast(MinVal())); FloatValue = std::min(FloatValue, static_cast(MaxVal())); + UnpackedType IntValue = static_cast(FloatValue); + size_t i = n >> 1; size_t j = n & 0x1; + uint8_t Shift = 4 * static_cast(j); + UnpackedType Mask = 0xF << Shift; - UnpackedType IntValue = static_cast(FloatValue); - - if (j == 0) { - OutputReference[i] = IntValue & 0xF; - } else { - OutputReference[i] |= static_cast((IntValue & 0xF) << 4); - } + OutputReference[i] &= ~Mask; // Clear 4-bit lane + OutputReference[i] |= static_cast((IntValue & 0xF) << Shift); // Set 4-bit lane } } From 807537cd8b4c8806b8c4ddfaa2192219c85d3a76 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 21 Apr 2024 11:05:41 -0700 Subject: [PATCH 35/72] more branchless update of int4 lane --- onnxruntime/core/mlas/lib/quantize.cpp | 28 +++++++++++--------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index 12599b8b582e9..702c58bf57627 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -1005,20 +1005,18 @@ MlasQuantizeLinearS4( constexpr int32_t MaximumValue = 7; for (size_t n = 0; n < N; n++) { - float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); FloatValue = std::max(FloatValue, static_cast(MinimumValue)); FloatValue = std::min(FloatValue, static_cast(MaximumValue)); int8_t IntValue = static_cast(FloatValue); - size_t i = n >> 1; // which byte - size_t j = n & 0x1; // which 4-bit elem in the byte + size_t OutputIndex = n >> 1; // which byte + size_t NibbleIndex = n & 0x1; // which 4-bit elem in the byte + uint8_t Shift = 4 * static_cast(NibbleIndex); + int8_t Mask = 0xF << Shift; - if (j == 0) { - Output[i] = IntValue & 0xF; - } else { - Output[i] |= static_cast((IntValue & 0xF) << 4); - } + Output[OutputIndex] &= ~Mask; // Clear 4-bit lane + Output[OutputIndex] |= static_cast((IntValue & 0xF) << Shift); // Set 4-bit lane } } @@ -1037,20 +1035,18 @@ MlasQuantizeLinearU4( constexpr int32_t MaximumValue = 15; for (size_t n = 0; n < N; n++) { - float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); FloatValue = std::max(FloatValue, static_cast(MinimumValue)); FloatValue = std::min(FloatValue, static_cast(MaximumValue)); uint8_t IntValue = static_cast(FloatValue); - size_t i = n >> 1; // which byte - size_t j = n & 0x1; // which 4-bit elem in the byte + size_t OutputIndex = n >> 1; // which byte + size_t NibbleIndex = n & 0x1; // which 4-bit elem in the byte + uint8_t Shift = 4 * static_cast(NibbleIndex); + uint8_t Mask = 0xF << Shift; - if (j == 0) { - Output[i] = IntValue & 0xF; - } else { - Output[i] |= static_cast((IntValue & 0xF) << 4); - } + Output[OutputIndex] &= ~Mask; // Clear 4-bit lane + Output[OutputIndex] |= static_cast((IntValue & 0xF) << Shift); // Set 4-bit lane } } #endif From a36a128d35526e5ac5ace244fedd729c5590bf75 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 21 Apr 2024 12:24:55 -0700 Subject: [PATCH 36/72] Fix cast warning as error --- onnxruntime/core/providers/cpu/quantization/quantize_linear.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index 425befc05b805..a14034ef511fe 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -522,7 +522,7 @@ void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const QUANT_FUNC(input, output, output_index, output_index + static_cast(block_size), \ scale[bd], INT4_TYPE(zp, 0), ctx->GetOperatorThreadPool()); \ input += block_size; \ - output_index += block_size; \ + output_index += static_cast(block_size); \ } \ } \ assert(output_index == static_cast(N * broadcast_dim * block_size)); \ From bc44557104666646a4b24b04a8ba6b1144fc1943 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 22 Apr 2024 00:40:58 -0700 Subject: [PATCH 37/72] Remove decrement of N --- onnxruntime/core/util/qmath.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index 588974e6eba43..061fb2d057037 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -203,8 +203,6 @@ ParQuantizeLinearStd(const float* Input, end_idx - begin_idx, \ Scale, \ ZeroPoint.val_0); \ - \ - N -= static_cast(end_idx - begin_idx); \ }); \ } From f40992dc51fa11d911b1aabf65a31f40d26c9e27 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 22 Apr 2024 04:14:34 -0700 Subject: [PATCH 38/72] Clean up Int4x2 class --- include/onnxruntime/core/framework/int4.h | 131 ++++++++---------- .../cpu/quantization/quantize_linear.cc | 48 ++++--- onnxruntime/core/util/qmath.h | 12 +- .../mlas/unittest/test_quantizelinear.cpp | 4 +- onnxruntime/test/providers/checkers.cc | 4 +- .../cpu/tensor/quantize_linear_test.cc | 68 +++++++++ 6 files changed, 165 insertions(+), 102 deletions(-) diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h index 0268f57c0457b..c303e370a9744 100644 --- a/include/onnxruntime/core/framework/int4.h +++ b/include/onnxruntime/core/framework/int4.h @@ -4,96 +4,84 @@ #pragma once #include +#include #include "endian.h" #include "core/common/common.h" #include "core/common/gsl.h" namespace onnxruntime { -struct Int4x2 { - using unpacked_type = int8_t; - static constexpr unpacked_type min_val = -8; - static constexpr unpacked_type max_val = 7; - - int8_t val_0 : 4; - int8_t val_1 : 4; - - Int4x2() : val_0{0}, val_1{0} {} - Int4x2(uint8_t bits) { - val_0 = static_cast(bits & 0xF); - val_1 = static_cast((bits >> 4) & 0xF); - } - Int4x2(int8_t lo, int8_t hi) : val_0{lo}, val_1{hi} {} - inline int8_t operator[](size_t index) const { - assert(index <= 1); - return index == 0 ? val_0 : val_1; - } - - inline uint8_t ToBits() const { - return (static_cast(val_1) << 4) | (static_cast(val_0) & 0xF); - } +template +struct UnpackedTypeTraits; - static bool Unpack(gsl::span dst, gsl::span src) { - if (((dst.size() + 1) / 2) != src.size()) { - return false; - } +template <> +struct UnpackedTypeTraits { + static constexpr int8_t min_val = -8; + static constexpr int8_t max_val = 7; +}; - for (size_t i = 0; i < dst.size(); i++) { - size_t r = i >> 1; // i / 2; - size_t c = i & 0x1; // i % 2; - dst[i] = src[r][c]; - } +template <> +struct UnpackedTypeTraits { + static constexpr uint8_t min_val = 0; + static constexpr uint8_t max_val = 15; +}; - return true; - } +template +struct Int4x2Base { + using unpacked_type = T; + static constexpr unpacked_type min_val = UnpackedTypeTraits::min_val; + static constexpr unpacked_type max_val = UnpackedTypeTraits::max_val; - static bool Pack(gsl::span dst, gsl::span src) { - if (((src.size() + 1) / 2) != dst.size()) { - return false; - } + unpacked_type elems{}; - size_t src_i = 0; - size_t dst_i = 0; - - for (; src_i < src.size() - 1; src_i += 2) { - dst[dst_i++] = Int4x2(src[src_i], src[src_i + 1]); - } + Int4x2Base() = default; + Int4x2Base(uint8_t bits) { + elems = static_cast(bits); + } + Int4x2Base(unpacked_type val0, unpacked_type val1) { + elems = static_cast(((val1 & 0xF) << 4) | (val0 & 0xF)); + } - if (src_i < src.size()) { - dst[dst_i] = Int4x2(src[src_i], 0); + inline unpacked_type GetElem0() const { + if constexpr (std::is_same_v) { + // Need to sign-extend lower 4-bits by left shifting and then doing an arithmetic right shift. + return static_cast(static_cast((elems << 4)) >> 4); + } else { + return static_cast(elems & 0xF); } - - return true; } -}; - -static_assert(sizeof(Int4x2) == sizeof(int8_t)); -struct UInt4x2 { - using unpacked_type = uint8_t; - static constexpr unpacked_type min_val = 0; - static constexpr unpacked_type max_val = 15; - - uint8_t val_0 : 4; - uint8_t val_1 : 4; + inline unpacked_type GetElem1() const { + return static_cast(elems >> 4); + } - UInt4x2() : val_0{0}, val_1{0} {} - UInt4x2(uint8_t bits) { - val_0 = bits & 0xF; - val_1 = (bits >> 4) & 0xF; + inline unpacked_type GetElem(size_t index) const { + assert(index <= 1); + const uint8_t shift = 4 * static_cast(index); + + if constexpr (std::is_same_v) { + // if index is 0, need to sign-extend lower 4-bits by left shifting and then doing an arithmetic right shift. + const uint8_t unshift = 4 - shift; + return static_cast(static_cast((elems >> shift) << unshift) >> unshift); + } else { + return static_cast((elems >> shift) & 0xF); + } } - UInt4x2(uint8_t lo, uint8_t hi) : val_0{lo}, val_1{hi} {} - inline uint8_t operator[](size_t index) const { + inline void SetElem(size_t index, unpacked_type val) { assert(index <= 1); - return index == 0 ? val_0 : val_1; + const uint8_t shift = 4 * static_cast(index); + const unpacked_type mask = 0xF << shift; + + elems &= ~mask; // Clear 4-bit element to 0 + elems |= static_cast((val & 0xF) << shift); // Set 4-bit element to val } inline uint8_t ToBits() const { - return (static_cast(val_1) << 4) | (static_cast(val_0) & 0xF); + return static_cast(elems); } - static bool Unpack(gsl::span dst, gsl::span src) { + static bool Unpack(gsl::span dst, gsl::span> src) { if (((dst.size() + 1) / 2) != src.size()) { return false; } @@ -101,13 +89,13 @@ struct UInt4x2 { for (size_t i = 0; i < dst.size(); i++) { size_t r = i >> 1; // i / 2; size_t c = i & 0x1; // i % 2; - dst[i] = src[r][c]; + dst[i] = src[r].GetElem(c); } return true; } - static bool Pack(gsl::span dst, gsl::span src) { + static bool Pack(gsl::span> dst, gsl::span src) { if (((src.size() + 1) / 2) != dst.size()) { return false; } @@ -116,16 +104,19 @@ struct UInt4x2 { size_t dst_i = 0; for (; src_i < src.size() - 1; src_i += 2) { - dst[dst_i++] = UInt4x2(src[src_i], src[src_i + 1]); + dst[dst_i++] = Int4x2Base(src[src_i], src[src_i + 1]); } if (src_i < src.size()) { - dst[dst_i] = UInt4x2(src[src_i], 0); + dst[dst_i] = Int4x2Base(src[src_i], 0); } return true; } }; +using Int4x2 = Int4x2Base; +using UInt4x2 = Int4x2Base; +static_assert(sizeof(Int4x2) == sizeof(int8_t)); static_assert(sizeof(UInt4x2) == sizeof(uint8_t)); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index a14034ef511fe..f2572f6b35a51 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -238,27 +238,29 @@ struct DequantizeLinearApply { } }; -#define DEQUANTIZE_LINEAR_APPLY_INT4(T) \ - template \ - struct DequantizeLinearApply { \ - void op(int64_t N, int64_t broadcast_dim, int64_t block_size, const T* input, const OutT* scale, OutT* output, const T* zero_point) { \ - size_t input_index = 0; \ - for (size_t n = 0; n < static_cast(N); n++) { \ - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ - size_t bd_i = bd >> 1; /*bd / 2*/ \ - size_t bd_j = bd & 0x1; /*bd % 2*/ \ - auto zp = zero_point ? static_cast(zero_point[bd_i][bd_j]) : 0; \ - auto sc = static_cast(scale[bd]); \ - for (size_t bs = 0; bs < static_cast(block_size); bs++) { \ - size_t input_i = input_index >> 1; \ - size_t input_j = input_index & 0x1; \ - *output++ = static_cast(static_cast(static_cast(input[input_i][input_j]) - zp) * sc); \ - input_index += 1; \ - } \ - } \ - } \ - assert(input_index == static_cast(N * broadcast_dim * block_size)); \ - } \ +#define DEQUANTIZE_LINEAR_APPLY_INT4(T) \ + template \ + struct DequantizeLinearApply { \ + void op(int64_t N, int64_t broadcast_dim, int64_t block_size, const T* input, const OutT* scale, \ + OutT* output, const T* zero_point) { \ + size_t input_index = 0; \ + for (size_t n = 0; n < static_cast(N); n++) { \ + for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ + size_t bd_i = bd >> 1; /*bd / 2*/ \ + size_t bd_j = bd & 0x1; /*bd % 2*/ \ + auto zp = zero_point ? static_cast(zero_point[bd_i].GetElem(bd_j)) : 0; \ + auto sc = static_cast(scale[bd]); \ + for (size_t bs = 0; bs < static_cast(block_size); bs++) { \ + size_t input_i = input_index >> 1; \ + size_t input_j = input_index & 0x1; \ + int32_t val = static_cast(input[input_i].GetElem(input_j)); \ + *output++ = static_cast(static_cast(val - zp) * sc); \ + input_index += 1; \ + } \ + } \ + } \ + assert(input_index == static_cast(N * broadcast_dim * block_size)); \ + } \ }; DEQUANTIZE_LINEAR_APPLY_INT4(Int4x2); @@ -518,7 +520,7 @@ void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ size_t bd_i = bd >> 1; /*bd / 2*/ \ size_t bd_j = bd & 0x1; /*bd % 2*/ \ - INT4_TYPE::unpacked_type zp = zero_point ? zero_point[bd_i][bd_j] : 0; \ + INT4_TYPE::unpacked_type zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ QUANT_FUNC(input, output, output_index, output_index + static_cast(block_size), \ scale[bd], INT4_TYPE(zp, 0), ctx->GetOperatorThreadPool()); \ input += block_size; \ @@ -549,7 +551,7 @@ DEFINE_COMPUTE_LOOP_FP32_TO_INT4(UInt4x2, ParQuantizeLinearStdU4) for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ size_t bd_i = bd >> 1; /*bd / 2*/ \ size_t bd_j = bd & 0x1; /*bd % 2*/ \ - INT4_TYPE::unpacked_type zp = zero_point ? zero_point[bd_i][bd_j] : 0; \ + INT4_TYPE::unpacked_type zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, \ static_cast(block_size), scale[bd], \ zp, ctx->GetOperatorThreadPool()); \ diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index 061fb2d057037..bd2e410020c5a 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -147,12 +147,13 @@ ParQuantizeLinearStd(const float* Input, /* If starting at an int4 element in the middle of a byte, quantize it by itself. */ \ if (out_start & 0x1) { \ int32_t ival = static_cast(std::nearbyintf(Input[inp_start] / Scale)) + \ - static_cast(ZeroPoint.val_0); \ + static_cast(ZeroPoint.GetElem0()); \ size_t output_index = out_start >> 1; \ \ - Output[output_index].val_1 = static_cast( \ + INT4_TYPE::unpacked_type quant_val = static_cast( \ std::min(static_cast(INT4_TYPE::max_val), \ std::max(static_cast(INT4_TYPE::min_val), ival))); \ + Output[output_index].SetElem(1, quant_val); \ \ out_start += 1; \ inp_start += 1; \ @@ -161,12 +162,13 @@ ParQuantizeLinearStd(const float* Input, /* If ending at element that ends in the middle of a byte, quantize it by itself. */ \ if (out_end & 0x1) { \ int32_t ival = static_cast(std::nearbyintf(Input[inp_end - 1] / Scale)) + \ - static_cast(ZeroPoint.val_0); \ + static_cast(ZeroPoint.GetElem0()); \ size_t output_index = (out_end - 1) >> 1; \ \ - Output[output_index].val_0 = static_cast( \ + INT4_TYPE::unpacked_type quant_val = static_cast( \ std::min(static_cast(INT4_TYPE::max_val), \ std::max(static_cast(INT4_TYPE::min_val), ival))); \ + Output[output_index].SetElem(0, quant_val); \ \ out_end -= 1; \ inp_end -= 1; \ @@ -202,7 +204,7 @@ ParQuantizeLinearStd(const float* Input, reinterpret_cast(&(Output[out_idx >> 1])), \ end_idx - begin_idx, \ Scale, \ - ZeroPoint.val_0); \ + ZeroPoint.GetElem0()); \ }); \ } diff --git a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp index 876a000014bc9..e11985ce82333 100644 --- a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp +++ b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp @@ -109,8 +109,8 @@ class MlasQuantizeLinear4BitTest : public MlasTestBase { uint8_t Shift = 4 * static_cast(j); UnpackedType Mask = 0xF << Shift; - OutputReference[i] &= ~Mask; // Clear 4-bit lane - OutputReference[i] |= static_cast((IntValue & 0xF) << Shift); // Set 4-bit lane + OutputReference[i] &= ~Mask; // Clear 4-bit lane + OutputReference[i] |= static_cast((IntValue & 0xF) << Shift); // Set 4-bit lane } } diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 983153d4b1e24..d0e08448ce456 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -178,7 +178,7 @@ struct TensorCheck { for (size_t i = 0; i < static_cast(size); ++i) { size_t r = i >> 1; size_t c = i & 0x1; - EXPECT_EQ(cur_expected[r][c], cur_actual[r][c]) << "i:" << i; + EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; } } }; @@ -198,7 +198,7 @@ struct TensorCheck { for (size_t i = 0; i < static_cast(size); ++i) { size_t r = i >> 1; size_t c = i & 0x1; - EXPECT_EQ(cur_expected[r][c], cur_actual[r][c]) << "i:" << i; + EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; } } }; diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 4ac79a2d02bdb..98c8f866b5475 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -424,6 +424,74 @@ TEST(QuantizeLinearOpTest, UInt4) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +template +static void GetExpectedInt4Quant(const float* input, Int4x2Base* output, size_t num_elems, float scale, + T zero_point) { + for (size_t n = 0; n < num_elems; n++) { + float float_val = std::nearbyintf(input[n] / scale) + float(zero_point); + float_val = std::max(float_val, static_cast(Int4x2Base::min_val)); + float_val = std::min(float_val, static_cast(Int4x2Base::max_val)); + + T int_val = static_cast(float_val); + + size_t i = n >> 1; + size_t j = n & 0x1; + output[i].SetElem(j, int_val); + } +} + +// Test int4 QuantizeLinear (per tensor) with a "large" and odd number of input elements. +// This exercises the TryParallelFor call which splits the input into blocks of even size. +TEST(QuantizeLinearOpTest, OddLarge_Int4) { + OpTester test("QuantizeLinear", 21); + std::vector dims{1017}; + constexpr int8_t unused_val = 0; + constexpr std::array pattern = {-20.0f, -14.0f, -4.1f, -0.0f, 3.0f, 3.3f}; + std::vector input_f32s(static_cast(dims[0])); + std::vector output((input_f32s.size() + 1) / 2); + + for (size_t i = 0; i < input_f32s.size(); ++i) { + input_f32s[i] = pattern[i % pattern.size()]; + } + + float scale = 2.0f; + int8_t zp = 1; + GetExpectedInt4Quant(input_f32s.data(), &output[0], input_f32s.size(), scale, zp); + + test.AddInput("x", dims, input_f32s); + test.AddInput("scale", {}, {scale}, true); + test.AddInput("zero_point", {}, {Int4x2(zp, unused_val)}, true); + test.AddOutput("y", dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test uint4 QuantizeLinear (per tensor) with a "large" and odd number of input elements. +// This exercises the TryParallelFor call which splits the input into blocks of even size. +TEST(QuantizeLinearOpTest, OddLarge_UInt4) { + OpTester test("QuantizeLinear", 21); + std::vector dims{1017}; + constexpr uint8_t unused_val = 0; + constexpr std::array pattern = {-20.0f, -14.0f, -4.1f, -0.0f, 3.0f, 3.3f}; + std::vector input_f32s(static_cast(dims[0])); + std::vector output((input_f32s.size() + 1) / 2); + + for (size_t i = 0; i < input_f32s.size(); ++i) { + input_f32s[i] = pattern[i % pattern.size()]; + } + + float scale = 2.0f; + uint8_t zp = 1; + GetExpectedInt4Quant(input_f32s.data(), &output[0], input_f32s.size(), scale, zp); + + test.AddInput("x", dims, input_f32s); + test.AddInput("scale", {}, {scale}, true); + test.AddInput("zero_point", {}, {UInt4x2(zp, unused_val)}, true); + test.AddOutput("y", dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + // quantize with scalar zero point and scale TEST(QuantizeLinearOpTest, Int8_NegativeZeroPoint) { // TODO: Unskip when fixed #41968513 From d0e17e212c0c778ee94d3aaa0613e96c424f8020 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 22 Apr 2024 14:09:47 -0700 Subject: [PATCH 39/72] Github linter fixes --- include/onnxruntime/core/framework/int4.h | 3 +- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 39 ++++++++++++------- onnxruntime/core/mlas/lib/quantize.cpp | 8 ++-- .../selectors_actions/qdq_selectors.h | 1 + .../core/providers/cpu/tensor/transpose.cc | 1 + .../providers/shared_library/provider_api.h | 8 +++- .../provider_bridge_provider.cc | 8 +++- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- .../core/session/provider_bridge_ort.cc | 8 +++- .../mlas/unittest/test_quantizelinear.cpp | 19 ++++----- .../cpu/cpu_execution_provider_test.cc | 2 +- .../cpu/tensor/quantize_linear_test.cc | 2 +- 12 files changed, 62 insertions(+), 39 deletions(-) diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h index c303e370a9744..7fc7d9836d312 100644 --- a/include/onnxruntime/core/framework/int4.h +++ b/include/onnxruntime/core/framework/int4.h @@ -5,7 +5,6 @@ #include #include -#include "endian.h" #include "core/common/common.h" #include "core/common/gsl.h" @@ -35,7 +34,7 @@ struct Int4x2Base { unpacked_type elems{}; Int4x2Base() = default; - Int4x2Base(uint8_t bits) { + explicit Int4x2Base(uint8_t bits) { elems = static_cast(bits); } Int4x2Base(unpacked_type val0, unpacked_type val1) { diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index b6d86cfafb035..2802e93d5867f 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -203,19 +203,32 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index 702c58bf57627..12df31b42a7e4 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -533,8 +533,8 @@ MlasQuantizeLinearS4Kernel( constexpr int32_t MaximumValue = 7; auto ScaleVector = MlasBroadcastFloat32x4(Scale); - auto MinimumValueVector = MlasBroadcastFloat32x4(float(MinimumValue - ZeroPoint)); - auto MaximumValueVector = MlasBroadcastFloat32x4(float(MaximumValue - ZeroPoint)); + auto MinimumValueVector = MlasBroadcastFloat32x4(static_cast(MinimumValue - ZeroPoint)); + auto MaximumValueVector = MlasBroadcastFloat32x4(static_cast(MaximumValue - ZeroPoint)); auto ZeroPointVector = MlasBroadcastInt32x4(ZeroPoint); // Holds 4 quantized 8bit values that will be packed into the output as packed 4bit values. @@ -595,8 +595,8 @@ MlasQuantizeLinearU4Kernel( constexpr int32_t MaximumValue = 15; auto ScaleVector = MlasBroadcastFloat32x4(Scale); - auto MinimumValueVector = MlasBroadcastFloat32x4(float(MinimumValue - ZeroPoint)); - auto MaximumValueVector = MlasBroadcastFloat32x4(float(MaximumValue - ZeroPoint)); + auto MinimumValueVector = MlasBroadcastFloat32x4(static_cast(MinimumValue - ZeroPoint)); + auto MaximumValueVector = MlasBroadcastFloat32x4(static_cast(MaximumValue - ZeroPoint)); auto ZeroPointVector = MlasBroadcastInt32x4(ZeroPoint); // Holds 4 quantized 8bit values that will be packed into the output as packed 4bit values. diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 4446f6f4e6b63..5a40f0fbde595 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -5,6 +5,7 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +#include #include "core/framework/node_unit.h" #include "core/optimizer/selectors_actions/selector_action_transformer.h" diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index 3aec1605527be..849adb73e784d 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -3,6 +3,7 @@ #include "core/providers/cpu/tensor/transpose.h" +#include #include "core/framework/element_type_lists.h" #include "core/framework/utils.h" #include "core/framework/transpose_helper.h" diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 5cc5c2302df6d..8b4b1c785ee3b 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -350,9 +350,13 @@ template <> constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ; } #endif template <> -constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; } +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; +} template <> -constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; } +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; +} } // namespace utils // This is a replacement for Ort::InitApi() to be called before any other onnxruntime API calls. diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 112dc9abb5f97..a6ed7098a26cc 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -257,9 +257,13 @@ template <> MLDataType DataTypeImpl::GetSparseTensorType() { return Provider_GetHost()->DataTypeImpl__GetSparseTensorType_Float8E5M2FNUZ(); } #endif template <> -MLDataType DataTypeImpl::GetSparseTensorType() { return Provider_GetHost()->DataTypeImpl__GetSparseTensorType_Int4x2(); } +MLDataType DataTypeImpl::GetSparseTensorType() { + return Provider_GetHost()->DataTypeImpl__GetSparseTensorType_Int4x2(); +} template <> -MLDataType DataTypeImpl::GetSparseTensorType() { return Provider_GetHost()->DataTypeImpl__GetSparseTensorType_UInt4x2(); } +MLDataType DataTypeImpl::GetSparseTensorType() { + return Provider_GetHost()->DataTypeImpl__GetSparseTensorType_UInt4x2(); +} #endif diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 4de53a19a7c32..2336b4a5048ab 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -217,7 +217,7 @@ ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t auto elem_count = narrow(tensor_shape.Size()); - // TODO: Handle this more cleanly. + // TODO(adrianlizarraga): Handle this more cleanly. if (utils::IsPrimitiveDataType(ml_type) || utils::IsPrimitiveDataType(ml_type)) { elem_count = (elem_count + 1) / 2; } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index ec96edaa2456e..349d6161c5029 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -844,8 +844,12 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetSparseTensorType_Float8E5M2() override { return DataTypeImpl::GetSparseTensorType(); } MLDataType DataTypeImpl__GetSparseTensorType_Float8E5M2FNUZ() override { return DataTypeImpl::GetSparseTensorType(); } #endif - MLDataType DataTypeImpl__GetSparseTensorType_Int4x2() override { return DataTypeImpl::GetSparseTensorType(); } - MLDataType DataTypeImpl__GetSparseTensorType_UInt4x2() override { return DataTypeImpl::GetSparseTensorType(); } + MLDataType DataTypeImpl__GetSparseTensorType_Int4x2() override { + return DataTypeImpl::GetSparseTensorType(); + } + MLDataType DataTypeImpl__GetSparseTensorType_UInt4x2() override { + return DataTypeImpl::GetSparseTensorType(); + } #endif const char* DataTypeImpl__ToString(MLDataType type) override { return DataTypeImpl::ToString(type); } diff --git a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp index e11985ce82333..d771ae4fb4c19 100644 --- a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp +++ b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp @@ -98,7 +98,7 @@ class MlasQuantizeLinear4BitTest : public MlasTestBase { void GenerateReference(const float* Input, UnpackedType* OutputReference, size_t N, float Scale, UnpackedType ZeroPoint) { for (size_t n = 0; n < N; n++) { - float FloatValue = std::nearbyintf(Input[n] / Scale) + float(ZeroPoint); + float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); FloatValue = std::max(FloatValue, static_cast(MinVal())); FloatValue = std::min(FloatValue, static_cast(MaxVal())); @@ -150,16 +150,13 @@ class MlasQuantizeLinear4BitTest : public MlasTestBase { for (size_t n = 0; n < N; n++) { size_t i = n >> 1; size_t j = n & 0x1; - - if (j == 0) { - ASSERT_EQ(Output[i] & 0xF, OutputReference[i] & 0xF) << ", size=" << N - << ", index=" << n - << ", nibble=" << j; - } else { - ASSERT_EQ((Output[i] >> 4) & 0xF, (OutputReference[i] >> 4) & 0xF) << ", size=" << N - << ", index=" << n - << ", nibble=" << j; - } + const uint8_t Shift = 4 * static_cast(j); + const uint8_t Unshift = 4 - Shift; + UnpackedType actual_val = static_cast((Output[i] >> Shift) << Unshift) >> Unshift; + UnpackedType expected_val = static_cast((OutputReference[i] >> Shift) << Unshift) >> Unshift; + ASSERT_EQ(actual_val, expected_val) << ", size=" << N + << ", index=" << n + << ", nibble=" << j; } } diff --git a/onnxruntime/test/providers/cpu/cpu_execution_provider_test.cc b/onnxruntime/test/providers/cpu/cpu_execution_provider_test.cc index 325f305dad537..eaadbb6bd9455 100644 --- a/onnxruntime/test/providers/cpu/cpu_execution_provider_test.cc +++ b/onnxruntime/test/providers/cpu/cpu_execution_provider_test.cc @@ -22,7 +22,7 @@ TEST(CPUExecutionProviderTest, MetadataTest) { ASSERT_EQ(provider->GetOrtDeviceByMemType(OrtMemTypeDefault).Type(), OrtDevice::CPU); } -// TODO: Remove. This is a throwaway test for Int4 +// TODO(adrianlizarraga): Remove. This is a throwaway test for Int4 TEST(CPUExecutionProviderTest, Example_Conv_Int4) { Ort::SessionOptions so; diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 98c8f866b5475..b4e30f7aa22f4 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -428,7 +428,7 @@ template static void GetExpectedInt4Quant(const float* input, Int4x2Base* output, size_t num_elems, float scale, T zero_point) { for (size_t n = 0; n < num_elems; n++) { - float float_val = std::nearbyintf(input[n] / scale) + float(zero_point); + float float_val = std::nearbyintf(input[n] / scale) + static_cast(zero_point); float_val = std::max(float_val, static_cast(Int4x2Base::min_val)); float_val = std::min(float_val, static_cast(Int4x2Base::max_val)); From 6568d48a5bf67537f0071652cfa0c43b98e46d0c Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 22 Apr 2024 14:13:30 -0700 Subject: [PATCH 40/72] Remove temporary unittest --- .../cpu/cpu_execution_provider_test.cc | 52 ------------------ .../test/testdata/conv.int4.int8.qdq.onnx | Bin 1783 -> 0 bytes 2 files changed, 52 deletions(-) delete mode 100644 onnxruntime/test/testdata/conv.int4.int8.qdq.onnx diff --git a/onnxruntime/test/providers/cpu/cpu_execution_provider_test.cc b/onnxruntime/test/providers/cpu/cpu_execution_provider_test.cc index eaadbb6bd9455..8b9dcbd943b4a 100644 --- a/onnxruntime/test/providers/cpu/cpu_execution_provider_test.cc +++ b/onnxruntime/test/providers/cpu/cpu_execution_provider_test.cc @@ -2,16 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/cpu_execution_provider.h" -#include "core/session/onnxruntime_cxx_api.h" -#include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/session/onnxruntime_run_options_config_keys.h" #include "gtest/gtest.h" -#include "gmock/gmock.h" - -#define ORT_MODEL_FOLDER ORT_TSTR("testdata/") - -// in test_main.cc -extern std::unique_ptr ort_env; namespace onnxruntime { namespace test { @@ -21,48 +12,5 @@ TEST(CPUExecutionProviderTest, MetadataTest) { EXPECT_TRUE(provider != nullptr); ASSERT_EQ(provider->GetOrtDeviceByMemType(OrtMemTypeDefault).Type(), OrtDevice::CPU); } - -// TODO(adrianlizarraga): Remove. This is a throwaway test for Int4 -TEST(CPUExecutionProviderTest, Example_Conv_Int4) { - Ort::SessionOptions so; - - // Ensure all type/shape inference warnings result in errors! - so.AddConfigEntry(kOrtSessionOptionsConfigStrictShapeTypeInference, "1"); - so.SetGraphOptimizationLevel(ORT_ENABLE_ALL); - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "conv.int4.int8.qdq.onnx"; - Ort::Session session(*ort_env, ort_model_path, so); - - std::array input0_data = {}; - for (size_t i = 0; i < input0_data.size(); i++) { - input0_data[i] = 0.2f; - } - - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - std::vector ort_inputs; - std::vector ort_input_names; - - // Add input0 - std::array inputs_shape{1, 3, 8, 8}; - ort_inputs.emplace_back(Ort::Value::CreateTensor( - memory_info, input0_data.data(), input0_data.size(), inputs_shape.data(), inputs_shape.size())); - ort_input_names.push_back("input_0"); - - // Run session and get outputs - std::array output_names{"output_0"}; - std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), - ort_inputs.size(), output_names.data(), output_names.size()); - - // Check output shape. - Ort::Value& ort_output = ort_outputs[0]; - auto typeshape = ort_output.GetTensorTypeAndShapeInfo(); - std::vector output_shape = typeshape.GetShape(); - - EXPECT_THAT(output_shape, ::testing::ElementsAre(1, 5, 6, 6)); - const float* results = ort_output.GetTensorData(); - - for (size_t i = 0; i < typeshape.GetElementCount(); i++) { - std::cout << i << ": " << results[i] << std::endl; - } -} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/conv.int4.int8.qdq.onnx b/onnxruntime/test/testdata/conv.int4.int8.qdq.onnx deleted file mode 100644 index a5f83ac76489a42aa5b0fced2100297f49daa8ce..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1783 zcmdr6v|B380H;adWXGR%8}SFf3qX;wt81 z&&(?*Er~ba;)ODx77D?*NH)pB!~&rnfmjmnUs?h*LrM(0v=Sexa;|PhE`jpY%=C;B zByaOTMIhc4gz}IqmxBtSdkAK`I1YJ2LD0|0h2LMqgp54Q>$qGeflC3V^MH{;$T%(n zfx!h%6GF22;E0Jgh)0cqcu!EAOR+lV=am^Kv4AM9USKfbRnEl+mxjcVFr1GZOK?Fn zEBzs9L>z}aMk1-u5;>^)NLsI$;YSEeVjWg$owgI-!N9mIenaP)_9R86UgO4h}nH z+6g6hXvqLoifJ)3Ffce_W)!VvW)4;kW)K7_dcw5hKNK+jgR=yJ+TrZia2CH}DR#xv z;Yv=yS^WNGE8(0!aMlMnD+Pxx>{bkE_pm!*lxJF6z&K)aLMz<8M2yb;|8H~T|9@cI zF&wdGU^r*Uz@TLSig_oj1;3X_#w=f}*@B!_$h4Lz*mT^Lf|m%$0(20i2rur1#JEH_ z7=;A5m^c`Lm>GyUKsZT>3!Y`r)UX2O*+AHdg^NKzv`C(d3tX#Z=B1?;2?>JOMWw*{ KEjJZX(E$J;l>Np4 From b8d5869f57da0095126c256704f6be58a5d5422c Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 22 Apr 2024 14:34:11 -0700 Subject: [PATCH 41/72] Case statement missing : --- include/onnxruntime/core/framework/data_types_internal.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index d30bcbe77beb0..05f4c10995ef2 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -218,10 +218,10 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ function(__VA_ARGS__); \ break; \ - case ONNX_NAMESPACE::TensorProto_DataType_UINT4 \ - function(__VA_ARGS__); \ - break; \ - default: \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + function(__VA_ARGS__); \ + break; \ + default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } From 934a0634e5d564b2fb561b5678585949d7ebc111 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 2 May 2024 00:52:02 -0700 Subject: [PATCH 42/72] Add powerpc int4 quant kernel --- onnxruntime/core/mlas/lib/mlasi.h | 36 +++++ .../core/mlas/lib/power/QuantizePower.cpp | 140 ++++++++++++++++ onnxruntime/core/mlas/lib/quantize.cpp | 151 ++++++------------ 3 files changed, 229 insertions(+), 98 deletions(-) diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 8d58b658eaa5c..2ae2fd26f9e73 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -2521,3 +2521,39 @@ MlasThreadedBufAlloc(size_t size) ThreadedBufSize = size; } } + +// +// Utilities for INT4 quantization. +// + +template +struct Int4Range; + +template <> +struct Int4Range { + static constexpr int8_t Min = -8; + static constexpr int8_t Max = 7; +}; + +template <> +struct Int4Range { + static constexpr uint8_t Min = 0; + static constexpr uint8_t Max = 15; +}; + +template +MLAS_FORCEINLINE +void +MlasSetInt4Element(UnpackedType* Output, size_t Index, UnpackedType Value) +{ + static_assert(std::is_same_v || std::is_same_v); + + const size_t OutputIndex = Index >> 1; // which byte + const size_t NibbleIndex = Index & 0x1; // which 4-bit elem in the byte + const uint8_t Shift = 4 * static_cast(NibbleIndex); + const UnpackedType Mask = 0xF << Shift; + + Output[OutputIndex] &= ~Mask; // Clear 4-bit lane + Output[OutputIndex] |= static_cast((Value & 0xF) << Shift); // Set 4-bit lane +} + diff --git a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp index 1fed8af21b31c..06d8fc0a62f3e 100644 --- a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp +++ b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp @@ -107,6 +107,120 @@ Return Value: } } +template +void +MLASCALL +MlasQuantizeLinearInt4Kernel( + const float* Input, + UnpackedType* Output, + size_t N, + float Scale, + UnpackedType ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer as int4 using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. Contains packed 4-bit elements. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + static_assert(std::is_same_v || std::is_same_v); + constexpr int32_t MinimumValue = Int4Range::Min; + constexpr int32_t MaximumValue = Int4Range::Max; + + auto ScaleVector = vec_splats(Scale); + auto MinimumValueVector = vec_splats(float(MinimumValue)); + auto MaximumValueVector = vec_splats(float(MaximumValue)); + auto ZeroPointVector = vec_splats(float(ZeroPoint)); + + // Holds 16 quantized 8-bit values that will be packed into the output as packed 4-bit values. + UnpackedType TmpOutput[16] = {}; + + while (N >= 16) { + auto FloatVector0 = vec_xl(0, Input); + auto FloatVector1 = vec_xl(0, Input + 4); + auto FloatVector2 = vec_xl(0, Input + 8); + auto FloatVector3 = vec_xl(0, Input + 12); + + FloatVector0 = vec_div(FloatVector0, ScaleVector); + FloatVector1 = vec_div(FloatVector1, ScaleVector); + FloatVector2 = vec_div(FloatVector2, ScaleVector); + FloatVector3 = vec_div(FloatVector3, ScaleVector); + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + FloatVector0 = vec_add(FloatVector0, ZeroPointVector); + FloatVector1 = vec_add(FloatVector1, ZeroPointVector); + FloatVector2 = vec_add(FloatVector2, ZeroPointVector); + FloatVector3 = vec_add(FloatVector3, ZeroPointVector); + + FloatVector0 = vec_max(FloatVector0, MinimumValueVector); + FloatVector1 = vec_max(FloatVector1, MinimumValueVector); + FloatVector2 = vec_max(FloatVector2, MinimumValueVector); + FloatVector3 = vec_max(FloatVector3, MinimumValueVector); + + FloatVector0 = vec_min(FloatVector0, MaximumValueVector); + FloatVector1 = vec_min(FloatVector1, MaximumValueVector); + FloatVector2 = vec_min(FloatVector2, MaximumValueVector); + FloatVector3 = vec_min(FloatVector3, MaximumValueVector); + + auto IntegerVector0 = vec_signed(FloatVector0); + auto IntegerVector1 = vec_signed(FloatVector1); + auto IntegerVector2 = vec_signed(FloatVector2); + auto IntegerVector3 = vec_signed(FloatVector3); + + auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); + auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); + + auto CharVector = vec_pack(ShortVector0, ShortVector1); + vec_xst(CharVector, 0, static_cast(&TmpOutput[0])); + + Output[0] = static_cast(((TmpOutput[1] & 0xF) << 4) | (TmpOutput[0] & 0xF)); + Output[1] = static_cast(((TmpOutput[3] & 0xF) << 4) | (TmpOutput[2] & 0xF)); + Output[2] = static_cast(((TmpOutput[5] & 0xF) << 4) | (TmpOutput[4] & 0xF)); + Output[3] = static_cast(((TmpOutput[7] & 0xF) << 4) | (TmpOutput[6] & 0xF)); + Output[4] = static_cast(((TmpOutput[9] & 0xF) << 4) | (TmpOutput[8] & 0xF)); + Output[5] = static_cast(((TmpOutput[11] & 0xF) << 4) | (TmpOutput[10] & 0xF)); + Output[6] = static_cast(((TmpOutput[13] & 0xF) << 4) | (TmpOutput[12] & 0xF)); + Output[7] = static_cast(((TmpOutput[15] & 0xF) << 4) | (TmpOutput[14] & 0xF)); + + Output += 8; + Input += 16; + N -= 16; + } + + for (size_t n = 0; n < N; n++) { + + float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); + FloatValue = std::max(FloatValue, static_cast(MinimumValue)); + FloatValue = std::min(FloatValue, static_cast(MaximumValue)); + UnpackedType IntValue = static_cast(FloatValue); + + MlasSetInt4Element(Output, n, IntValue); + } +} + void MLASCALL MlasQuantizeLinearU8Kernel( @@ -159,3 +273,29 @@ MlasQuantizeLinearS16Kernel( MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); } +void +MLASCALL +MlasQuantizeLinearU4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS4Kernel( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index 12df31b42a7e4..e1aa92a7306fe 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -519,18 +519,21 @@ Return Value: } } +template void MLASCALL -MlasQuantizeLinearS4Kernel( +MlasQuantizeLinearInt4Kernel( const float* Input, - int8_t* Output, + UnpackedType* Output, size_t N, float Scale, - int8_t ZeroPoint + UnpackedType ZeroPoint ) { - constexpr int32_t MinimumValue = -8; - constexpr int32_t MaximumValue = 7; + static_assert(std::is_same_v || std::is_same_v); + + constexpr int32_t MinimumValue = Int4Range::Min; + constexpr int32_t MaximumValue = Int4Range::Max; auto ScaleVector = MlasBroadcastFloat32x4(Scale); auto MinimumValueVector = MlasBroadcastFloat32x4(static_cast(MinimumValue - ZeroPoint)); @@ -538,7 +541,7 @@ MlasQuantizeLinearS4Kernel( auto ZeroPointVector = MlasBroadcastInt32x4(ZeroPoint); // Holds 4 quantized 8bit values that will be packed into the output as packed 4bit values. - std::array TmpOutput = {}; + std::array TmpOutput = {}; while (N >= 4) { @@ -546,11 +549,11 @@ MlasQuantizeLinearS4Kernel( auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, MinimumValueVector, MaximumValueVector, ZeroPointVector); - IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); + IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); MlasQuantizeLinearStore4PackedValues(IntegerVector, TmpOutput.data()); - Output[0] = static_cast(((TmpOutput[1] & 0xF) << 4) | (TmpOutput[0] & 0xF)); - Output[1] = static_cast(((TmpOutput[3] & 0xF) << 4) | (TmpOutput[2] & 0xF)); + Output[0] = static_cast(((TmpOutput[1] & 0xF) << 4) | (TmpOutput[0] & 0xF)); + Output[1] = static_cast(((TmpOutput[3] & 0xF) << 4) | (TmpOutput[2] & 0xF)); Input += 4; Output += 2; @@ -570,17 +573,23 @@ MlasQuantizeLinearS4Kernel( MinimumValueVector, MaximumValueVector, ZeroPointVector); MlasQuantizeLinearStoreSingleValue(IntegerVector, &TmpOutput[0]); - - size_t OutputIndex = n >> 1; // which byte - size_t NibbleIndex = n & 0x1; // which 4-bit elem in the byte - uint8_t Shift = 4 * static_cast(NibbleIndex); - int8_t Mask = 0xF << Shift; - - Output[OutputIndex] &= ~Mask; // Clear 4-bit lane - Output[OutputIndex] |= static_cast((TmpOutput[0] & 0xF) << Shift); // Set 4-bit lane + MlasSetInt4Element(Output, n, TmpOutput[0]); } } +void +MLASCALL +MlasQuantizeLinearS4Kernel( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + void MLASCALL MlasQuantizeLinearU4Kernel( @@ -591,56 +600,7 @@ MlasQuantizeLinearU4Kernel( uint8_t ZeroPoint ) { - constexpr int32_t MinimumValue = 0; - constexpr int32_t MaximumValue = 15; - - auto ScaleVector = MlasBroadcastFloat32x4(Scale); - auto MinimumValueVector = MlasBroadcastFloat32x4(static_cast(MinimumValue - ZeroPoint)); - auto MaximumValueVector = MlasBroadcastFloat32x4(static_cast(MaximumValue - ZeroPoint)); - auto ZeroPointVector = MlasBroadcastInt32x4(ZeroPoint); - - // Holds 4 quantized 8bit values that will be packed into the output as packed 4bit values. - std::array TmpOutput = {}; - - while (N >= 4) { - - auto FloatVector = MlasLoadFloat32x4(Input); - auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, - MinimumValueVector, MaximumValueVector, ZeroPointVector); - - IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); - MlasQuantizeLinearStore4PackedValues(IntegerVector, TmpOutput.data()); - - Output[0] = static_cast(((TmpOutput[1] & 0xF) << 4) | (TmpOutput[0] & 0xF)); - Output[1] = static_cast(((TmpOutput[3] & 0xF) << 4) | (TmpOutput[2] & 0xF)); - - Input += 4; - Output += 2; - N -= 4; - } - - for (size_t n = 0; n < N; n++) { - -#if defined(MLAS_NEON64_INTRINSICS) - auto FloatVector = vld1q_dup_f32(Input + n); -#elif defined(MLAS_LSX_INTRINSICS) - MLAS_FLOAT32X4 FloatVector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input+n, 0); -#else - auto FloatVector = _mm_load_ss(Input + n); -#endif - auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, - MinimumValueVector, MaximumValueVector, ZeroPointVector); - - MlasQuantizeLinearStoreSingleValue(IntegerVector, &TmpOutput[0]); - - size_t OutputIndex = n >> 1; // which byte - size_t NibbleIndex = n & 0x1; // which 4-bit elem in the byte - uint8_t Shift = 4 * static_cast(NibbleIndex); - uint8_t Mask = 0xF << Shift; - - Output[OutputIndex] &= ~Mask; // Clear 4-bit lane - Output[OutputIndex] |= static_cast((TmpOutput[0] & 0xF) << Shift); // Set 4-bit lane - } + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); } void @@ -990,36 +950,47 @@ MlasQuantizeLinear( uint16_t ZeroPoint ); -// QuantizeLinear INT4 implementation using the C++ runtime. +template void MLASCALL -MlasQuantizeLinearS4( +MlasQuantizeLinearInt4( const float* Input, - int8_t* Output, + UnpackedType* Output, size_t N, float Scale, - int8_t ZeroPoint + UnpackedType ZeroPoint ) { - constexpr int32_t MinimumValue = -8; - constexpr int32_t MaximumValue = 7; + static_assert(std::is_same_v || std::is_same_v); + + constexpr int32_t MinimumValue = Int4Range::Min; + constexpr int32_t MaximumValue = Int4Range::Max; + for (size_t n = 0; n < N; n++) { float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); FloatValue = std::max(FloatValue, static_cast(MinimumValue)); FloatValue = std::min(FloatValue, static_cast(MaximumValue)); - int8_t IntValue = static_cast(FloatValue); - - size_t OutputIndex = n >> 1; // which byte - size_t NibbleIndex = n & 0x1; // which 4-bit elem in the byte - uint8_t Shift = 4 * static_cast(NibbleIndex); - int8_t Mask = 0xF << Shift; + UnpackedType IntValue = static_cast(FloatValue); - Output[OutputIndex] &= ~Mask; // Clear 4-bit lane - Output[OutputIndex] |= static_cast((IntValue & 0xF) << Shift); // Set 4-bit lane + MlasSetInt4Element(Output, n, IntValue); } } +// QuantizeLinear INT4 implementation using the C++ runtime. +void +MLASCALL +MlasQuantizeLinearS4( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint); +} + // QuantizeLinear UINT4 implementation using the C++ runtime. void MLASCALL @@ -1031,23 +1002,7 @@ MlasQuantizeLinearU4( uint8_t ZeroPoint ) { - constexpr int32_t MinimumValue = 0; - constexpr int32_t MaximumValue = 15; - - for (size_t n = 0; n < N; n++) { - float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); - FloatValue = std::max(FloatValue, static_cast(MinimumValue)); - FloatValue = std::min(FloatValue, static_cast(MaximumValue)); - uint8_t IntValue = static_cast(FloatValue); - - size_t OutputIndex = n >> 1; // which byte - size_t NibbleIndex = n & 0x1; // which 4-bit elem in the byte - uint8_t Shift = 4 * static_cast(NibbleIndex); - uint8_t Mask = 0xF << Shift; - - Output[OutputIndex] &= ~Mask; // Clear 4-bit lane - Output[OutputIndex] |= static_cast((IntValue & 0xF) << Shift); // Set 4-bit lane - } + MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint); } #endif From 48538532ed992c2cb199e8863d04ae084dc73136 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Fri, 3 May 2024 10:12:40 -0700 Subject: [PATCH 43/72] Try to exclude MLAS C++ code from Github's cpplint workflow. MLAS has its own formatting style. The excessive linter warnings cause reviewdog to crash due to too much output. --- .github/workflows/lint.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 12b772ceff282..aadfe746f829f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -93,7 +93,10 @@ jobs: github_token: ${{ secrets.github_token }} reporter: github-pr-check level: warning - flags: --linelength=120 --exclude=java/src/main/native/*.c + flags: --linelength=120 \ + --exclude=java/src/main/native/*.c \ + --exclude=onnxruntime/core/mlas/inc/* \ + --exclude=onnxruntime/core/mlas/lib/* filter: "-runtime/references" lint-js: From 51193846fc80d7c919e4a8940eeede23f31a243e Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Fri, 3 May 2024 10:17:52 -0700 Subject: [PATCH 44/72] Remove backslash from cpplint flags --- .github/workflows/lint.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index aadfe746f829f..34911cfc7972e 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -93,9 +93,9 @@ jobs: github_token: ${{ secrets.github_token }} reporter: github-pr-check level: warning - flags: --linelength=120 \ - --exclude=java/src/main/native/*.c \ - --exclude=onnxruntime/core/mlas/inc/* \ + flags: --linelength=120 + --exclude=java/src/main/native/*.c + --exclude=onnxruntime/core/mlas/inc/* --exclude=onnxruntime/core/mlas/lib/* filter: "-runtime/references" From 279a50d5f2af77c40bc6942bde2ae78946735153 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Fri, 3 May 2024 16:22:19 -0700 Subject: [PATCH 45/72] Template on sign instead of unpacked type --- include/onnxruntime/core/framework/int4.h | 88 +++++++++++-------- onnxruntime/core/mlas/inc/mlas.h | 4 +- onnxruntime/core/mlas/lib/mlasi.h | 48 ++++++---- .../core/mlas/lib/power/QuantizePower.cpp | 39 ++++---- onnxruntime/core/mlas/lib/quantize.cpp | 67 +++++++------- .../cpu/quantization/quantize_linear.cc | 20 ++--- onnxruntime/core/util/qmath.h | 10 +-- .../mlas/unittest/test_quantizelinear.cpp | 58 ++++++------ .../cpu/tensor/quantize_linear_test.cc | 12 +-- 9 files changed, 185 insertions(+), 161 deletions(-) diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h index 7fc7d9836d312..6f0311a39f591 100644 --- a/include/onnxruntime/core/framework/int4.h +++ b/include/onnxruntime/core/framework/int4.h @@ -10,77 +10,91 @@ namespace onnxruntime { -template -struct UnpackedTypeTraits; +template +struct Int4Traits; template <> -struct UnpackedTypeTraits { +struct Int4Traits { + using UnpackedType = int8_t; static constexpr int8_t min_val = -8; static constexpr int8_t max_val = 7; }; template <> -struct UnpackedTypeTraits { +struct Int4Traits { + using UnpackedType = uint8_t; static constexpr uint8_t min_val = 0; static constexpr uint8_t max_val = 15; }; -template +template struct Int4x2Base { - using unpacked_type = T; - static constexpr unpacked_type min_val = UnpackedTypeTraits::min_val; - static constexpr unpacked_type max_val = UnpackedTypeTraits::max_val; + using UnpackedType = typename Int4Traits::UnpackedType; + static constexpr UnpackedType min_val = Int4Traits::min_val; + static constexpr UnpackedType max_val = Int4Traits::max_val; - unpacked_type elems{}; + uint8_t bits_{}; Int4x2Base() = default; + explicit Int4x2Base(uint8_t bits) { - elems = static_cast(bits); + bits_ = bits; + } + + Int4x2Base(UnpackedType val0, UnpackedType val1) { + bits_ = static_cast(((val1 & 0xF) << 4) | (val0 & 0xF)); } - Int4x2Base(unpacked_type val0, unpacked_type val1) { - elems = static_cast(((val1 & 0xF) << 4) | (val0 & 0xF)); + + static inline int8_t SignExtendLower4Bits(uint8_t bits) { + // Sign-extend lower 4-bits by left shifting and then doing an arithmetic right shift. + constexpr uint8_t Shift = (sizeof(int32_t) * 8) - 4; + return static_cast((static_cast(bits) << Shift) >> Shift); } - inline unpacked_type GetElem0() const { - if constexpr (std::is_same_v) { - // Need to sign-extend lower 4-bits by left shifting and then doing an arithmetic right shift. - return static_cast(static_cast((elems << 4)) >> 4); + inline UnpackedType GetElem0() const { + if constexpr (Signed) { + return SignExtendLower4Bits(bits_); } else { - return static_cast(elems & 0xF); + return static_cast(bits_ & 0xF); } } - inline unpacked_type GetElem1() const { - return static_cast(elems >> 4); + inline UnpackedType GetElem1() const { + const uint8_t val = static_cast((bits_ >> 4) & 0xF); + + if constexpr (Signed) { + return SignExtendLower4Bits(val); + } else { + return val; + } } - inline unpacked_type GetElem(size_t index) const { + inline UnpackedType GetElem(size_t index) const { assert(index <= 1); const uint8_t shift = 4 * static_cast(index); + const uint8_t val = static_cast((bits_ >> shift) & 0xF); - if constexpr (std::is_same_v) { - // if index is 0, need to sign-extend lower 4-bits by left shifting and then doing an arithmetic right shift. - const uint8_t unshift = 4 - shift; - return static_cast(static_cast((elems >> shift) << unshift) >> unshift); + if constexpr (Signed) { + return SignExtendLower4Bits(val); } else { - return static_cast((elems >> shift) & 0xF); + return val; } } - inline void SetElem(size_t index, unpacked_type val) { + inline void SetElem(size_t index, UnpackedType val) { assert(index <= 1); const uint8_t shift = 4 * static_cast(index); - const unpacked_type mask = 0xF << shift; + const uint8_t mask = 0xF << shift; - elems &= ~mask; // Clear 4-bit element to 0 - elems |= static_cast((val & 0xF) << shift); // Set 4-bit element to val + bits_ &= ~mask; // Clear 4-bit element to 0 + bits_ |= static_cast((val & 0xF) << shift); // Set 4-bit element to val } inline uint8_t ToBits() const { - return static_cast(elems); + return bits_; } - static bool Unpack(gsl::span dst, gsl::span> src) { + static bool Unpack(gsl::span dst, gsl::span> src) { if (((dst.size() + 1) / 2) != src.size()) { return false; } @@ -94,7 +108,7 @@ struct Int4x2Base { return true; } - static bool Pack(gsl::span> dst, gsl::span src) { + static bool Pack(gsl::span> dst, gsl::span src) { if (((src.size() + 1) / 2) != dst.size()) { return false; } @@ -103,19 +117,19 @@ struct Int4x2Base { size_t dst_i = 0; for (; src_i < src.size() - 1; src_i += 2) { - dst[dst_i++] = Int4x2Base(src[src_i], src[src_i + 1]); + dst[dst_i++] = Int4x2Base(src[src_i], src[src_i + 1]); } if (src_i < src.size()) { - dst[dst_i] = Int4x2Base(src[src_i], 0); + dst[dst_i] = Int4x2Base(src[src_i], 0); } return true; } }; -using Int4x2 = Int4x2Base; -using UInt4x2 = Int4x2Base; -static_assert(sizeof(Int4x2) == sizeof(int8_t)); +using Int4x2 = Int4x2Base; +using UInt4x2 = Int4x2Base; +static_assert(sizeof(Int4x2) == sizeof(uint8_t)); static_assert(sizeof(UInt4x2) == sizeof(uint8_t)); } // namespace onnxruntime diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index f8966657e2109..cdfd283899c8c 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1229,14 +1229,14 @@ MlasQuantizeLinearU4( uint8_t* Output, size_t N, float Scale, - uint8_t ZeroPoint + int8_t ZeroPoint ); void MLASCALL MlasQuantizeLinearS4( const float* Input, - int8_t* Output, + uint8_t* Output, size_t N, float Scale, int8_t ZeroPoint diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 2ae2fd26f9e73..403460c310fb7 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -687,13 +687,13 @@ void uint8_t* Output, size_t N, float Scale, - uint8_t ZeroPoint); + int8_t ZeroPoint); typedef void (MLASCALL MLAS_QUANTIZE_LINEAR_S4_KERNEL)( const float* Input, - int8_t* Output, + uint8_t* Output, size_t N, float Scale, int8_t ZeroPoint); @@ -2526,34 +2526,46 @@ MlasThreadedBufAlloc(size_t size) // Utilities for INT4 quantization. // -template -struct Int4Range; +template +struct Int4Traits; -template <> -struct Int4Range { +template<> +struct Int4Traits { + using UnpackedType = int8_t; static constexpr int8_t Min = -8; static constexpr int8_t Max = 7; }; -template <> -struct Int4Range { - static constexpr uint8_t Min = 0; - static constexpr uint8_t Max = 15; +template<> +struct Int4Traits { + using UnpackedType = uint8_t; + static constexpr int8_t Min = 0; + static constexpr int8_t Max = 15; }; -template +template MLAS_FORCEINLINE void -MlasSetInt4Element(UnpackedType* Output, size_t Index, UnpackedType Value) +MlasSetInt4Element(uint8_t* Output, size_t ElemIndex, UnpackedType Value) { static_assert(std::is_same_v || std::is_same_v); - const size_t OutputIndex = Index >> 1; // which byte - const size_t NibbleIndex = Index & 0x1; // which 4-bit elem in the byte - const uint8_t Shift = 4 * static_cast(NibbleIndex); - const UnpackedType Mask = 0xF << Shift; + const size_t OutputIndex = ElemIndex >> 1; // which byte + const size_t NibbleIndex = ElemIndex & 0x1; // which 4-bit elem in the byte + const uint8_t Shift = static_cast(4 * NibbleIndex); + const uint8_t Mask = static_cast(0xF << Shift); + uint8_t* Dst = &Output[OutputIndex]; - Output[OutputIndex] &= ~Mask; // Clear 4-bit lane - Output[OutputIndex] |= static_cast((Value & 0xF) << Shift); // Set 4-bit lane + *Dst &= ~Mask; // Clear 4-bit lane + *Dst |= static_cast((Value & 0xF) << Shift); // Set 4-bit lane +} + +template +MLAS_FORCEINLINE +void +MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueHigh) +{ + static_assert(std::is_same_v || std::is_same_v); + *Output = static_cast(((ValueHigh & 0xF) << 4) | (ValueLow & 0xF)); } diff --git a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp index 06d8fc0a62f3e..0cfa56740edfb 100644 --- a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp +++ b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp @@ -107,15 +107,15 @@ Return Value: } } -template +template void MLASCALL MlasQuantizeLinearInt4Kernel( const float* Input, - UnpackedType* Output, + uint8_t* Output, size_t N, float Scale, - UnpackedType ZeroPoint + int8_t ZeroPoint ) /*++ @@ -142,9 +142,9 @@ Return Value: --*/ { - static_assert(std::is_same_v || std::is_same_v); - constexpr int32_t MinimumValue = Int4Range::Min; - constexpr int32_t MaximumValue = Int4Range::Max; + constexpr int32_t MinimumValue = Int4Traits::Min; + constexpr int32_t MaximumValue = Int4Traits::Max; + using UnpackedType = typename Int4Traits::UnpackedType; auto ScaleVector = vec_splats(Scale); auto MinimumValueVector = vec_splats(float(MinimumValue)); @@ -196,16 +196,15 @@ Return Value: auto CharVector = vec_pack(ShortVector0, ShortVector1); vec_xst(CharVector, 0, static_cast(&TmpOutput[0])); - Output[0] = static_cast(((TmpOutput[1] & 0xF) << 4) | (TmpOutput[0] & 0xF)); - Output[1] = static_cast(((TmpOutput[3] & 0xF) << 4) | (TmpOutput[2] & 0xF)); - Output[2] = static_cast(((TmpOutput[5] & 0xF) << 4) | (TmpOutput[4] & 0xF)); - Output[3] = static_cast(((TmpOutput[7] & 0xF) << 4) | (TmpOutput[6] & 0xF)); - Output[4] = static_cast(((TmpOutput[9] & 0xF) << 4) | (TmpOutput[8] & 0xF)); - Output[5] = static_cast(((TmpOutput[11] & 0xF) << 4) | (TmpOutput[10] & 0xF)); - Output[6] = static_cast(((TmpOutput[13] & 0xF) << 4) | (TmpOutput[12] & 0xF)); - Output[7] = static_cast(((TmpOutput[15] & 0xF) << 4) | (TmpOutput[14] & 0xF)); + MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]); + MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]); + MlasPackInt4Elements(Output++, TmpOutput[4], TmpOutput[5]); + MlasPackInt4Elements(Output++, TmpOutput[6], TmpOutput[7]); + MlasPackInt4Elements(Output++, TmpOutput[8], TmpOutput[9]); + MlasPackInt4Elements(Output++, TmpOutput[10], TmpOutput[11]); + MlasPackInt4Elements(Output++, TmpOutput[12], TmpOutput[13]); + MlasPackInt4Elements(Output++, TmpOutput[14], TmpOutput[15]); - Output += 8; Input += 16; N -= 16; } @@ -217,7 +216,7 @@ Return Value: FloatValue = std::min(FloatValue, static_cast(MaximumValue)); UnpackedType IntValue = static_cast(FloatValue); - MlasSetInt4Element(Output, n, IntValue); + MlasSetInt4Element(Output, n, IntValue); } } @@ -280,22 +279,22 @@ MlasQuantizeLinearU4Kernel( uint8_t* Output, size_t N, float Scale, - uint8_t ZeroPoint + int8_t ZeroPoint ) { - MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); } void MLASCALL MlasQuantizeLinearS4Kernel( const float* Input, - int8_t* Output, + uint8_t* Output, size_t N, float Scale, int8_t ZeroPoint ) { - MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); } diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index e1aa92a7306fe..ae638fafee18f 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -519,21 +519,20 @@ Return Value: } } -template +template void MLASCALL MlasQuantizeLinearInt4Kernel( const float* Input, - UnpackedType* Output, + uint8_t* Output, size_t N, float Scale, - UnpackedType ZeroPoint + int8_t ZeroPoint ) { - static_assert(std::is_same_v || std::is_same_v); - - constexpr int32_t MinimumValue = Int4Range::Min; - constexpr int32_t MaximumValue = Int4Range::Max; + constexpr int32_t MinimumValue = Int4Traits::Min; + constexpr int32_t MaximumValue = Int4Traits::Max; + using UnpackedType = typename Int4Traits::UnpackedType; auto ScaleVector = MlasBroadcastFloat32x4(Scale); auto MinimumValueVector = MlasBroadcastFloat32x4(static_cast(MinimumValue - ZeroPoint)); @@ -541,7 +540,7 @@ MlasQuantizeLinearInt4Kernel( auto ZeroPointVector = MlasBroadcastInt32x4(ZeroPoint); // Holds 4 quantized 8bit values that will be packed into the output as packed 4bit values. - std::array TmpOutput = {}; + UnpackedType TmpOutput[4] = {}; while (N >= 4) { @@ -550,13 +549,11 @@ MlasQuantizeLinearInt4Kernel( MinimumValueVector, MaximumValueVector, ZeroPointVector); IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); - MlasQuantizeLinearStore4PackedValues(IntegerVector, TmpOutput.data()); - - Output[0] = static_cast(((TmpOutput[1] & 0xF) << 4) | (TmpOutput[0] & 0xF)); - Output[1] = static_cast(((TmpOutput[3] & 0xF) << 4) | (TmpOutput[2] & 0xF)); + MlasQuantizeLinearStore4PackedValues(IntegerVector, &TmpOutput[0]); + MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]); + MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]); Input += 4; - Output += 2; N -= 4; } @@ -572,8 +569,8 @@ MlasQuantizeLinearInt4Kernel( auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, MinimumValueVector, MaximumValueVector, ZeroPointVector); - MlasQuantizeLinearStoreSingleValue(IntegerVector, &TmpOutput[0]); - MlasSetInt4Element(Output, n, TmpOutput[0]); + MlasQuantizeLinearStoreSingleValue(IntegerVector, &TmpOutput[0]); + MlasSetInt4Element(Output, n, TmpOutput[0]); } } @@ -581,13 +578,13 @@ void MLASCALL MlasQuantizeLinearS4Kernel( const float* Input, - int8_t* Output, + uint8_t* Output, size_t N, float Scale, int8_t ZeroPoint ) { - MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); } void @@ -597,10 +594,10 @@ MlasQuantizeLinearU4Kernel( uint8_t* Output, size_t N, float Scale, - uint8_t ZeroPoint + int8_t ZeroPoint ) { - MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); } void @@ -659,7 +656,7 @@ void MLASCALL MlasQuantizeLinearS4( const float* Input, - int8_t* Output, + uint8_t* Output, size_t N, float Scale, int8_t ZeroPoint @@ -680,7 +677,7 @@ MlasQuantizeLinearU4( uint8_t* Output, size_t N, float Scale, - uint8_t ZeroPoint + int8_t ZeroPoint ) { #if defined(MLAS_TARGET_AMD64) @@ -831,7 +828,7 @@ void MLASCALL MlasQuantizeLinearS4( const float* Input, - int8_t* Output, + uint8_t* Output, size_t N, float Scale, int8_t ZeroPoint @@ -847,7 +844,7 @@ MlasQuantizeLinearU4( uint8_t* Output, size_t N, float Scale, - uint8_t ZeroPoint + int8_t ZeroPoint ) { GetMlasPlatform().QuantizeLinearU4Kernel(Input, Output, N, Scale, ZeroPoint); @@ -950,22 +947,20 @@ MlasQuantizeLinear( uint16_t ZeroPoint ); -template +template void MLASCALL MlasQuantizeLinearInt4( const float* Input, - UnpackedType* Output, + uint8_t* Output, size_t N, float Scale, - UnpackedType ZeroPoint + int8_t ZeroPoint ) { - static_assert(std::is_same_v || std::is_same_v); - - constexpr int32_t MinimumValue = Int4Range::Min; - constexpr int32_t MaximumValue = Int4Range::Max; - + constexpr int32_t MinimumValue = Int4Traits::Min; + constexpr int32_t MaximumValue = Int4Traits::Max; + using UnpackedType = typename Int4Traits::UnpackedType; for (size_t n = 0; n < N; n++) { float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); @@ -973,7 +968,7 @@ MlasQuantizeLinearInt4( FloatValue = std::min(FloatValue, static_cast(MaximumValue)); UnpackedType IntValue = static_cast(FloatValue); - MlasSetInt4Element(Output, n, IntValue); + MlasSetInt4Element(Output, n, IntValue); } } @@ -982,13 +977,13 @@ void MLASCALL MlasQuantizeLinearS4( const float* Input, - int8_t* Output, + uint8_t* Output, size_t N, float Scale, int8_t ZeroPoint ) { - MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint); + MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint); } // QuantizeLinear UINT4 implementation using the C++ runtime. @@ -999,10 +994,10 @@ MlasQuantizeLinearU4( uint8_t* Output, size_t N, float Scale, - uint8_t ZeroPoint + int8_t ZeroPoint ) { - MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint); + MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint); } #endif diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index f2572f6b35a51..dceef3a300293 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -520,7 +520,7 @@ void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ size_t bd_i = bd >> 1; /*bd / 2*/ \ size_t bd_j = bd & 0x1; /*bd % 2*/ \ - INT4_TYPE::unpacked_type zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ + INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ QUANT_FUNC(input, output, output_index, output_index + static_cast(block_size), \ scale[bd], INT4_TYPE(zp, 0), ctx->GetOperatorThreadPool()); \ input += block_size; \ @@ -544,31 +544,31 @@ DEFINE_COMPUTE_LOOP_FP32_TO_INT4(UInt4x2, ParQuantizeLinearStdU4) ORT_UNUSED_PARAMETER(saturate); \ \ size_t total_size = static_cast(N * broadcast_dim * block_size); \ - auto tmp_buf = std::make_unique(total_size); \ + auto tmp_buf = std::make_unique(total_size); \ size_t tmp_buf_index = 0; \ \ for (size_t n = 0; n < static_cast(N); n++) { \ for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ size_t bd_i = bd >> 1; /*bd / 2*/ \ size_t bd_j = bd & 0x1; /*bd % 2*/ \ - INT4_TYPE::unpacked_type zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ - ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, \ - static_cast(block_size), scale[bd], \ - zp, ctx->GetOperatorThreadPool()); \ + INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ + ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, \ + static_cast(block_size), scale[bd], \ + zp, ctx->GetOperatorThreadPool()); \ input += block_size; \ tmp_buf_index += static_cast(block_size); \ } \ } \ \ for (size_t i = 0; i < total_size; i++) { \ - tmp_buf[i] = std::min(INT4_TYPE::max_val, \ - std::max(INT4_TYPE::min_val, \ - tmp_buf[i])); \ + tmp_buf[i] = std::min(INT4_TYPE::max_val, \ + std::max(INT4_TYPE::min_val, \ + tmp_buf[i])); \ } \ \ size_t num_int4_pairs = (total_size + 1) / 2; \ auto dst = gsl::make_span(output, num_int4_pairs); \ - auto src = gsl::make_span(tmp_buf.get(), total_size); \ + auto src = gsl::make_span(tmp_buf.get(), total_size); \ INT4_TYPE::Pack(dst, src); \ } diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index bd2e410020c5a..4aa2ab81216ce 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -150,7 +150,7 @@ ParQuantizeLinearStd(const float* Input, static_cast(ZeroPoint.GetElem0()); \ size_t output_index = out_start >> 1; \ \ - INT4_TYPE::unpacked_type quant_val = static_cast( \ + INT4_TYPE::UnpackedType quant_val = static_cast( \ std::min(static_cast(INT4_TYPE::max_val), \ std::max(static_cast(INT4_TYPE::min_val), ival))); \ Output[output_index].SetElem(1, quant_val); \ @@ -165,7 +165,7 @@ ParQuantizeLinearStd(const float* Input, static_cast(ZeroPoint.GetElem0()); \ size_t output_index = (out_end - 1) >> 1; \ \ - INT4_TYPE::unpacked_type quant_val = static_cast( \ + INT4_TYPE::UnpackedType quant_val = static_cast( \ std::min(static_cast(INT4_TYPE::max_val), \ std::max(static_cast(INT4_TYPE::min_val), ival))); \ Output[output_index].SetElem(0, quant_val); \ @@ -190,7 +190,7 @@ ParQuantizeLinearStd(const float* Input, \ const std::ptrdiff_t num_blocks = (N + block_size - 1) / block_size; \ const TensorOpCost unit_cost{static_cast(block_size * sizeof(float)), \ - static_cast(block_size * sizeof(INT4_TYPE::unpacked_type)) / 2.0, \ + static_cast(block_size * sizeof(INT4_TYPE::UnpackedType)) / 2.0, \ static_cast(block_size) * 2.0}; \ concurrency::ThreadPool::TryParallelFor( \ thread_pool, num_blocks, unit_cost, \ @@ -201,10 +201,10 @@ ParQuantizeLinearStd(const float* Input, auto out_idx = begin_idx + static_cast(out_start); \ \ MLAS_FUNC(&(Input[inp_idx]), \ - reinterpret_cast(&(Output[out_idx >> 1])), \ + reinterpret_cast(&(Output[out_idx >> 1])), \ end_idx - begin_idx, \ Scale, \ - ZeroPoint.GetElem0()); \ + static_cast(ZeroPoint.GetElem0())); \ }); \ } diff --git a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp index d771ae4fb4c19..7c160b6696265 100644 --- a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp +++ b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp @@ -71,54 +71,53 @@ class MlasQuantizeLinearTest : public MlasTestBase { } }; -template +template class MlasQuantizeLinear4BitTest : public MlasTestBase { private: MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; - UnpackedType MinVal() const { - if constexpr (std::is_same_v) { + int32_t MinVal() const { + if constexpr (Signed) { return -8; - } else if (std::is_same_v) { + } else { return 0; } } - UnpackedType MaxVal() const { - if constexpr (std::is_same_v) { + int32_t MaxVal() const { + if constexpr (Signed) { return 7; } else { - static_assert(std::is_same_v); return 15; } } - void GenerateReference(const float* Input, UnpackedType* OutputReference, size_t N, float Scale, - UnpackedType ZeroPoint) { + void GenerateReference(const float* Input, uint8_t* OutputReference, size_t N, float Scale, + int8_t ZeroPoint) { for (size_t n = 0; n < N; n++) { float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); FloatValue = std::max(FloatValue, static_cast(MinVal())); FloatValue = std::min(FloatValue, static_cast(MaxVal())); - UnpackedType IntValue = static_cast(FloatValue); + int8_t IntValue = static_cast(FloatValue); size_t i = n >> 1; size_t j = n & 0x1; uint8_t Shift = 4 * static_cast(j); - UnpackedType Mask = 0xF << Shift; + uint8_t Mask = 0xF << Shift; - OutputReference[i] &= ~Mask; // Clear 4-bit lane - OutputReference[i] |= static_cast((IntValue & 0xF) << Shift); // Set 4-bit lane + OutputReference[i] &= ~Mask; // Clear 4-bit lane + OutputReference[i] |= static_cast((IntValue & 0xF) << Shift); // Set 4-bit lane } } void Test(size_t N) { size_t OutBufLen = (N + 1) / 2; float* Input = BufferInput.GetBuffer(N); - UnpackedType* Output = BufferOutput.GetBuffer(OutBufLen); - UnpackedType* OutputReference = BufferOutputReference.GetBuffer(OutBufLen); + uint8_t* Output = BufferOutput.GetBuffer(OutBufLen); + uint8_t* OutputReference = BufferOutputReference.GetBuffer(OutBufLen); std::default_random_engine generator(static_cast(N)); @@ -131,7 +130,7 @@ class MlasQuantizeLinear4BitTest : public MlasTestBase { float Scale = (MaximumValue - MinimumValue) / 32.f; std::uniform_int_distribution zp_distribution(MinVal(), MaxVal()); - UnpackedType ZeroPoint = static_cast(zp_distribution(generator)); + int8_t ZeroPoint = static_cast(zp_distribution(generator)); std::uniform_real_distribution distribution(MinimumValue, MaximumValue); for (size_t n = 0; n < N; n++) { @@ -140,10 +139,9 @@ class MlasQuantizeLinear4BitTest : public MlasTestBase { GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); - if constexpr (std::is_same_v) { + if constexpr (Signed) { MlasQuantizeLinearS4(Input, Output, N, Scale, ZeroPoint); } else { - static_assert(std::is_same_v); MlasQuantizeLinearU4(Input, Output, N, Scale, ZeroPoint); } @@ -151,9 +149,16 @@ class MlasQuantizeLinear4BitTest : public MlasTestBase { size_t i = n >> 1; size_t j = n & 0x1; const uint8_t Shift = 4 * static_cast(j); - const uint8_t Unshift = 4 - Shift; - UnpackedType actual_val = static_cast((Output[i] >> Shift) << Unshift) >> Unshift; - UnpackedType expected_val = static_cast((OutputReference[i] >> Shift) << Unshift) >> Unshift; + + int32_t actual_val = (Output[i] >> Shift) & 0xF; + int32_t expected_val = (OutputReference[i] >> Shift) & 0xF; + + if constexpr (Signed) { + constexpr uint8_t SignExtShift = (sizeof(int32_t) * 8) - 4; + actual_val = (actual_val << SignExtShift) >> SignExtShift; + expected_val = (expected_val << SignExtShift) >> SignExtShift; + } + ASSERT_EQ(actual_val, expected_val) << ", size=" << N << ", index=" << n << ", nibble=" << j; @@ -162,10 +167,9 @@ class MlasQuantizeLinear4BitTest : public MlasTestBase { public: static const char* GetTestSuiteName() { - if constexpr (std::is_same_v) { + if constexpr (Signed) { return "QuantizeLinearS4"; } else { - static_assert(std::is_same_v); return "QuantizeLinearU4"; } } @@ -184,8 +188,8 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe count += MlasDirectShortExecuteTests>::RegisterShortExecute(); count += MlasDirectShortExecuteTests>::RegisterShortExecute(); count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } return count; }); diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index ec35c83fe3687..856c7749e8da5 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -424,15 +424,15 @@ TEST(QuantizeLinearOpTest, UInt4) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -template -static void GetExpectedInt4Quant(const float* input, Int4x2Base* output, size_t num_elems, float scale, - T zero_point) { +template +static void GetExpectedInt4Quant(const float* input, Int4x2Base* output, size_t num_elems, float scale, + int8_t zero_point) { for (size_t n = 0; n < num_elems; n++) { float float_val = std::nearbyintf(input[n] / scale) + static_cast(zero_point); - float_val = std::max(float_val, static_cast(Int4x2Base::min_val)); - float_val = std::min(float_val, static_cast(Int4x2Base::max_val)); + float_val = std::max(float_val, static_cast(Int4x2Base::min_val)); + float_val = std::min(float_val, static_cast(Int4x2Base::max_val)); - T int_val = static_cast(float_val); + Int4x2Base::UnpackedType int_val = static_cast::UnpackedType>(float_val); size_t i = n >> 1; size_t j = n & 0x1; From 488117fd09afa257c3d68d57a3651b2405c3c143 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Fri, 3 May 2024 16:48:10 -0700 Subject: [PATCH 46/72] Use typename --- onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 856c7749e8da5..3c6c1022e4bee 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -427,12 +427,14 @@ TEST(QuantizeLinearOpTest, UInt4) { template static void GetExpectedInt4Quant(const float* input, Int4x2Base* output, size_t num_elems, float scale, int8_t zero_point) { + using UnpackedType = typename Int4x2Base::UnpackedType; + for (size_t n = 0; n < num_elems; n++) { float float_val = std::nearbyintf(input[n] / scale) + static_cast(zero_point); float_val = std::max(float_val, static_cast(Int4x2Base::min_val)); float_val = std::min(float_val, static_cast(Int4x2Base::max_val)); - Int4x2Base::UnpackedType int_val = static_cast::UnpackedType>(float_val); + UnpackedType int_val = static_cast(float_val); size_t i = n >> 1; size_t j = n & 0x1; From 4c862bdf6fcaa8879cfefe4858573592df68fd11 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 6 May 2024 18:37:55 -0700 Subject: [PATCH 47/72] Add utils to compute tensor storage size and num elements for sub-byte types --- .../onnxruntime/core/framework/data_types.h | 46 +++++++++++----- include/onnxruntime/core/framework/tensor.h | 44 +++++++++++---- onnxruntime/core/framework/data_types.cc | 4 +- onnxruntime/core/framework/execution_frame.cc | 2 +- .../core/framework/session_state_utils.cc | 2 +- onnxruntime/core/framework/tensor.cc | 55 +++++++++++-------- onnxruntime/core/session/onnxruntime_c_api.cc | 14 ++--- 7 files changed, 107 insertions(+), 60 deletions(-) diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index dad3f4769019e..4d950d17dba7a 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -937,12 +937,21 @@ class PrimitiveDataTypeBase : public DataTypeImpl { return data_type_; } + int32_t GetNumSubElems() const { + return num_sub_elems_; + } + + bool HasSubElems() const { + return num_sub_elems_ > 1; + } + protected: - PrimitiveDataTypeBase(size_t size, int32_t data_type) - : DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type} {} + PrimitiveDataTypeBase(size_t size, int32_t data_type, int32_t num_sub_elems) + : DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type}, num_sub_elems_{num_sub_elems} {} private: const int32_t data_type_; + const int32_t num_sub_elems_; }; /** @@ -968,9 +977,9 @@ class PrimitiveDataType : public PrimitiveDataTypeBase { } private: - PrimitiveDataType() + explicit PrimitiveDataType(int32_t num_sub_elems) : PrimitiveDataTypeBase{sizeof(T), - utils::ToTensorProtoElementType()} { + utils::ToTensorProtoElementType(), num_sub_elems} { } }; @@ -1077,15 +1086,26 @@ inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const { return SequenceTensorType::Type(); \ } -#define ORT_REGISTER_PRIM_TYPE(TYPE) \ - template <> \ - MLDataType PrimitiveDataType::Type() { \ - static PrimitiveDataType prim_data_type; \ - return &prim_data_type; \ - } \ - template <> \ - MLDataType DataTypeImpl::GetType() { \ - return PrimitiveDataType::Type(); \ +#define ORT_REGISTER_PRIM_TYPE(TYPE) \ + template <> \ + MLDataType PrimitiveDataType::Type() { \ + static PrimitiveDataType prim_data_type(1); \ + return &prim_data_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return PrimitiveDataType::Type(); \ + } + +#define ORT_REGISTER_PRIM_SUBBYTE_TYPE(TYPE, NUM_SUB_ELEMS) \ + template <> \ + MLDataType PrimitiveDataType::Type() { \ + static PrimitiveDataType prim_data_type(NUM_SUB_ELEMS); \ + return &prim_data_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return PrimitiveDataType::Type(); \ } #define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \ diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index 3c3933024636e..77f4d57e5fdfb 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -145,6 +145,37 @@ class Tensor final { /// Bytes required. static size_t CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape); + /// + /// Calculate the required storage for the tensor. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// The resulting storage size. + /// Status indicating success or failure. + static Status CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape, size_t& storage_size); + + /// + /// Get the number of elements for a Tensor of the given element type and shape. + /// For element types smaller than 1 byte (e.g., int4), a single Tensor element stores multiple sub-byte elements. + /// So, this function returns the number of Tensor elements, each of which may contain multiple sub-byte elements. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Number of Tensor elements. Returns -1 if shape has negative dims. + static inline int64_t GetNumTensorElems(MLDataType elt_type, const TensorShape& shape) { + return GetNumTensorElems(elt_type, shape.Size()); + } + + /// + /// Get the number of elements for a Tensor of the given element type and shape size. + /// For element types smaller than 1 byte (e.g., int4), a single Tensor element stores multiple sub-byte elements. + /// So, this function returns the number of Tensor elements, each of which may contain multiple sub-byte elements. + /// + /// Data type of the tensor elements. + /// The number of elements indicated by the shape (i.e., shape.Size()). + /// Number of Tensor elements. Returns -1 if shape_size is negative. + static int64_t GetNumTensorElems(MLDataType elt_type, int64_t shape_size); + /** Returns the data type. */ @@ -200,12 +231,7 @@ class Tensor final { ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", "T ", "!=", dtype_); T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - int64_t num_elems = shape_.Size(); - if (utils::IsPrimitiveDataType(dtype_) || utils::IsPrimitiveDataType(dtype_)) { - num_elems = (num_elems + 1) / 2; - } - - return gsl::make_span(data, static_cast(num_elems)); + return gsl::make_span(data, static_cast(GetNumTensorElems(dtype_, shape_))); } template @@ -222,11 +248,7 @@ class Tensor final { ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", "T ", "!=", dtype_); const T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - int64_t num_elems = shape_.Size(); - if (utils::IsPrimitiveDataType(dtype_) || utils::IsPrimitiveDataType(dtype_)) { - num_elems = (num_elems + 1) / 2; - } - return gsl::make_span(data, static_cast::size_type>(num_elems)); + return gsl::make_span(data, static_cast::size_type>(GetNumTensorElems(dtype_, shape_))); } void* MutableDataRaw(MLDataType type) { diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index 4558ed0db6f62..6958a439f225e 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -1221,8 +1221,8 @@ ORT_REGISTER_PRIM_TYPE(Float8E5M2); ORT_REGISTER_PRIM_TYPE(Float8E5M2FNUZ); #endif -ORT_REGISTER_PRIM_TYPE(Int4x2); -ORT_REGISTER_PRIM_TYPE(UInt4x2); +ORT_REGISTER_PRIM_SUBBYTE_TYPE(Int4x2, 2); +ORT_REGISTER_PRIM_SUBBYTE_TYPE(UInt4x2, 2); namespace { template diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 32a5f749af084..88921d0c4f886 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -530,7 +530,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va } size_t size; - int64_t len = shape.Size(); + int64_t len = Tensor::GetNumTensorElems(element_type, shape); if (len < 0) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape cannot contain any negative value"); } diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 692ca08772535..d92dcfdd262bb 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -37,7 +37,7 @@ namespace session_state_utils { // It can handle arena-based allocators and non-arena based allocators. static common::Status AllocateBufferUsingDeviceAllocatorFromShapeAndType(const TensorShape& tensor_shape, const DataTypeImpl* type, const AllocatorPtr& alloc, /*out*/ void*& p_data) { - int64_t shape_size = tensor_shape.Size(); + int64_t shape_size = Tensor::GetNumTensorElems(type, tensor_shape); if (shape_size < 0) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "shape.Size() must >=0"); diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index b7017297df4ce..ce40a2d227251 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -27,26 +27,41 @@ int64_t GetSizeFromStrides(const TensorShape& shape, gsl::span st } // namespace #endif -size_t Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape) { - int64_t shape_size = shape.Size(); - if (shape_size < 0) - ORT_THROW("shape.Size() must >=0"); +int64_t Tensor::GetNumTensorElems(MLDataType elt_type, int64_t shape_size) { + int64_t num_elems = shape_size; + auto prim_type = elt_type->AsPrimitiveDataType(); - if (shape_size > 0) { - SafeInt len = 0; + if (prim_type != nullptr && prim_type->HasSubElems() && num_elems > 0) { + const int64_t num_sub_elems = prim_type->GetNumSubElems(); + num_elems = (num_elems + (num_sub_elems - 1)) / num_sub_elems; + } - // TODO(adrianlizarraga): Handle more cleanly. - if (utils::IsPrimitiveDataType(elt_type) || utils::IsPrimitiveDataType(elt_type)) { - shape_size = (shape_size + 1) / 2; - } + return num_elems; +} - if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), elt_type->Size(), &len)) - ORT_THROW("tensor failed memory size calculation"); +Status Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape, + /*out*/ size_t& storage_size) { + int64_t num_elems = GetNumTensorElems(elt_type, shape.Size()); + ORT_RETURN_IF(num_elems < 0, "Tensor shape.Size() must be >= 0"); - return len; + if (num_elems > 0) { + if (static_cast(num_elems) > std::numeric_limits::max()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape is too large"); + } + if (!IAllocator::CalcMemSizeForArray(static_cast(num_elems), elt_type->Size(), &storage_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed memory size calculation for Tensor storage size"); + } + } else { + storage_size = 0; } - return 0; + return Status::OK(); +} + +size_t Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape) { + size_t storage_size = 0; + ORT_THROW_IF_ERROR(CalculateTensorStorageSize(elt_type, shape, storage_size)); + return storage_size; } Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, @@ -106,18 +121,14 @@ void Tensor::InitOrtValue(Tensor&& tensor, OrtValue& ort_value) { size_t Tensor::SizeInBytes() const { #ifdef ENABLE_STRIDED_TENSORS - int64_t size = IsContiguous() ? shape_.Size() : GetSizeFromStrides(shape_, strides_); + int64_t shape_size = IsContiguous() ? shape_.Size() : GetSizeFromStrides(shape_, strides_); #else - int64_t size = shape_.Size(); + int64_t shape_size = shape_.Size(); #endif size_t ret = 0; + const int64_t num_elems = GetNumTensorElems(dtype_, shape_size); - // TODO(adrianlizarraga): Handle more cleanly. - if (utils::IsPrimitiveDataType(dtype_) || utils::IsPrimitiveDataType(dtype_)) { - size = (size + 1) / 2; - } - - if (!IAllocator::CalcMemSizeForArray(SafeInt(size), dtype_->Size(), &ret)) { + if (!IAllocator::CalcMemSizeForArray(SafeInt(num_elems), dtype_->Size(), &ret)) { ORT_THROW("tensor size overflow"); } return ret; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 4f94674066e0c..066b79e0e38f6 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -215,16 +215,10 @@ ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); } - auto elem_count = narrow(tensor_shape.Size()); - - // TODO(adrianlizarraga): Handle this more cleanly. - if (utils::IsPrimitiveDataType(ml_type) || utils::IsPrimitiveDataType(ml_type)) { - elem_count = (elem_count + 1) / 2; - } - - size_t size_to_allocate; - if (!IAllocator::CalcMemSizeForArray(ml_type->Size(), elem_count, &size_to_allocate)) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "size overflow"); + size_t size_to_allocate = 0; + Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, size_to_allocate); + if (!status.IsOK()) { + return ToOrtStatus(status); } if (size_to_allocate > p_data_len) { std::ostringstream oss; From 27c554e5c1d73ba41d113084044afe425e88257e Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 6 May 2024 23:51:44 -0700 Subject: [PATCH 48/72] Add more uses of Tensor::CalcTensorStorageSize() --- include/onnxruntime/core/framework/tensor.h | 15 +++++++-- onnxruntime/core/framework/execution_frame.cc | 13 ++------ .../core/framework/ort_value_tensor_slicer.cc | 8 ++++- .../core/framework/session_state_utils.cc | 16 ++-------- onnxruntime/core/framework/tensor.cc | 32 ++++++++++++++----- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- .../python/onnxruntime_pybind_mlvalue.cc | 7 ++-- 7 files changed, 53 insertions(+), 40 deletions(-) diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index 77f4d57e5fdfb..6804b340f0ac6 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -150,9 +150,12 @@ class Tensor final { /// /// Data type of the tensor elements. /// Tensor shape. + /// Power of 2 alignment to include in calculation. + /// Bumps up result to the nearest multiple of alignment. Set to 0 to ignore. /// The resulting storage size. /// Status indicating success or failure. - static Status CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape, size_t& storage_size); + static Status CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape, size_t alignment, + size_t& storage_size); /// /// Get the number of elements for a Tensor of the given element type and shape. @@ -231,7 +234,7 @@ class Tensor final { ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", "T ", "!=", dtype_); T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - return gsl::make_span(data, static_cast(GetNumTensorElems(dtype_, shape_))); + return gsl::make_span(data, static_cast(NumElements())); } template @@ -248,7 +251,7 @@ class Tensor final { ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", "T ", "!=", dtype_); const T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - return gsl::make_span(data, static_cast::size_type>(GetNumTensorElems(dtype_, shape_))); + return gsl::make_span(data, static_cast::size_type>(NumElements())); } void* MutableDataRaw(MLDataType type) { @@ -302,6 +305,12 @@ class Tensor final { byte_offset_ = byte_offset; } + /** + The number of Tensor elements. A single Tensor element may contain multiple sub-elements for + subbyte data types (e.g., int4). + */ + int64_t NumElements() const; + /** The number of bytes of data. */ diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 88921d0c4f886..894e0daae94b6 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -529,17 +529,8 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va return Status(ONNXRUNTIME, FAIL, "Trying to allocate memory for unused optional inputs/outputs"); } - size_t size; - int64_t len = Tensor::GetNumTensorElems(element_type, shape); - if (len < 0) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape cannot contain any negative value"); - } - if (static_cast(len) > std::numeric_limits::max()) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape is too large"); - } - if (!IAllocator::CalcMemSizeForArrayWithAlignment(static_cast(len), element_type->Size(), &size)) { - return Status(ONNXRUNTIME, FAIL, "size overflow"); - } + size_t size = 0; + ORT_RETURN_IF_ERROR(Tensor::CalculateTensorStorageSize(element_type, shape, kAllocAlignment, size)); // Lazily get the allocator only if needed. AllocatorPtr alloc = nullptr; diff --git a/onnxruntime/core/framework/ort_value_tensor_slicer.cc b/onnxruntime/core/framework/ort_value_tensor_slicer.cc index cf4020d830bff..2acc1e301d6e0 100644 --- a/onnxruntime/core/framework/ort_value_tensor_slicer.cc +++ b/onnxruntime/core/framework/ort_value_tensor_slicer.cc @@ -14,7 +14,13 @@ OrtValueTensorSlicer OrtValueTensorSlicer::Create(T& ort_value, int64_t sl ORT_ENFORCE(ort_value.IsTensor(), "Can't slice a non-tensor OrtValue. Type was ", ort_value.Type()); ORT_ENFORCE(ort_value.IsAllocated(), "OrtValue has not been allocated so can't be sliced."); - auto& tensor_shape = ort_value.template Get().Shape(); + const Tensor& tensor = ort_value.template Get(); + auto* prim_type = tensor.DataType()->AsPrimitiveDataType(); + if (prim_type != nullptr) { + // TODO(adrianlizarraga): Support slicing Tensors of subbyte element types (e.g., int4). + ORT_ENFORCE(!prim_type->HasSubElems(), "Can't slice a tensor with a subbyte element type"); + } + auto& tensor_shape = tensor.Shape(); ORT_ENFORCE(gsl::narrow_cast(tensor_shape.NumDimensions()) >= slice_dimension, "Insufficient dimensions to slice on ", slice_dimension, ". Shape:", tensor_shape); diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index d92dcfdd262bb..059de8e3c8c4a 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -37,20 +37,10 @@ namespace session_state_utils { // It can handle arena-based allocators and non-arena based allocators. static common::Status AllocateBufferUsingDeviceAllocatorFromShapeAndType(const TensorShape& tensor_shape, const DataTypeImpl* type, const AllocatorPtr& alloc, /*out*/ void*& p_data) { - int64_t shape_size = Tensor::GetNumTensorElems(type, tensor_shape); - if (shape_size < 0) - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "shape.Size() must >=0"); + size_t mem_size = 0; + ORT_RETURN_IF_ERROR(Tensor::CalculateTensorStorageSize(type, tensor_shape, /*alignment*/ 0, mem_size)); - p_data = nullptr; - if (shape_size > 0) { - SafeInt mem_size = 0; - - if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), type->Size(), &mem_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed memory size calculation"); - } - - p_data = alloc->Reserve(mem_size); - } + p_data = alloc->Reserve(mem_size); return Status::OK(); } diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index ce40a2d227251..68cf73955de24 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -31,7 +31,7 @@ int64_t Tensor::GetNumTensorElems(MLDataType elt_type, int64_t shape_size) { int64_t num_elems = shape_size; auto prim_type = elt_type->AsPrimitiveDataType(); - if (prim_type != nullptr && prim_type->HasSubElems() && num_elems > 0) { + if (prim_type != nullptr && num_elems > 0 && prim_type->HasSubElems()) { const int64_t num_sub_elems = prim_type->GetNumSubElems(); num_elems = (num_elems + (num_sub_elems - 1)) / num_sub_elems; } @@ -39,7 +39,7 @@ int64_t Tensor::GetNumTensorElems(MLDataType elt_type, int64_t shape_size) { return num_elems; } -Status Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape, +Status Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape, size_t alignment, /*out*/ size_t& storage_size) { int64_t num_elems = GetNumTensorElems(elt_type, shape.Size()); ORT_RETURN_IF(num_elems < 0, "Tensor shape.Size() must be >= 0"); @@ -48,8 +48,9 @@ Status Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape if (static_cast(num_elems) > std::numeric_limits::max()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape is too large"); } - if (!IAllocator::CalcMemSizeForArray(static_cast(num_elems), elt_type->Size(), &storage_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed memory size calculation for Tensor storage size"); + if (!IAllocator::CalcMemSizeForArrayWithAlignment(static_cast(num_elems), elt_type->Size(), alignment, + &storage_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Calculation for Tensor storage size overflowed"); } } else { storage_size = 0; @@ -60,7 +61,7 @@ Status Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape size_t Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape) { size_t storage_size = 0; - ORT_THROW_IF_ERROR(CalculateTensorStorageSize(elt_type, shape, storage_size)); + ORT_THROW_IF_ERROR(CalculateTensorStorageSize(elt_type, shape, 0, storage_size)); return storage_size; } @@ -119,14 +120,25 @@ void Tensor::InitOrtValue(Tensor&& tensor, OrtValue& ort_value) { ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } +int64_t Tensor::NumElements() const { + int64_t num_elems = shape_.Size(); + + if (dtype_ != nullptr && num_elems > 0 && dtype_->HasSubElems()) { + const int64_t num_sub_elems = dtype_->GetNumSubElems(); + num_elems = (num_elems + (num_sub_elems - 1)) / num_sub_elems; + } + + return num_elems; +} + size_t Tensor::SizeInBytes() const { #ifdef ENABLE_STRIDED_TENSORS - int64_t shape_size = IsContiguous() ? shape_.Size() : GetSizeFromStrides(shape_, strides_); + int64_t size = IsContiguous() ? shape_.Size() : GetSizeFromStrides(shape_, strides_); #else - int64_t shape_size = shape_.Size(); + int64_t size = shape_.Size(); #endif size_t ret = 0; - const int64_t num_elems = GetNumTensorElems(dtype_, shape_size); + const int64_t num_elems = GetNumTensorElems(dtype_, size); if (!IAllocator::CalcMemSizeForArray(SafeInt(num_elems), dtype_->Size(), &ret)) { ORT_THROW("tensor size overflow"); @@ -161,6 +173,8 @@ void Tensor::Init(MLDataType elt_type, const TensorShape& shape, void* p_raw_dat ORT_ENFORCE(shape.NumDimensions() == strides.size(), "Length of strides doesn't match tensor dimension size."); strides_.assign(strides.begin(), strides.end()); is_contiguous_ = CheckIsContiguous(); + ORT_ENFORCE(is_contiguous_ || !dtype_->HasSubElems(), + "Do not support subbyte element types with non-contiguous strided tensors."); } #else ORT_UNUSED_PARAMETER(strides); @@ -277,6 +291,8 @@ void Tensor::SetShapeAndStrides(const TensorShape& new_shape, gsl::spanHasSubElems(), + "Do not support subbyte element types with non-contiguous strided tensors."); } #endif diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 066b79e0e38f6..5cf5ff9b3bd0a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -216,7 +216,7 @@ ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t } size_t size_to_allocate = 0; - Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, size_to_allocate); + Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); if (!status.IsOK()) { return ToOrtStatus(status); } diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 78297df185d68..23744c24d1a21 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -586,9 +586,10 @@ static void CopyDataToTensor(PyArrayObject* darray, int npy_type, Tensor& tensor } } else { void* buffer = tensor.MutableDataRaw(); - size_t len; - if (!IAllocator::CalcMemSizeForArray(tensor.DataType()->Size(), tensor.Shape().Size(), &len)) { - throw std::runtime_error("length overflow"); + size_t len = 0; + Status status = Tensor::CalculateTensorStorageSize(tensor.DataType(), tensor.Shape(), /*alignment*/ 0, len); + if (!status.IsOK()) { + throw std::runtime_error(status.ErrorMessage()); } mem_cpy_to_device(buffer, PyArray_DATA(darray), len); } From 2e8c3b91d4ba25cfa559d4f876c6227f1362be6a Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 00:26:15 -0700 Subject: [PATCH 49/72] Add new PrimitiveDataTypeBase methods to provider api --- .../core/providers/shared_library/provider_interfaces.h | 2 ++ .../core/providers/shared_library/provider_wrappedtypes.h | 8 ++++++++ onnxruntime/core/session/provider_bridge_ort.cc | 2 ++ 3 files changed, 12 insertions(+) diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 514370d65d173..0356ea265166f 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -575,6 +575,8 @@ struct ProviderHost { // PrimitiveDataTypeBase virtual int32_t PrimitiveDataTypeBase__GetDataType(const PrimitiveDataTypeBase* p) = 0; + virtual int32_t PrimitiveDataTypeBase__GetNumSubElems(const PrimitiveDataTypeBase* p) = 0; + virtual bool PrimitiveDataTypeBase__HasSubElems(const PrimitiveDataTypeBase* p) = 0; // DataTypeImpl virtual MLDataType DataTypeImpl__GetType_Tensor() = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index a74630cdf7edd..b46358a72009e 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -584,6 +584,14 @@ struct KernelRegistry final { struct PrimitiveDataTypeBase final { int32_t GetDataType() const { return g_host->PrimitiveDataTypeBase__GetDataType(this); } + int32_t GetNumSubElems() const { + return g_host->PrimitiveDataTypeBase__GetNumSubElems(this); + } + + bool HasSubElems() const { + return g_host->PrimitiveDataTypeBase__HasSubElems(this); + } + PROVIDER_DISALLOW_ALL(PrimitiveDataTypeBase) }; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index f73abd92ee6f5..0a2ec37a0560d 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -769,6 +769,8 @@ struct ProviderHostImpl : ProviderHost { // PrimitiveDataTypeBase (wrapped) int32_t PrimitiveDataTypeBase__GetDataType(const PrimitiveDataTypeBase* p) override { return p->GetDataType(); } + int32_t PrimitiveDataTypeBase__GetNumSubElems(const PrimitiveDataTypeBase* p) override { return p->GetNumSubElems(); } + bool PrimitiveDataTypeBase__HasSubElems(const PrimitiveDataTypeBase* p) override { return p->HasSubElems(); } // DataTypeImpl (wrapped) MLDataType DataTypeImpl__GetType_Tensor() override { return DataTypeImpl::GetType(); } From f26d8856ad090e188398393bed7664ba5d26c7e7 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 01:09:48 -0700 Subject: [PATCH 50/72] Remove SparseTensor registrations for int4 types --- include/onnxruntime/core/framework/data_types.h | 3 +-- onnxruntime/core/framework/data_types.cc | 8 -------- .../providers/shared_library/provider_bridge_provider.cc | 8 -------- .../core/providers/shared_library/provider_interfaces.h | 2 -- onnxruntime/core/session/provider_bridge_ort.cc | 6 ------ 5 files changed, 1 insertion(+), 26 deletions(-) diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index 4d950d17dba7a..c0db9be98931f 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -297,8 +297,7 @@ struct IsTensorContainedType : public IsAnyOf struct IsSparseTensorContainedType : public IsAnyOf& reg_fn) { REGISTER_SPARSE_TENSOR_PROTO(Float8E5M2, reg_fn); REGISTER_SPARSE_TENSOR_PROTO(Float8E5M2FNUZ, reg_fn); #endif - REGISTER_SPARSE_TENSOR_PROTO(Int4x2, reg_fn); - REGISTER_SPARSE_TENSOR_PROTO(UInt4x2, reg_fn); #endif #if !defined(DISABLE_ML_OPS) @@ -1174,10 +1170,6 @@ const SparseTensorTypeBase* DataTypeImpl::SparseTensorTypeFromONNXEnum(int type) return DataTypeImpl::GetSparseTensorType()->AsSparseTensorType(); #endif - case TensorProto_DataType_INT4: - return DataTypeImpl::GetSparseTensorType()->AsSparseTensorType(); - case TensorProto_DataType_UINT4: - return DataTypeImpl::GetSparseTensorType()->AsSparseTensorType(); default: ORT_NOT_IMPLEMENTED("sparse tensor type ", type, " is not supported"); diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 42fec5dce6694..27d8a0f06f565 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -256,14 +256,6 @@ MLDataType DataTypeImpl::GetSparseTensorType() { return Provider_Get template <> MLDataType DataTypeImpl::GetSparseTensorType() { return Provider_GetHost()->DataTypeImpl__GetSparseTensorType_Float8E5M2FNUZ(); } #endif -template <> -MLDataType DataTypeImpl::GetSparseTensorType() { - return Provider_GetHost()->DataTypeImpl__GetSparseTensorType_Int4x2(); -} -template <> -MLDataType DataTypeImpl::GetSparseTensorType() { - return Provider_GetHost()->DataTypeImpl__GetSparseTensorType_UInt4x2(); -} #endif diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 0356ea265166f..f5dca30c738ed 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -651,8 +651,6 @@ struct ProviderHost { virtual MLDataType DataTypeImpl__GetSparseTensorType_Float8E5M2() = 0; virtual MLDataType DataTypeImpl__GetSparseTensorType_Float8E5M2FNUZ() = 0; #endif - virtual MLDataType DataTypeImpl__GetSparseTensorType_Int4x2() = 0; - virtual MLDataType DataTypeImpl__GetSparseTensorType_UInt4x2() = 0; #endif virtual const char* DataTypeImpl__ToString(MLDataType type) = 0; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 0a2ec37a0560d..596a1717fa1a6 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -847,12 +847,6 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetSparseTensorType_Float8E5M2() override { return DataTypeImpl::GetSparseTensorType(); } MLDataType DataTypeImpl__GetSparseTensorType_Float8E5M2FNUZ() override { return DataTypeImpl::GetSparseTensorType(); } #endif - MLDataType DataTypeImpl__GetSparseTensorType_Int4x2() override { - return DataTypeImpl::GetSparseTensorType(); - } - MLDataType DataTypeImpl__GetSparseTensorType_UInt4x2() override { - return DataTypeImpl::GetSparseTensorType(); - } #endif const char* DataTypeImpl__ToString(MLDataType type) override { return DataTypeImpl::ToString(type); } From f3fdc2e5864821859d513fda36164b5cfd4470d4 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 03:34:58 -0700 Subject: [PATCH 51/72] Support Transpose int4 --- .../core/framework/element_type_lists.h | 7 ++ .../core/providers/cpu/tensor/transpose.cc | 82 +++++++++-------- .../providers/cpu/tensor/transpose_test.cc | 91 ++++++++++++++----- 3 files changed, 120 insertions(+), 60 deletions(-) diff --git a/onnxruntime/core/framework/element_type_lists.h b/onnxruntime/core/framework/element_type_lists.h index 3d956ec26d22e..2478dc27162ac 100644 --- a/onnxruntime/core/framework/element_type_lists.h +++ b/onnxruntime/core/framework/element_type_lists.h @@ -11,6 +11,7 @@ #include "core/common/type_list.h" #include "core/framework/float8.h" #include "core/framework/float16.h" +#include "core/framework/int4.h" namespace onnxruntime { @@ -82,6 +83,12 @@ using AllIRv9 = using AllIRv9 = AllIRv4; #endif +using AllIRv10 = + boost::mp11::mp_push_back< + AllIRv9, + UInt4x2, + Int4x2>; + using All = AllIRv4; #if !defined(DISABLE_FLOAT8_TYPES) diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index 849adb73e784d..ae62943476009 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -22,7 +22,7 @@ namespace op_kernel_type_control { // we're using one set of types for all opsets ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Transpose, Input, 0, - DefaultDataTypes); + element_type_lists::AllIRv10); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // enable all types for layout transformation @@ -30,12 +30,23 @@ ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Transpose, Input, 0, DefaultDataTypes); #endif + +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, Transpose, 21, Input, 0, + element_type_lists::AllIRv10); + +ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, Transpose, 21, Input, 0, + element_type_lists::AllIRv10); + } // namespace op_kernel_type_control namespace { // reduce the supported types with any global or op specific lists -using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, - Transpose, Input, 0); +using EnabledDataTypesAllOpsets = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, + Transpose, Input, 0); +using EnabledDataTypesOpset21 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + Transpose, 21, Input, 0); } // namespace /* A permutation [a,b,c,...] indicates that @@ -184,7 +195,7 @@ inline void CopyPrim(uint8_t* target, const uint8_t* source) { template static bool TypedDoTransposeEltWise(int64_t num_axes, gsl::span target_dims, size_t num_blocks, const gsl::span& stride, const uint8_t* source, uint8_t* target) { - constexpr bool enabled = utils::HasTypeWithSameSize(); + constexpr bool enabled = utils::HasTypeWithSameSize(); if (enabled) { MultiIndex mindex; @@ -288,7 +299,7 @@ static Status DoUntypedTranspose(const gsl::span& permutations, co Status status = Status::OK(); if (is_string_type) { - constexpr bool string_enabled = utils::HasType(); + constexpr bool string_enabled = utils::HasType(); if (string_enabled) { const auto* input_data = input.Data(); @@ -360,16 +371,14 @@ static Status TransposeImpl(const gsl::span& permutations, const T return DoUntypedTranspose(permutations, input, output, input_shape_override); } -template +template static Status UnpackInt4Tensor(const Tensor& src, Tensor& dst, AllocatorPtr cpu_allocator) { - static_assert(sizeof(PackedType) == 1); - static_assert(sizeof(UnpackedType) == 1); - + using UnpackedType = typename Int4Type::UnpackedType; MLDataType int8_elem_type = DataTypeImpl::GetType(); const TensorShape& shape = src.Shape(); Tensor int8_tensor(int8_elem_type, shape, cpu_allocator); - ORT_RETURN_IF_NOT(PackedType::Unpack(int8_tensor.MutableDataAsSpan(), src.DataAsSpan()), + ORT_RETURN_IF_NOT(Int4Type::Unpack(int8_tensor.MutableDataAsSpan(), src.DataAsSpan()), "Failed to unpack Int4x2 Tensor to an int8_t Tensor"); dst = std::move(int8_tensor); @@ -377,6 +386,27 @@ static Status UnpackInt4Tensor(const Tensor& src, Tensor& dst, AllocatorPtr cpu_ return Status::OK(); } +template +static Status DoTransposeInt4(const gsl::span& permutations, const Tensor& input, Tensor& output, + const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { + using Int8Type = typename Int4Type::UnpackedType; + + ORT_RETURN_IF_NOT(input.IsDataType() && output.IsDataType(), + "Expected to transpose int4 tensor"); + + // Convert to Tensor, transpose, and then repack back to Tensor. + AllocatorPtr cpu_allocator = std::make_shared(); + Tensor input_unpacked; + Tensor output_unpacked(DataTypeImpl::GetType(), output.Shape(), cpu_allocator); + + ORT_RETURN_IF_ERROR((UnpackInt4Tensor(input, input_unpacked, cpu_allocator))); + ORT_RETURN_IF_ERROR(TransposeImpl(permutations, input_unpacked, output_unpacked, input_shape_override, tp)); + ORT_RETURN_IF_NOT(Int4Type::Pack(output.MutableDataAsSpan(), output_unpacked.DataAsSpan()), + "Failed to pack 8-bit Tensor into 4-bit Tensor"); + + return Status::OK(); +} + //`input_shape_override` overrides the shape of `input` for compute purposes. Status TransposeBase::DoTranspose(const gsl::span& permutations, const Tensor& input, Tensor& output, const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { @@ -388,31 +418,11 @@ Status TransposeBase::DoTranspose(const gsl::span& permutations, c input_type, " != ", output_type); } if (input.IsDataType()) { - // Convert to Tensor, transpose, and then repack back to Int4x2. - AllocatorPtr cpu_allocator = std::make_shared(); - Tensor input_unpacked; - Tensor output_unpacked(DataTypeImpl::GetType(), output.Shape(), cpu_allocator); - - ORT_RETURN_IF_ERROR((UnpackInt4Tensor(input, input_unpacked, cpu_allocator))); - ORT_RETURN_IF_ERROR(TransposeImpl(permutations, input_unpacked, output_unpacked, input_shape_override, tp)); - ORT_RETURN_IF_NOT(Int4x2::Pack(output.MutableDataAsSpan(), output_unpacked.DataAsSpan()), - "Failed to pack Tensor into Tensor"); - - return Status::OK(); + return DoTransposeInt4(permutations, input, output, input_shape_override, tp); } if (input.IsDataType()) { - // Convert to Tensor, transpose, and then repack back to UInt4x2. - AllocatorPtr cpu_allocator = std::make_shared(); - Tensor input_unpacked; - Tensor output_unpacked(DataTypeImpl::GetType(), output.Shape(), cpu_allocator); - - ORT_RETURN_IF_ERROR((UnpackInt4Tensor(input, input_unpacked, cpu_allocator))); - ORT_RETURN_IF_ERROR(TransposeImpl(permutations, input_unpacked, output_unpacked, input_shape_override, tp)); - ORT_RETURN_IF_NOT(UInt4x2::Pack(output.MutableDataAsSpan(), output_unpacked.DataAsSpan()), - "Failed to pack Tensor into Tensor"); - - return Status::OK(); + return DoTransposeInt4(permutations, input, output, input_shape_override, tp); } return TransposeImpl(permutations, input, output, input_shape_override, tp); @@ -447,22 +457,22 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Transpose, 1, 12, - KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), + KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), Transpose); ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Transpose, 13, 20, - KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), + KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), Transpose); // Opset 21 added support for float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4 and uint4. -// TODO(adrianlizarraga): Implement support for float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4 and uint4. +// TODO(adrianlizarraga): Implement support for float8e4m3fnuz, float8e5m2, and float8e5m2fnuz. ONNX_CPU_OPERATOR_KERNEL( Transpose, 21, - KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), + KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), Transpose); } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 0e7ac5ed2b2f0..5147ead9ad929 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -49,14 +49,17 @@ void TransposeTest(const std::vector& input_shape, const std::vector* p_perm, const std::vector& expected_shape, const std::vector& expected_vals, - const std::unordered_set& excluded_provider_types = {}) { - OpTester test("Transpose"); - if (nullptr != p_perm) - test.AddAttribute("perm", *p_perm); - test.AddInput("X", input_shape, input_vals); - test.AddOutput("Y", expected_shape, expected_vals); - - test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_provider_types); + const std::unordered_set& excluded_provider_types = {}, + const std::vector& opsets = {7}) { + for (auto opset : opsets) { + OpTester test("Transpose", opset); + if (nullptr != p_perm) + test.AddAttribute("perm", *p_perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_provider_types); + } } // Test 2 dimensional transpose, with no permutation attribute specified @@ -73,7 +76,7 @@ TEST(TransposeOpTest, TwoDimNoAttr) { 3.0f, 6.0f}; TransposeTest(input_shape, input_vals, nullptr, expected_shape, expected_vals, - {kTensorrtExecutionProvider}); // TensorRT: SegFault error + {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: SegFault error } TEST(TransposeOpTest, TwoDimNoAttrStr) { @@ -88,7 +91,7 @@ TEST(TransposeOpTest, TwoDimNoAttrStr) { "2", "5", "3", "6"}; - TransposeTest(input_shape, input_vals, nullptr, expected_shape, expected_vals); + TransposeTest(input_shape, input_vals, nullptr, expected_shape, expected_vals, {}, {7, 21}); } // Test 2 dimensional transpose, with permutation attribute specified @@ -103,7 +106,47 @@ TEST(TransposeOpTest, TwoDim) { 2.0f, 5.0f, 3.0f, 6.0f}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {}, {7, 21}); +} + +// Test Int4 transpose with odd inner dimension. +TEST(TransposeOpTest, TwoDim_Odd_Int4) { + constexpr int8_t unused_val = 0; + std::vector input_shape({5, 3}); + std::vector input_vals = {Int4x2(1, 2), Int4x2(3, 4), Int4x2(5, 6), Int4x2(7, 8), + Int4x2(9, 10), Int4x2(11, 12), Int4x2(13, 14), Int4x2(15, unused_val)}; + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 5}); + std::vector expected_vals = {Int4x2(1, 4), Int4x2(7, 10), Int4x2(13, 2), Int4x2(5, 8), + Int4x2(11, 14), Int4x2(3, 6), Int4x2(9, 12), Int4x2(15, unused_val)}; + + OpTester test("Transpose", 21); + test.AddAttribute("perm", perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(); +} + +// Test UInt4 transpose with odd inner dimension. +TEST(TransposeOpTest, TwoDim_Odd_UInt4) { + constexpr int8_t unused_val = 0; + std::vector input_shape({5, 3}); + std::vector input_vals = {UInt4x2(1, 2), UInt4x2(3, 4), UInt4x2(5, 6), UInt4x2(7, 8), + UInt4x2(9, 10), UInt4x2(11, 12), UInt4x2(13, 14), UInt4x2(15, unused_val)}; + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 5}); + std::vector expected_vals = {UInt4x2(1, 4), UInt4x2(7, 10), UInt4x2(13, 2), UInt4x2(5, 8), + UInt4x2(11, 14), UInt4x2(3, 6), UInt4x2(9, 12), UInt4x2(15, unused_val)}; + + OpTester test("Transpose", 21); + test.AddAttribute("perm", perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(); } TEST(TransposeOpTest, TwoDim_double) { @@ -131,7 +174,7 @@ TEST(TransposeOpTest, TwoDim_int32) { 2, 5, 3, 6}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {}, {7, 21}); } TEST(TransposeOpTest, TwoDim_int16) { @@ -147,7 +190,7 @@ TEST(TransposeOpTest, TwoDim_int16) { 2, 5, 3, 6}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kOpenVINOExecutionProvider}); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kOpenVINOExecutionProvider}, {7, 21}); } TEST(TransposeOpTest, TwoDim_mlfloat16) { @@ -163,7 +206,7 @@ TEST(TransposeOpTest, TwoDim_mlfloat16) { MLFloat16::FromBits(static_cast(2)), MLFloat16::FromBits(static_cast(5)), MLFloat16::FromBits(static_cast(3)), MLFloat16::FromBits(static_cast(6))}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}, {7, 21}); } #if defined(USE_DNNL) @@ -264,7 +307,7 @@ TEST(TransposeOpTest, TwoDim_int8) { 2, 5, 3, 6}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}, {7, 21}); } TEST(TransposeOpTest, TwoDimStr) { @@ -280,7 +323,7 @@ TEST(TransposeOpTest, TwoDimStr) { "2", "5", "3", "6"}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {}, {7, 21}); } // Test 3 dimensional transpose, with permutation attribute specified @@ -319,7 +362,7 @@ TEST(TransposeOpTest, Transpose021) { 3.3f, 6.3f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, - {kTensorrtExecutionProvider}); // TensorRT: illegal error + {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: illegal error } TEST(TransposeOpTest, Transpose120) { @@ -349,7 +392,7 @@ TEST(TransposeOpTest, Transpose120) { 6.0f, 6.1f, 6.2f, 6.3f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, - {kTensorrtExecutionProvider}); // TensorRT: illegal error + {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: illegal error } // test when the suffix size is > 1 (last dimension is not moved) @@ -382,7 +425,7 @@ TEST(TransposeOpTest, Transpose102) { 4.3f, 5.3f, 6.3f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, - {kTensorrtExecutionProvider}); // TensorRT: illegal error + {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: illegal error } TEST(TransposeOpTest, TransposeReshape) { @@ -416,7 +459,7 @@ TEST(TransposeOpTest, TransposeReshape) { 4.3f, 5.3f, 6.3f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, - {kTensorrtExecutionProvider}); // TensorRT: illegal error + {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: illegal error } TEST(TransposeOpTest, ThreeDimStr) { @@ -453,7 +496,7 @@ TEST(TransposeOpTest, ThreeDimStr) { "2", "5", "3", "6"}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {}, {7, 21}); } TEST(TransposeOpTest, SixDim) { @@ -478,7 +521,7 @@ TEST(TransposeOpTest, SixDim) { }(); TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, - {kQnnExecutionProvider}); // Error: Failed to finalize QNN graph. + {kQnnExecutionProvider}, {7, 21}); // Error: Failed to finalize QNN graph. } template @@ -522,7 +565,7 @@ TEST(TransposeOpTest, NCHW2NHWCStr) { "3", "7", "11", "4", "8", "12"}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}, {7, 21}); } template @@ -582,7 +625,7 @@ TEST(TransposeOpTest, NHWC2NCHW_String) { "2", "5", "8", "11", "3", "6", "9", "12"}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}); + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, {kTensorrtExecutionProvider}, {7, 21}); } // test to cover memcpy from single axis moving inwards path From 8adbb4a8ea79eda342b59a99e0623bd84407a674 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 04:09:58 -0700 Subject: [PATCH 52/72] Revert to default types for older transpose opsets --- onnxruntime/core/providers/cpu/tensor/transpose.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index ae62943476009..5b904e85848d0 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -22,7 +22,7 @@ namespace op_kernel_type_control { // we're using one set of types for all opsets ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Transpose, Input, 0, - element_type_lists::AllIRv10); + DefaultDataTypes); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // enable all types for layout transformation From 0ac8427a8b53e592eb374f5d60acc2108b45928c Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 04:17:40 -0700 Subject: [PATCH 53/72] Update op docs --- docs/OperatorKernels.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 118eee04910b8..9d59b04a92419 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -418,7 +418,7 @@ Do not modify directly.* |TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float)| |||[1, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float)| -|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)| From da80e3a3086f10e5176e9e736b16e27c07c43eac Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 04:49:13 -0700 Subject: [PATCH 54/72] Add comments --- include/onnxruntime/core/framework/data_types.h | 9 +++++++-- include/onnxruntime/core/framework/int4.h | 4 ++++ include/onnxruntime/core/session/onnxruntime_c_api.h | 5 +++-- onnxruntime/test/onnx/TestCase.cc | 1 + 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index c0db9be98931f..b197d88090432 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -919,7 +919,8 @@ class OpaqueType : public NonTensorType { * Base class for primitive Tensor contained types * * \details This class contains an integer constant that can be - * used for input data type dispatching + * used for input data type dispatching. This class also stores the number of subelements per size units. + * Example: For int4, the size unit is 1 byte and the number of subelements is 2. * */ class PrimitiveDataTypeBase : public DataTypeImpl { @@ -950,7 +951,7 @@ class PrimitiveDataTypeBase : public DataTypeImpl { private: const int32_t data_type_; - const int32_t num_sub_elems_; + const int32_t num_sub_elems_; // > 1 for subbyte primitives, 1 for normal primitives. }; /** @@ -1096,6 +1097,10 @@ inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const { return PrimitiveDataType::Type(); \ } +// Registers a subbyte primitive. +// Examples: +// - Int4x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Int4x2, 2) +// - [not supported] Int3x8 could store 8 packed 3-bit elements in 3 bytes: ORT_*_SUBBYTE_TYPE(Int3x8, 8) #define ORT_REGISTER_PRIM_SUBBYTE_TYPE(TYPE, NUM_SUB_ELEMS) \ template <> \ MLDataType PrimitiveDataType::Type() { \ diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h index 6f0311a39f591..8a9ffe97db68f 100644 --- a/include/onnxruntime/core/framework/int4.h +++ b/include/onnxruntime/core/framework/int4.h @@ -27,6 +27,10 @@ struct Int4Traits { static constexpr uint8_t max_val = 15; }; +/// +/// Stores 2 packed 4-bit elements in 1 byte. +/// +/// Set to true if signed int4, or false if unsigned uint4. template struct Int4x2Base { using UnpackedType = typename Int4Traits::UnpackedType; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index a2fec7a75527e..524541a2557e8 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -197,8 +197,9 @@ typedef enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, // Non-IEEE floating-point format based on IEEE754 single-precision ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of uint4 values (size == 1 byte) - ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of int4 values (size == 1 byte) + // Int4 types were introduced in ONNX 1.16. See https://onnx.ai/onnx/technical/int4.html + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of packed uint4 values (size == 1 byte) + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of packed int4 values (size == 1 byte) } ONNXTensorElementDataType; // Synced with onnx TypeProto oneof diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index baf79dfce7bfe..1d54a3cfae9bf 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1022,6 +1022,7 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {"dequantizelinear_blocked", "blocked quantization (onnx 1.16.0) not supported", {}}, {"quantizelinear_blocked_asymmetric", "blocked quantization (onnx 1.16.0) not supported", {}}, {"quantizelinear_blocked_symmetric", "blocked quantization (onnx 1.16.0) not supported", {}}, + // See PR that fixes int4 q/dq tests: https://github.com/onnx/onnx/pull/6122 {"dequantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}}, {"dequantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}, {"quantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}}, From f72c3d597ec0488e8a9f9e07809986a913dd1b8a Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 09:20:40 -0700 Subject: [PATCH 55/72] Exclude TRT from int4 traspose test --- onnxruntime/test/providers/cpu/tensor/transpose_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 5147ead9ad929..01dba55ceb8ed 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -126,7 +126,7 @@ TEST(TransposeOpTest, TwoDim_Odd_Int4) { test.AddInput("X", input_shape, input_vals); test.AddOutput("Y", expected_shape, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // Test UInt4 transpose with odd inner dimension. @@ -146,7 +146,7 @@ TEST(TransposeOpTest, TwoDim_Odd_UInt4) { test.AddInput("X", input_shape, input_vals); test.AddOutput("Y", expected_shape, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(TransposeOpTest, TwoDim_double) { From 66055127f81a87ea7b22a3c16991c90092bc5c34 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 10:43:17 -0700 Subject: [PATCH 56/72] Test C API for creating int4 OrtValues --- onnxruntime/test/shared_lib/test_inference.cc | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 051a93ac8458f..58c185f818df7 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -2827,6 +2827,50 @@ TEST(CApiTest, create_tensor_with_data_float8) { #endif +// Test creating an Ort::Value with INT4 data. +TEST(CApiTest, create_tensor_with_data_int4) { + std::array values = {0x10, 0x32, 0x78, 0x06}; // {0, 1, 2, 3, -8, 7, 6, pad_0} + std::vector dims = {7}; // 7 4-bit elements take up 4 bytes. + + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + Ort::Value tensor = Ort::Value::CreateTensor(info, values.data(), values.size(), dims.data(), dims.size(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4); + const auto* new_pointer = tensor.GetTensorData(); + ASSERT_EQ(new_pointer, values.data()); + auto type_info = tensor.GetTypeInfo(); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + ASSERT_NE(tensor_info, nullptr); + auto query_dims = tensor_info.GetShape(); + ASSERT_EQ(query_dims, dims); + ASSERT_EQ(tensor_info.GetElementType(), ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4); + + uint8_t pair_2 = tensor.At({2}); + ASSERT_EQ(values[2], pair_2); +} + +// Test creating an Ort::Value with UINT4 data. +TEST(CApiTest, create_tensor_with_data_uint4) { + std::array values = {0x10, 0x32, 0x54, 0x0F}; // {0, 1, 2, 3, 4, 5, 15, pad_0} + std::vector dims = {7}; // 7 4-bit elements take up 4 bytes. + + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + Ort::Value tensor = Ort::Value::CreateTensor(info, values.data(), values.size(), dims.data(), dims.size(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4); + const auto* new_pointer = tensor.GetTensorData(); + ASSERT_EQ(new_pointer, values.data()); + auto type_info = tensor.GetTypeInfo(); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + ASSERT_NE(tensor_info, nullptr); + auto query_dims = tensor_info.GetShape(); + ASSERT_EQ(query_dims, dims); + ASSERT_EQ(tensor_info.GetElementType(), ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4); + + uint8_t pair_2 = tensor.At({2}); + ASSERT_EQ(values[2], pair_2); +} + TEST(CApiTest, access_tensor_data_elements) { /** * Create a 2x3 data blob that looks like: From ea96d09d48df51db4ffc2ee3e6a639f14ff607fe Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 10:53:22 -0700 Subject: [PATCH 57/72] Add comment to qmath macro for defining the int4 quantization functions --- onnxruntime/core/util/qmath.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index 4aa2ab81216ce..15b7eb1063976 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -133,6 +133,22 @@ ParQuantizeLinearStd(const float* Input, }); } +/** + * Defines a function for int4 quantization. Calls MLAS kernel in parallel with the provided thread pool. + * + * \param FUNC_NAME The name of the generated function. + * \param INT4_TYPE The int4 type (i.e., either Int4x2 or UInt4x2) + * \param MLAS_FUNC The MLAS quantization kernel to call. + * \param Input The input float values to quantize. Must contain `out_end - out_start` elements. + * \param Output The output buffer that will contain the quantized values. + * \param out_start The int4 element index at which to start writing to the output buffer. + * Divide by 2 to get index into Output buffer. + * \param out_end The int4 element index at which to stop writing to the output buffer. + * Divide by 2 to get index into Output buffer. + * \param Scale The quantization scale value. + * \param ZeroPoint The quantization zero-point value. + * \param thread_pool The thread pool to use. + */ #define DEFINE_PAR_QUANT_LINEAR_STD_4BIT(FUNC_NAME, INT4_TYPE, MLAS_FUNC) \ inline void FUNC_NAME(const float* Input, \ INT4_TYPE* Output, \ From aa029ba1431ed7a7b9a2a85cd502608eeb668031 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 11:30:57 -0700 Subject: [PATCH 58/72] Clean up tensorprotoutils macros --- include/onnxruntime/core/framework/int4.h | 8 +- .../core/framework/tensorprotoutils.cc | 207 +++++++----------- 2 files changed, 83 insertions(+), 132 deletions(-) diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h index 8a9ffe97db68f..89ffbb96231bc 100644 --- a/include/onnxruntime/core/framework/int4.h +++ b/include/onnxruntime/core/framework/int4.h @@ -98,8 +98,12 @@ struct Int4x2Base { return bits_; } + static size_t CalcNumInt4Pairs(size_t num_int4_elems) { + return (num_int4_elems + 1) / 2; + } + static bool Unpack(gsl::span dst, gsl::span> src) { - if (((dst.size() + 1) / 2) != src.size()) { + if (CalcNumInt4Pairs(dst.size()) != src.size()) { return false; } @@ -113,7 +117,7 @@ struct Int4x2Base { } static bool Pack(gsl::span> dst, gsl::span src) { - if (((src.size() + 1) / 2) != dst.size()) { + if (CalcNumInt4Pairs(src.size()) != dst.size()) { return false; } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index d4608d6e1bc4c..4c3c6b373a415 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -141,41 +141,28 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t reinterpret_cast(p_data)); } -template <> -Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, - /*out*/ Int4x2* p_data) { - static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); - - ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); - - size_t num_packed_pairs = (expected_num_elements + 1) / 2; - ORT_RETURN_IF_NOT(num_packed_pairs == raw_data_len, "Unexpected number of packed int4 pairs"); - - gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); - gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); - - std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); - - return Status::OK(); -} - -template <> -Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, - /*out*/ UInt4x2* p_data) { - static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); - - ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); - - size_t num_packed_pairs = (expected_num_elements + 1) / 2; - ORT_RETURN_IF_NOT(num_packed_pairs == raw_data_len, "Unexpected number of packed int4 pairs"); - - gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); - gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); - - std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); - - return Status::OK(); -} +#define DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(INT4_TYPE) \ + template <> \ + Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, \ + /*out*/ INT4_TYPE* p_data) { \ + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \ + \ + ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \ + \ + size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \ + ORT_RETURN_IF_NOT(num_packed_pairs == raw_data_len, "Unexpected number of packed int4 pairs"); \ + \ + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), \ + num_packed_pairs); \ + gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); \ + \ + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \ + \ + return Status::OK(); \ + } + +DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2) +DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2) static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const ORTCHAR_T* tensor_proto_dir, @@ -313,49 +300,31 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, reinterpret_cast(p_data)); } -template <> -Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, - const ORTCHAR_T* tensor_proto_dir, size_t expected_num_elements, - /*out*/ Int4x2* p_data) { - static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); - - ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); - std::vector unpacked_tensor; - ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); - - size_t num_packed_pairs = (expected_num_elements + 1) / 2; - ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); - - gsl::span src_span = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), - num_packed_pairs); - gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); - - std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); - - return Status::OK(); -} - -template <> -Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, - const ORTCHAR_T* tensor_proto_dir, size_t expected_num_elements, - /*out*/ UInt4x2* p_data) { - static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); - - ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); - std::vector unpacked_tensor; - ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); - - size_t num_packed_pairs = (expected_num_elements + 1) / 2; - ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); - - gsl::span src_span = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), - num_packed_pairs); - gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); - - std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); - - return Status::OK(); -} +#define DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(INT4_TYPE) \ + template <> \ + Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, \ + const ORTCHAR_T* tensor_proto_dir, size_t expected_num_elements, \ + /*out*/ INT4_TYPE* p_data) { \ + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \ + \ + ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \ + std::vector unpacked_tensor; \ + ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); \ + \ + size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \ + ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); \ + \ + gsl::span src_span = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), \ + num_packed_pairs); \ + gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); \ + \ + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \ + \ + return Status::OK(); \ + } + +DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Int4x2) +DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(UInt4x2) #define INSTANTIATE_UNPACK_EXTERNAL_TENSOR(type) \ template Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto&, const ORTCHAR_T*, size_t, type*); @@ -698,61 +667,39 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d #endif -// UnpackTensor -template <> -Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, - /*out*/ Int4x2* p_data, size_t expected_num_elems) { - if (nullptr == p_data) { - const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); - return size == 0 ? Status::OK() : Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); - } - if (ONNX_NAMESPACE::TensorProto_DataType_INT4 != tensor.data_type()) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); - } - - size_t expected_int4_pairs = (expected_num_elems + 1) / 2; - - if (raw_data != nullptr) { - return UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); - } - - ORT_RETURN_IF_NOT(static_cast(tensor.int32_data_size()) == expected_int4_pairs, - "UnpackTensor: the pre-allocated size does not match the size in proto"); - - for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { - p_data[i] = Int4x2(static_cast(tensor.int32_data()[i])); +#define DEFINE_INT4_UNPACK_TENSOR_IMPL(INT4_TYPE, ONNX_INT4_TYPE) \ + template <> \ + Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \ + /*out*/ INT4_TYPE* p_data, size_t expected_num_elems) { \ + if (nullptr == p_data) { \ + const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); \ + return size == 0 ? Status::OK() : Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ + } \ + if (ONNX_NAMESPACE::ONNX_INT4_TYPE != tensor.data_type()) { \ + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ + } \ + \ + size_t expected_int4_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elems); \ + \ + if (raw_data != nullptr) { \ + return UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); \ + } \ + \ + ORT_RETURN_IF_NOT(static_cast(tensor.int32_data_size()) == expected_int4_pairs, \ + "UnpackTensor: the pre-allocated size does not match the size in proto"); \ + \ + for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { \ + p_data[i] = INT4_TYPE(static_cast(tensor.int32_data()[i])); \ + } \ + \ + return Status::OK(); \ } - return Status::OK(); -} +// UnpackTensor +DEFINE_INT4_UNPACK_TENSOR_IMPL(Int4x2, TensorProto_DataType_INT4) // UnpackTensor -template <> -Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, - /*out*/ UInt4x2* p_data, size_t expected_num_elems) { - if (nullptr == p_data) { - const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); - return size == 0 ? Status::OK() : Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); - } - if (ONNX_NAMESPACE::TensorProto_DataType_UINT4 != tensor.data_type()) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); - } - - size_t expected_int4_pairs = (expected_num_elems + 1) / 2; - - if (raw_data != nullptr) { - return UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); - } - - ORT_RETURN_IF_NOT(static_cast(tensor.int32_data_size()) == expected_int4_pairs, - "UnpackTensor: the pre-allocated size does not match the size in proto"); - - for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { - p_data[i] = UInt4x2(static_cast(tensor.int32_data()[i])); - } - - return Status::OK(); -} +DEFINE_INT4_UNPACK_TENSOR_IMPL(UInt4x2, TensorProto_DataType_UINT4) // UnpackTensor from raw data, external data or the type specific data field. // Uses the model path to construct the full path for loading external data. In case when model_path is empty @@ -1741,7 +1688,7 @@ template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::T case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \ TensorShape tensor_shape = GetTensorShapeFromTensorProto(initializer); \ size_t element_count = static_cast(tensor_shape.Size()); \ - size_t packed_element_count = (element_count + 1) / 2; \ + size_t packed_element_count = ELEMENT_TYPE::CalcNumInt4Pairs(element_count); \ unpacked_tensor.resize(packed_element_count * sizeof(ELEMENT_TYPE)); \ return onnxruntime::utils::UnpackTensor( \ initializer, \ From 2b9f53a49d375faa241c1e7184c1e2fc2605d0a3 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 11:43:15 -0700 Subject: [PATCH 59/72] Use CalcNumInt4Pairs() --- onnxruntime/core/framework/tensorprotoutils.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 4c3c6b373a415..ad9a32a461561 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -760,11 +760,11 @@ INSTANTIATE_UNPACK_TENSOR(UInt4x2) } \ break; -#define CASE_PROTO_TRACE_INT4(X) \ - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ - if (!IAllocator::CalcMemSizeForArrayWithAlignment((size + 1) / 2, 1, out)) { \ - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid TensorProto"); \ - } \ +#define CASE_PROTO_TRACE_INT4(X, Y) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + if (!IAllocator::CalcMemSizeForArrayWithAlignment(Y::CalcNumInt4Pairs(size), sizeof(Y), out)) { \ + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid TensorProto"); \ + } \ break; template @@ -800,8 +800,8 @@ common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& CASE_PROTO_TRACE(FLOAT8E5M2, Float8E5M2); CASE_PROTO_TRACE(FLOAT8E5M2FNUZ, Float8E5M2FNUZ); #endif - CASE_PROTO_TRACE_INT4(UINT4); - CASE_PROTO_TRACE_INT4(INT4); + CASE_PROTO_TRACE_INT4(UINT4, UInt4x2); + CASE_PROTO_TRACE_INT4(INT4, Int4x2); default: return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); } From 84c6c72895e65c8fc0034bc3d63a0a834953b715 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 11:58:33 -0700 Subject: [PATCH 60/72] Add comment reference to ONNX PR that fixes int4 q/dq onnx node tests --- onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 6aee9899283db..1885a213bdf32 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -314,6 +314,7 @@ "^test_quantizelinear_blocked_asymmetric", "^test_quantizelinear_blocked_symmetric", // Bug with test model: node's input name does not match the model's input name (x_zero_point vs zero_point) + // PR with fix: https://github.com/onnx/onnx/pull/6122 "^test_dequantizelinear_int4", "^test_dequantizelinear_uint4", "^test_quantizelinear_int4", From 7bee6f1621b1e812f1f50a190bd902facdb4f0a1 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 12:01:10 -0700 Subject: [PATCH 61/72] Add another use of CalcNumInt4Pairs() in base_tester.h --- onnxruntime/test/providers/base_tester.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/base_tester.h b/onnxruntime/test/providers/base_tester.h index e00855d1e9eac..512b3402c5986 100644 --- a/onnxruntime/test/providers/base_tester.h +++ b/onnxruntime/test/providers/base_tester.h @@ -692,8 +692,7 @@ class BaseTester { // In case values is nullptr for optional type tensor, it means we are creating // an optional type tensor which is None and we hence skip values count validation if constexpr (std::is_same_v || std::is_same_v) { - int64_t expected_values_count = shape.Size(); - expected_values_count = (expected_values_count + 1) / 2; + const int64_t expected_values_count = T::CalcNumInt4Pairs(shape.Size()); ORT_ENFORCE(expected_values_count == values_count, values_count, " input values doesn't match tensor size of ", expected_values_count); } else { From 1435222216f56bbd985bab6e6110e97ce7c5e348 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 12:21:13 -0700 Subject: [PATCH 62/72] Add another use of CalcNumInt4Pairs() in cpu quantize_linear tests --- onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 3c6c1022e4bee..5eeda5a3b8949 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -450,7 +450,7 @@ TEST(QuantizeLinearOpTest, OddLarge_Int4) { constexpr int8_t unused_val = 0; constexpr std::array pattern = {-20.0f, -14.0f, -4.1f, -0.0f, 3.0f, 3.3f}; std::vector input_f32s(static_cast(dims[0])); - std::vector output((input_f32s.size() + 1) / 2); + std::vector output(Int4x2::CalcNumInt4Pairs(input_f32s.size())); for (size_t i = 0; i < input_f32s.size(); ++i) { input_f32s[i] = pattern[i % pattern.size()]; @@ -476,7 +476,7 @@ TEST(QuantizeLinearOpTest, OddLarge_UInt4) { constexpr uint8_t unused_val = 0; constexpr std::array pattern = {-20.0f, -14.0f, -4.1f, -0.0f, 3.0f, 3.3f}; std::vector input_f32s(static_cast(dims[0])); - std::vector output((input_f32s.size() + 1) / 2); + std::vector output(UInt4x2::CalcNumInt4Pairs(input_f32s.size())); for (size_t i = 0; i < input_f32s.size(); ++i) { input_f32s[i] = pattern[i % pattern.size()]; From 5d8b02955eed6bbb052a041a8dec44b5f40f1ae5 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 13:06:34 -0700 Subject: [PATCH 63/72] Temporarily disable the block_size attribute for Q/DQ ops --- .../cpu/quantization/quantize_linear.cc | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index dceef3a300293..36f5f4d859b4e 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -20,12 +20,22 @@ class DequantizeLinear final : public OpKernel { if (!info.GetAttr("axis", &axis_).IsOK()) { axis_ = 1; } + + if (!info.GetAttr("block_size", &block_size_).IsOK()) { + block_size_ = 0; + } + + // TODO(adrianlizarraga): Support the block_size attribute added in opset 21. + if (block_size_ != 0) { + ORT_THROW("DequantizeLinear does not yet support the 'block_size' attribute."); + } } Status Compute(OpKernelContext* context) const override; private: int64_t axis_; + int64_t block_size_; }; template @@ -38,6 +48,15 @@ class QuantizeLinear final : public OpKernel { if (!info.GetAttr("saturate", &saturate_).IsOK()) { saturate_ = 1; } + + if (!info.GetAttr("block_size", &block_size_).IsOK()) { + block_size_ = 0; + } + + // TODO(adrianlizarraga): Support the block_size attribute added in opset 21. + if (block_size_ != 0) { + ORT_THROW("QuantizeLinear does not yet support the 'block_size' attribute."); + } } Status Compute(OpKernelContext* context) const override; @@ -45,6 +64,7 @@ class QuantizeLinear final : public OpKernel { private: int64_t axis_; int64_t saturate_; + int64_t block_size_; }; static void PrepareForQDQ(const TensorShape& input_shape, From ca2a1c5f5594b6bd1c807caf2f47f0842a24bbe7 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 15:45:55 -0700 Subject: [PATCH 64/72] Disable QDQ fusions for int4 --- .../selectors_actions/qdq_selectors.cc | 32 +++++++++ .../selectors_actions/qdq_selectors.h | 70 ++++++++++++------- .../optimizer/graph_transform_test_builder.h | 28 ++++++++ onnxruntime/test/optimizer/qdq_test_utils.h | 21 ++++-- .../test/optimizer/qdq_transformer_test.cc | 18 +++-- onnxruntime/test/util/compare_ortvalue.cc | 43 ++++++++++++ 6 files changed, 176 insertions(+), 36 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index d4879376b34ad..09705f61c82ce 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -139,6 +139,10 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + const Node& dq_node = *dq_nodes.front(); const Node& q_node = *q_nodes.front(); @@ -172,6 +176,10 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { return graph_viewer.GetConstantInitializer(initializer_name, true); }; @@ -198,6 +206,10 @@ bool UnaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& return false; } + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + return true; } @@ -223,6 +235,10 @@ bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + if (!allow_4bit_ && Is4BitIntType(dt_input_1)) { + return false; + } + return true; } @@ -258,6 +274,10 @@ bool VariadicNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + return true; } @@ -280,6 +300,10 @@ bool SplitNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& dq_node = *dq_nodes.front(); int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + // All Q outputs should have same data type and (optionally) equal quantization parameters as the input. for (size_t q_idx = 0; q_idx < q_nodes.size(); q_idx++) { const Node& q_node = *q_nodes[q_idx]; @@ -421,6 +445,10 @@ bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } + if (!allow_4bit_ && (Is4BitIntType(dt_A) || Is4BitIntType(dt_B))) { + return false; + } + if (dq_nodes.size() < 3) { // no bias return true; } @@ -459,6 +487,10 @@ bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& return false; } + if (!allow_4bit_ && Is4BitIntType(dt_input_1)) { + return false; + } + return true; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 5a40f0fbde595..1a2a620acb480 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -48,7 +48,8 @@ class NodeGroupSelector { // Zero point and scale are constant scalars and must match class DropQDQNodeGroupSelector : public NodeGroupSelector { public: - explicit DropQDQNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -56,12 +57,14 @@ class DropQDQNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; // Single DQ -> node. class DropDQNodeGroupSelector : public NodeGroupSelector { public: - explicit DropDQNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + explicit DropDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -69,12 +72,14 @@ class DropDQNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; // single input. default is to only support uint8. class UnaryNodeGroupSelector : public NodeGroupSelector { public: - explicit UnaryNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + explicit UnaryNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -82,12 +87,14 @@ class UnaryNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; // 2 DQ nodes providing input -> node -> Q class BinaryNodeGroupSelector : public NodeGroupSelector { public: - explicit BinaryNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + explicit BinaryNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -95,12 +102,14 @@ class BinaryNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; // Variadic DQ nodes -> node -> Q class VariadicNodeGroupSelector : public NodeGroupSelector { public: - explicit VariadicNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + explicit VariadicNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -108,6 +117,7 @@ class VariadicNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; // DQ node -> Split -> multiple Q nodes with equal quantization types. @@ -115,8 +125,8 @@ class VariadicNodeGroupSelector : public NodeGroupSelector { // equal and constant. class SplitNodeGroupSelector : public NodeGroupSelector { public: - explicit SplitNodeGroupSelector(bool req_equal_quant_params = false) - : req_equal_quant_params_(req_equal_quant_params) {} + explicit SplitNodeGroupSelector(bool req_equal_quant_params = false, bool allow_4bit = true) + : req_equal_quant_params_(req_equal_quant_params), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -126,6 +136,7 @@ class SplitNodeGroupSelector : public NodeGroupSelector { bool req_equal_quant_params_; // If true, only selects a node group if the input and output // quantization parameters are all equal/constant, which enables the // optimizer to drop the Q/DQ ops if the group is assigned to the CPU EP. + bool allow_4bit_; }; // DQ nodes for X, W and optionally B -> node -> Q @@ -147,8 +158,8 @@ class ConvNodeGroupSelector : public NodeGroupSelector { class WhereNodeGroupSelector : public NodeGroupSelector { public: - explicit WhereNodeGroupSelector(bool allow_16bit = true) - : allow_16bit_(allow_16bit) {} + explicit WhereNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -156,6 +167,7 @@ class WhereNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; class PadNodeGroupSelector : public NodeGroupSelector { @@ -196,7 +208,8 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { // Output: optional Q node for Y class GemmNodeGroupSelector : public NodeGroupSelector { public: - explicit GemmNodeGroupSelector(bool allow_16bit = true) : allow_16bit_(allow_16bit) {} + explicit GemmNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -204,6 +217,7 @@ class GemmNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; bool allow_16bit_; + bool allow_4bit_; }; // Input: DQ nodes for input, scale, and B @@ -278,33 +292,35 @@ class BaseSelector : public NodeSelector { class DropQDQNodesSelector : public BaseSelector { public: - explicit DropQDQNodesSelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {} }; class DropDQNodesSelector : public BaseSelector { public: - explicit DropDQNodesSelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit DropDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {} }; class UnarySelector : public BaseSelector { public: - explicit UnarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} + explicit UnarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false, + bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit), compatible_providers) {} }; class BinarySelector : public BaseSelector { public: - explicit BinarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} + explicit BinarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false, + bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit), compatible_providers) {} }; // Variadic DQ nodes -> node -> Q class InputVariadicSelector : public BaseSelector { public: - explicit InputVariadicSelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit InputVariadicSelector(bool allow_16bit = false, bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; @@ -312,8 +328,8 @@ class InputVariadicSelector : public BaseSelector { // DQ -> Split -> variadic Q nodes class SplitSelector : public BaseSelector { public: - SplitSelector(bool req_equal_quant_params = false) - : BaseSelector(std::make_unique(req_equal_quant_params)) {} + SplitSelector(bool req_equal_quant_params = false, bool allow_4bit = false) + : BaseSelector(std::make_unique(req_equal_quant_params, allow_4bit)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; @@ -329,8 +345,9 @@ class ConvSelector : public BaseSelector { class WhereSelector : public BaseSelector { public: - explicit WhereSelector(gsl::span compatible_providers = {}, bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} + explicit WhereSelector(gsl::span compatible_providers = {}, bool allow_16bit = false, + bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit), compatible_providers) {} }; // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not @@ -345,8 +362,9 @@ class MatMulSelector : public BaseSelector { // Output: optional Q node for Y class GemmSelector : public BaseSelector { public: - explicit GemmSelector(gsl::span compatible_providers = {}, bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} + explicit GemmSelector(gsl::span compatible_providers = {}, bool allow_16bit = false, + bool allow_4bit = false) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit), compatible_providers) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 57f10d9a4eb69..4d4d576930f10 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -9,6 +9,7 @@ #include "core/common/type_utils.h" #include "core/graph/graph.h" #include "core/framework/framework_common.h" +#include "core/framework/int4.h" #include "core/optimizer/graph_transformer_level.h" #include "core/graph/onnx_protobuf.h" #include "test/framework/test_utils.h" @@ -45,6 +46,12 @@ struct IsTypeQuantLinearCompatible : std::true_type {}; template <> struct IsTypeQuantLinearCompatible : std::true_type {}; +template <> +struct IsTypeQuantLinearCompatible : std::true_type {}; + +template <> +struct IsTypeQuantLinearCompatible : std::true_type {}; + template struct IsTypeDequantLinearCompatible : utils::IsByteType {}; @@ -57,6 +64,12 @@ struct IsTypeDequantLinearCompatible : std::true_type {}; template <> struct IsTypeDequantLinearCompatible : std::true_type {}; +template <> +struct IsTypeDequantLinearCompatible : std::true_type {}; + +template <> +struct IsTypeDequantLinearCompatible : std::true_type {}; + class ModelTestBuilder { public: ModelTestBuilder(Graph& graph) : graph_(graph) { @@ -102,6 +115,21 @@ class ModelTestBuilder { return MakeInput(shape, data); } + template + typename std::enable_if< + std::is_same_v || std::is_same_v, + NodeArg*>::type + MakeInputInt4(const std::vector& shape, typename TInt4::UnpackedType min, typename TInt4::UnpackedType max) { + std::vector data_int8 = rand_gen_.Uniform(shape, min, max); + std::vector data(TInt4::CalcNumInt4Pairs(data_int8.size())); + for (size_t i = 0; i < data_int8.size(); i++) { + size_t r = i >> 1; + size_t c = i & 0x1; + data[r].SetElem(c, data_int8[i]); + } + return MakeInput(shape, data); + } + template NodeArg* MakeInput(const std::optional>& shape, std::optional input_name = std::nullopt) { diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 414a0fbeb78f5..e94f296bc10c9 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -5,9 +5,11 @@ #include #include +#include #include "graph_transform_test_builder.h" +#include "core/framework/int4.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/session/inference_session.h" @@ -488,12 +490,21 @@ GetQDQTestCaseFn BuildQDQSplitTestCase(const std::vector& input_shape, bool use_diff_output_scale, bool use_contrib_qdq = false) { return [input_shape, axis, use_diff_output_scale, use_contrib_qdq](ModelTestBuilder& builder) { - auto* input_arg = builder.MakeInput(input_shape, - std::numeric_limits::min(), - std::numeric_limits::max()); + InputType dq_zp{}; + OutputType q_zp{}; + NodeArg* input_arg = nullptr; + + if constexpr (std::is_same_v || std::is_same_v) { + input_arg = builder.MakeInputInt4(input_shape, InputType::min_val, InputType::max_val); + dq_zp = InputType(static_cast(InputType::max_val / 2)); + q_zp = OutputType(static_cast(OutputType::max_val / 2)); + } else { + input_arg = builder.MakeInput(input_shape, std::numeric_limits::min(), + std::numeric_limits::max()); + dq_zp = std::numeric_limits::max() / 2; + q_zp = std::numeric_limits::max() / 2; + } - InputType dq_zp = std::numeric_limits::max() / 2; - OutputType q_zp = std::numeric_limits::max() / 2; auto* dq_output = builder.MakeIntermediate(); constexpr float input_scale = 0.003f; builder.AddDequantizeLinearNode(input_arg, input_scale, dq_zp, dq_output, use_contrib_qdq); diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index ae263a7ca7d35..924d8a46e14c0 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -4,6 +4,7 @@ #include #include "core/framework/compute_capability.h" #include "core/framework/node_unit.h" +#include "core/framework/int4.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas.h" @@ -1236,19 +1237,21 @@ TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) { // Runs a test that checks if DQ -> Split -> Q (many) is replaced with just Split. template static void RunDropSplitQDQTestCase(const std::vector& input_shape, int64_t axis, - bool all_same_quant_params, bool use_contrib_qdq = false) { - auto check_graph = [all_same_quant_params, use_contrib_qdq](InferenceSessionWrapper& session) { + bool all_same_quant_params, bool use_contrib_qdq = false, + bool should_not_drop = false) { + auto check_graph = [all_same_quant_params, use_contrib_qdq, should_not_drop](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - int expected_q_ops = all_same_quant_params ? 0 : 3; - int expected_dq_ops = all_same_quant_params ? 0 : 1; + int expected_q_ops = all_same_quant_params && !should_not_drop ? 0 : 3; + int expected_dq_ops = all_same_quant_params && !should_not_drop ? 0 : 1; EXPECT_EQ(op_to_count["Split"], 1); EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_q_ops); EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_dq_ops); }; std::vector opsets = {12, 13, 18, 19, 21}; - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { opsets = std::vector{21}; } @@ -1276,6 +1279,11 @@ TEST(QDQTransformerTests, Split) { RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, !USE_CONTRIB_QDQ_OPS); RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, !USE_CONTRIB_QDQ_OPS); + + // Do not yet support int4 Split, so should not drop + constexpr bool SHOULD_NOT_DROP = true; + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, !USE_CONTRIB_QDQ_OPS, SHOULD_NOT_DROP); + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, !USE_CONTRIB_QDQ_OPS, SHOULD_NOT_DROP); } // Test cases that DO NOT drop Q/DQ ops from DQ -> Split -> Q (many) diff --git a/onnxruntime/test/util/compare_ortvalue.cc b/onnxruntime/test/util/compare_ortvalue.cc index 64ebe24188762..cc4c0440d26d9 100644 --- a/onnxruntime/test/util/compare_ortvalue.cc +++ b/onnxruntime/test/util/compare_ortvalue.cc @@ -29,6 +29,7 @@ #pragma GCC diagnostic pop #endif +#include "core/framework/int4.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/framework/TensorSeq.h" @@ -202,6 +203,44 @@ std::pair IsResultExactlyMatch(const Tensor& outval return std::make_pair(COMPARE_RESULT::SUCCESS, ""); } +template <> +std::pair IsResultExactlyMatch(const Tensor& outvalue, + const Tensor& expected_value) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const Int4x2* expected_output = expected_value.Data(); + const Int4x2* real_output = outvalue.Data(); + for (size_t di = 0; di != size1; ++di) { + size_t r = di >> 1; + size_t c = di & 0x1; + + if (expected_output[r].GetElem(c) != real_output[r].GetElem(c)) { + std::ostringstream oss; + oss << "expected " << expected_output[r].GetElem(c) << ", got " << real_output[r].GetElem(c); + return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + } + } + return std::make_pair(COMPARE_RESULT::SUCCESS, ""); +} + +template <> +std::pair IsResultExactlyMatch(const Tensor& outvalue, + const Tensor& expected_value) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const UInt4x2* expected_output = expected_value.Data(); + const UInt4x2* real_output = outvalue.Data(); + for (size_t di = 0; di != size1; ++di) { + size_t r = di >> 1; + size_t c = di & 0x1; + + if (expected_output[r].GetElem(c) != real_output[r].GetElem(c)) { + std::ostringstream oss; + oss << "expected " << expected_output[r].GetElem(c) << ", got " << real_output[r].GetElem(c); + return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + } + } + return std::make_pair(COMPARE_RESULT::SUCCESS, ""); +} + std::pair CompareFloat16Result(const Tensor& outvalue, const Tensor& expected_value, double per_sample_tolerance, double relative_per_sample_tolerance, @@ -313,6 +352,10 @@ std::pair CompareTwoTensors(const Tensor& outvalue, return IsResultExactlyMatch(outvalue, expected_tensor); } else if (outvalue.IsDataType()) { return IsResultExactlyMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return IsResultExactlyMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return IsResultExactlyMatch(outvalue, expected_tensor); } else if (outvalue.IsDataType()) { return CompareFloat16Result(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, post_processing); From 0efc3523399df28029f5ca0a466fdd3349d8961b Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 16:00:24 -0700 Subject: [PATCH 65/72] Add typename --- onnxruntime/test/optimizer/graph_transform_test_builder.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 4d4d576930f10..5d9488aef68bd 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -120,7 +120,8 @@ class ModelTestBuilder { std::is_same_v || std::is_same_v, NodeArg*>::type MakeInputInt4(const std::vector& shape, typename TInt4::UnpackedType min, typename TInt4::UnpackedType max) { - std::vector data_int8 = rand_gen_.Uniform(shape, min, max); + using UnpackedType = typename TInt4::UnpackedType; + std::vector data_int8 = rand_gen_.Uniform(shape, min, max); std::vector data(TInt4::CalcNumInt4Pairs(data_int8.size())); for (size_t i = 0; i < data_int8.size(); i++) { size_t r = i >> 1; From 7452329cf37d29aa939877c50afcc66e04b7b32a Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 7 May 2024 16:46:41 -0700 Subject: [PATCH 66/72] Add python quantization unit test for int4 qdq --- .../test/python/quantization/test_qdq.py | 36 +++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index db4ab7e8a412c..89267c3320ed8 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -522,7 +522,9 @@ def setUpClass(cls): def tearDownClass(cls): cls._tmp_model_dir.cleanup() - def construct_model_conv_relu(self, output_model_path, input_shape, weight_shape, output_shape): + def construct_model_conv_relu( + self, output_model_path, input_shape, weight_shape, output_shape, opset=13, ir_version=7 + ): # (input) # | # Conv @@ -557,19 +559,31 @@ def construct_model_conv_relu(self, output_model_path, input_shape, weight_shape [output_tensor], initializer=initializers, ) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", opset)]) + model.ir_version = ir_version onnx.save(model, output_model_path) - def verify_qdq(self, per_channel, activation_type, weight_type, extra_options=None): + def verify_qdq( + self, + per_channel, + activation_type, + weight_type, + extra_options=None, + opset=13, + ir_version=7, + rtol=1e-2, + atol=0.05, + ): np.random.seed(1) model_fp32_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_fp32.{per_channel}.onnx") model_qdq_path = str( Path(self._tmp_model_dir.name) / f"conv_relu_quant_qdq.{activation_type}.{weight_type}.{per_channel}.onnx" ) data_reader = self.input_feeds(1, {"input": [1, 8, 33, 33]}) - self.construct_model_conv_relu(model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 31, 31]) + self.construct_model_conv_relu( + model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 31, 31], opset=opset, ir_version=ir_version + ) quantize_static( model_fp32_path, model_qdq_path, @@ -595,7 +609,7 @@ def verify_qdq(self, per_channel, activation_type, weight_type, extra_options=No "DequantizeLinear", ], ) - check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) + check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next(), rtol=rtol, atol=atol) # If the model uses Q/DQ ops with "com.microsoft" domain (e.g., for int16 support), # then ensure the model has the appropriate opset import. @@ -648,6 +662,16 @@ def test_quantize_conv_without_bias(self): self.verify_qdq(True, QuantType.QUInt16, QuantType.QUInt8, {"UseQDQContribOps": True}) self.verify_qdq(True, QuantType.QInt16, QuantType.QInt8, {"UseQDQContribOps": True}) + # 4-bit QDQ + self.verify_qdq(False, QuantType.QInt16, QuantType.QInt4, opset=21, ir_version=10, atol=0.4) # per-tensor + self.verify_qdq(True, QuantType.QInt16, QuantType.QInt4, opset=21, ir_version=10) # per-channel + self.verify_qdq( + False, QuantType.QInt16, QuantType.QInt4, {"UseQDQContribOps": True}, opset=21, ir_version=10, atol=0.4 + ) # per-tensor + self.verify_qdq( + True, QuantType.QInt16, QuantType.QInt4, {"UseQDQContribOps": True}, opset=21, ir_version=10 + ) # per-channel + def test_quantize_relu_conv(self): float_model_path = str(Path(self._tmp_model_dir.name) / "float_relu_convs_model.onnx") construct_relu_conv_model(float_model_path) From 6c06bfbf292fa7fd0f713fdb86188536010b0942 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 30 May 2024 01:29:38 -0700 Subject: [PATCH 67/72] Review comments --- include/onnxruntime/core/framework/int4.h | 48 ++--- include/onnxruntime/core/framework/tensor.h | 44 ++-- onnxruntime/core/framework/tensor.cc | 38 ++-- .../core/framework/tensorprotoutils.cc | 26 +-- .../cpu/quantization/quantize_linear.cc | 193 +++++++++--------- onnxruntime/core/util/qmath.h | 6 +- onnxruntime/test/onnx/tensorprotoutils.cc | 2 +- onnxruntime/test/optimizer/qdq_test_utils.h | 4 +- 8 files changed, 170 insertions(+), 191 deletions(-) diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h index 89ffbb96231bc..228c1e4e872de 100644 --- a/include/onnxruntime/core/framework/int4.h +++ b/include/onnxruntime/core/framework/int4.h @@ -37,64 +37,46 @@ struct Int4x2Base { static constexpr UnpackedType min_val = Int4Traits::min_val; static constexpr UnpackedType max_val = Int4Traits::max_val; - uint8_t bits_{}; + std::byte bits_{}; Int4x2Base() = default; - explicit Int4x2Base(uint8_t bits) { + explicit Int4x2Base(std::byte bits) { bits_ = bits; } Int4x2Base(UnpackedType val0, UnpackedType val1) { - bits_ = static_cast(((val1 & 0xF) << 4) | (val0 & 0xF)); + bits_ = static_cast(((val1 & 0xF) << 4) | (val0 & 0xF)); } - static inline int8_t SignExtendLower4Bits(uint8_t bits) { + static inline int8_t SignExtendLower4Bits(std::byte bits) { // Sign-extend lower 4-bits by left shifting and then doing an arithmetic right shift. - constexpr uint8_t Shift = (sizeof(int32_t) * 8) - 4; - return static_cast((static_cast(bits) << Shift) >> Shift); - } - - inline UnpackedType GetElem0() const { - if constexpr (Signed) { - return SignExtendLower4Bits(bits_); - } else { - return static_cast(bits_ & 0xF); - } - } - - inline UnpackedType GetElem1() const { - const uint8_t val = static_cast((bits_ >> 4) & 0xF); - - if constexpr (Signed) { - return SignExtendLower4Bits(val); - } else { - return val; - } + constexpr uint8_t shift = (sizeof(int32_t) * 8) - 4; + return static_cast((static_cast(bits) << shift) >> shift); } inline UnpackedType GetElem(size_t index) const { assert(index <= 1); const uint8_t shift = 4 * static_cast(index); - const uint8_t val = static_cast((bits_ >> shift) & 0xF); + const std::byte val = (bits_ >> shift) & std::byte{0xF}; if constexpr (Signed) { return SignExtendLower4Bits(val); } else { - return val; + return static_cast(val); } } inline void SetElem(size_t index, UnpackedType val) { assert(index <= 1); const uint8_t shift = 4 * static_cast(index); - const uint8_t mask = 0xF << shift; + const std::byte mask = std::byte{0xF0} >> shift; - bits_ &= ~mask; // Clear 4-bit element to 0 - bits_ |= static_cast((val & 0xF) << shift); // Set 4-bit element to val + bits_ &= mask; // Clear 4-bit element to 0 + bits_ |= static_cast((val & 0xF) << shift); // Set 4-bit element to val } - inline uint8_t ToBits() const { + inline std::byte ToBits() const { return bits_; } @@ -117,7 +99,7 @@ struct Int4x2Base { } static bool Pack(gsl::span> dst, gsl::span src) { - if (CalcNumInt4Pairs(src.size()) != dst.size()) { + if (src.empty() || (CalcNumInt4Pairs(src.size()) != dst.size())) { return false; } @@ -138,6 +120,6 @@ struct Int4x2Base { using Int4x2 = Int4x2Base; using UInt4x2 = Int4x2Base; -static_assert(sizeof(Int4x2) == sizeof(uint8_t)); -static_assert(sizeof(UInt4x2) == sizeof(uint8_t)); +static_assert(sizeof(Int4x2) == sizeof(std::byte)); +static_assert(sizeof(UInt4x2) == sizeof(std::byte)); } // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index 6804b340f0ac6..96725aa103064 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -156,29 +156,6 @@ class Tensor final { /// Status indicating success or failure. static Status CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape, size_t alignment, size_t& storage_size); - - /// - /// Get the number of elements for a Tensor of the given element type and shape. - /// For element types smaller than 1 byte (e.g., int4), a single Tensor element stores multiple sub-byte elements. - /// So, this function returns the number of Tensor elements, each of which may contain multiple sub-byte elements. - /// - /// Data type of the tensor elements. - /// Tensor shape. - /// Number of Tensor elements. Returns -1 if shape has negative dims. - static inline int64_t GetNumTensorElems(MLDataType elt_type, const TensorShape& shape) { - return GetNumTensorElems(elt_type, shape.Size()); - } - - /// - /// Get the number of elements for a Tensor of the given element type and shape size. - /// For element types smaller than 1 byte (e.g., int4), a single Tensor element stores multiple sub-byte elements. - /// So, this function returns the number of Tensor elements, each of which may contain multiple sub-byte elements. - /// - /// Data type of the tensor elements. - /// The number of elements indicated by the shape (i.e., shape.Size()). - /// Number of Tensor elements. Returns -1 if shape_size is negative. - static int64_t GetNumTensorElems(MLDataType elt_type, int64_t shape_size); - /** Returns the data type. */ @@ -234,7 +211,7 @@ class Tensor final { ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", "T ", "!=", dtype_); T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - return gsl::make_span(data, static_cast(NumElements())); + return gsl::make_span(data, static_cast(NumStorageElements())); } template @@ -251,7 +228,7 @@ class Tensor final { ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", "T ", "!=", dtype_); const T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - return gsl::make_span(data, static_cast::size_type>(NumElements())); + return gsl::make_span(data, static_cast::size_type>(NumStorageElements())); } void* MutableDataRaw(MLDataType type) { @@ -305,11 +282,18 @@ class Tensor final { byte_offset_ = byte_offset; } - /** - The number of Tensor elements. A single Tensor element may contain multiple sub-elements for - subbyte data types (e.g., int4). - */ - int64_t NumElements() const; + /// + /// The number of Tensor "storage" elements. A single storage element may contain multiple sub-elements for + /// sub-byte data types (e.g., int4). + /// + /// For element types smaller than 1 byte (e.g., int4), a single storage element stores multiple sub-byte elements. + /// Example: Tensor of shape (4,) has 2 storage elements. + /// + /// For element types >= 1 byte, this function returns the product of the shape. + /// Example: Tensor of shape (4,) has 4 storage elements. + /// + /// Number of tensor storage elements + int64_t NumStorageElements() const; /** The number of bytes of data. diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index 68cf73955de24..60d768cc59a5d 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -27,7 +27,19 @@ int64_t GetSizeFromStrides(const TensorShape& shape, gsl::span st } // namespace #endif -int64_t Tensor::GetNumTensorElems(MLDataType elt_type, int64_t shape_size) { +/// +/// Get the number of elements for a Tensor of the given element type and shape size. +/// +/// For element types smaller than 1 byte (e.g., int4), a single storage element stores multiple sub-byte elements. +/// Example: Tensor of shape_size 4 has 2 storage elements. +/// +/// For element types >= 1 byte, this function returns the product of the shape. +/// Example: Tensor of shape_size 4 has 4 storage elements. +/// +/// Data type of the tensor elements. +/// The number of elements indicated by the shape (i.e., shape.Size()). +/// Number of Tensor elements. Returns -1 if shape_size is negative. +static int64_t GetNumTensorStorageElems(MLDataType elt_type, int64_t shape_size) { int64_t num_elems = shape_size; auto prim_type = elt_type->AsPrimitiveDataType(); @@ -41,7 +53,7 @@ int64_t Tensor::GetNumTensorElems(MLDataType elt_type, int64_t shape_size) { Status Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape, size_t alignment, /*out*/ size_t& storage_size) { - int64_t num_elems = GetNumTensorElems(elt_type, shape.Size()); + int64_t num_elems = GetNumTensorStorageElems(elt_type, shape.Size()); ORT_RETURN_IF(num_elems < 0, "Tensor shape.Size() must be >= 0"); if (num_elems > 0) { @@ -120,27 +132,19 @@ void Tensor::InitOrtValue(Tensor&& tensor, OrtValue& ort_value) { ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } -int64_t Tensor::NumElements() const { - int64_t num_elems = shape_.Size(); - - if (dtype_ != nullptr && num_elems > 0 && dtype_->HasSubElems()) { - const int64_t num_sub_elems = dtype_->GetNumSubElems(); - num_elems = (num_elems + (num_sub_elems - 1)) / num_sub_elems; - } - - return num_elems; -} - -size_t Tensor::SizeInBytes() const { +int64_t Tensor::NumStorageElements() const { #ifdef ENABLE_STRIDED_TENSORS int64_t size = IsContiguous() ? shape_.Size() : GetSizeFromStrides(shape_, strides_); #else int64_t size = shape_.Size(); #endif - size_t ret = 0; - const int64_t num_elems = GetNumTensorElems(dtype_, size); - if (!IAllocator::CalcMemSizeForArray(SafeInt(num_elems), dtype_->Size(), &ret)) { + return GetNumTensorStorageElems(dtype_, size); +} + +size_t Tensor::SizeInBytes() const { + size_t ret = 0; + if (!IAllocator::CalcMemSizeForArray(SafeInt(NumStorageElements()), dtype_->Size(), &ret)) { ORT_THROW("tensor size overflow"); } return ret; diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index ad9a32a461561..6af78f18fb82f 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -59,18 +59,18 @@ TensorProto ToScalarTensor(TensorProto_DataType datatype, int32_t value) { return t; \ } -#define TO_TENSOR_ORT_TYPE_INT4(TYPE) \ - template <> \ - TensorProto ToTensor(const onnxruntime::TYPE& value) { \ - return ToScalarTensor(ToTensorProtoElementType(), value.ToBits()); \ - } \ - template <> \ - TensorProto ToTensor(const std::vector& values) { \ - TensorProto t = ToTensorInitialize(ToTensorProtoElementType()); \ - for (const onnxruntime::TYPE& val : values) { \ - t.add_int32_data(val.ToBits()); \ - } \ - return t; \ +#define TO_TENSOR_ORT_TYPE_INT4(TYPE) \ + template <> \ + TensorProto ToTensor(const onnxruntime::TYPE& value) { \ + return ToScalarTensor(ToTensorProtoElementType(), static_cast(value.ToBits())); \ + } \ + template <> \ + TensorProto ToTensor(const std::vector& values) { \ + TensorProto t = ToTensorInitialize(ToTensorProtoElementType()); \ + for (const onnxruntime::TYPE& val : values) { \ + t.add_int32_data(static_cast(val.ToBits())); \ + } \ + return t; \ } namespace ONNX_NAMESPACE { @@ -689,7 +689,7 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d "UnpackTensor: the pre-allocated size does not match the size in proto"); \ \ for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { \ - p_data[i] = INT4_TYPE(static_cast(tensor.int32_data()[i])); \ + p_data[i] = INT4_TYPE(static_cast(tensor.int32_data()[i])); \ } \ \ return Status::OK(); \ diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index 36f5f4d859b4e..05dea2a05c97b 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -71,30 +71,31 @@ static void PrepareForQDQ(const TensorShape& input_shape, const Tensor& scale, const Tensor* zero_point_ptr, int64_t axis, - int64_t& block_count, - int64_t& broadcast_dim, - int64_t& block_size) { + int64_t& quant_block_count, // A "quant block" is a block of elems with the same scale/zp + int64_t& axis_dim_val, + int64_t& quant_block_size) { if (IsScalarOr1ElementVector(&scale)) { // per-tensor QuantizeLinear/DequantizeLinear - block_count = 1; - broadcast_dim = 1; - block_size = static_cast(input_shape.Size()); + quant_block_count = 1; + axis_dim_val = 1; + quant_block_size = static_cast(input_shape.Size()); // enforce that zero point are scalars ORT_ENFORCE(zero_point_ptr == nullptr || IsScalarOr1ElementVector(zero_point_ptr), "x_zero_point must be null or a scalar or 1D tensor or size 1."); } else { // per-channel QuantizeLinear/DequantizeLinear const int64_t axis_no_neg = HandleNegativeAxis(axis, input_shape.NumDimensions()); - block_count = input_shape.SizeToDimension(onnxruntime::narrow(axis_no_neg)); - broadcast_dim = input_shape[onnxruntime::narrow(axis_no_neg)]; - block_size = input_shape.SizeFromDimension(SafeInt(axis_no_neg) + 1); + quant_block_count = input_shape.SizeToDimension(onnxruntime::narrow(axis_no_neg)); + axis_dim_val = input_shape[onnxruntime::narrow(axis_no_neg)]; + quant_block_size = input_shape.SizeFromDimension(SafeInt(axis_no_neg) + 1); // if an axis was specified, ensure the scale and zero point are compatible - ORT_ENFORCE(scale.Shape().NumDimensions() == 1 && scale.Shape()[0] == broadcast_dim, + ORT_ENFORCE(scale.Shape().NumDimensions() == 1 && scale.Shape()[0] == axis_dim_val, "scale must be 1D tensor with size ", - broadcast_dim); - ORT_ENFORCE(zero_point_ptr == nullptr || (zero_point_ptr->Shape().NumDimensions() == 1 && zero_point_ptr->Shape()[0] == broadcast_dim), + axis_dim_val); + ORT_ENFORCE(zero_point_ptr == nullptr || + (zero_point_ptr->Shape().NumDimensions() == 1 && zero_point_ptr->Shape()[0] == axis_dim_val), "x_zero_point must be null or 1D tensor with size ", - broadcast_dim); + axis_dim_val); } } @@ -245,12 +246,13 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( template struct DequantizeLinearApply { - void op(int64_t N, int64_t broadcast_dim, int64_t block_size, const T* input, const OutT* scale, OutT* output, const T* zero_point) { + void op(int64_t N, int64_t axis_dim_val, int64_t quant_block_size, const T* input, const OutT* scale, OutT* output, + const T* zero_point) { for (size_t n = 0; n < static_cast(N); n++) { - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { + for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { auto zp = zero_point ? static_cast(zero_point[bd]) : 0; auto sc = static_cast(scale[bd]); - for (size_t bs = 0; bs < static_cast(block_size); bs++) { + for (size_t bs = 0; bs < static_cast(quant_block_size); bs++) { *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); } } @@ -258,29 +260,29 @@ struct DequantizeLinearApply { } }; -#define DEQUANTIZE_LINEAR_APPLY_INT4(T) \ - template \ - struct DequantizeLinearApply { \ - void op(int64_t N, int64_t broadcast_dim, int64_t block_size, const T* input, const OutT* scale, \ - OutT* output, const T* zero_point) { \ - size_t input_index = 0; \ - for (size_t n = 0; n < static_cast(N); n++) { \ - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ - size_t bd_i = bd >> 1; /*bd / 2*/ \ - size_t bd_j = bd & 0x1; /*bd % 2*/ \ - auto zp = zero_point ? static_cast(zero_point[bd_i].GetElem(bd_j)) : 0; \ - auto sc = static_cast(scale[bd]); \ - for (size_t bs = 0; bs < static_cast(block_size); bs++) { \ - size_t input_i = input_index >> 1; \ - size_t input_j = input_index & 0x1; \ - int32_t val = static_cast(input[input_i].GetElem(input_j)); \ - *output++ = static_cast(static_cast(val - zp) * sc); \ - input_index += 1; \ - } \ - } \ - } \ - assert(input_index == static_cast(N * broadcast_dim * block_size)); \ - } \ +#define DEQUANTIZE_LINEAR_APPLY_INT4(T) \ + template \ + struct DequantizeLinearApply { \ + void op(int64_t N, int64_t axis_dim_val, int64_t quant_block_size, const T* input, const OutT* scale, \ + OutT* output, const T* zero_point) { \ + size_t input_index = 0; \ + for (size_t n = 0; n < static_cast(N); n++) { \ + for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { \ + size_t bd_i = bd >> 1; /*bd / 2*/ \ + size_t bd_j = bd & 0x1; /*bd % 2*/ \ + auto zp = zero_point ? static_cast(zero_point[bd_i].GetElem(bd_j)) : 0; \ + auto sc = static_cast(scale[bd]); \ + for (size_t bs = 0; bs < static_cast(quant_block_size); bs++) { \ + size_t input_i = input_index >> 1; \ + size_t input_j = input_index & 0x1; \ + int32_t val = static_cast(input[input_i].GetElem(input_j)); \ + *output++ = static_cast(static_cast(val - zp) * sc); \ + input_index += 1; \ + } \ + } \ + } \ + assert(input_index == static_cast(N * axis_dim_val * quant_block_size)); \ + } \ }; DEQUANTIZE_LINEAR_APPLY_INT4(Int4x2); @@ -288,19 +290,20 @@ DEQUANTIZE_LINEAR_APPLY_INT4(UInt4x2); #if !defined(DISABLE_FLOAT8_TYPES) -#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ - template \ - struct DequantizeLinearApply { \ - void op(int64_t N, int64_t broadcast_dim, int64_t block_size, const T* input, const OutT* scale, OutT* output, const T*) { \ - for (size_t n = 0; n < static_cast(N); n++) { \ - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ - auto sc = scale[bd]; \ - for (size_t bs = 0; bs < static_cast(block_size); bs++, input++) { \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - } \ - } \ +#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ + template \ + struct DequantizeLinearApply { \ + void op(int64_t N, int64_t axis_dim_val, int64_t quant_block_size, const T* input, const OutT* scale, \ + OutT* output, const T*) { \ + for (size_t n = 0; n < static_cast(N); n++) { \ + for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { \ + auto sc = scale[bd]; \ + for (size_t bs = 0; bs < static_cast(quant_block_size); bs++, input++) { \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + } \ + } \ }; DEQUANTIZE_LINEAR_APPLY_FLOAT8(Float8E4M3FN) @@ -321,10 +324,10 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { auto& y = *ctx->Output(0, x_shape); int64_t N; - int64_t broadcast_dim; - int64_t block_size; + int64_t axis_dim_val; + int64_t quant_block_size; - PrepareForQDQ(x.Shape(), x_scale, x_zero_point, axis_, N, broadcast_dim, block_size); + PrepareForQDQ(x.Shape(), x_scale, x_zero_point, axis_, N, axis_dim_val, quant_block_size); const T* zero_point = x_zero_point ? x_zero_point->Data() : nullptr; @@ -346,11 +349,11 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { if (to == ONNX_NAMESPACE::TensorProto::FLOAT) { const float* scale = x_scale.Data(); float* output = y.MutableData(); - DequantizeLinearApply().op(N, broadcast_dim, block_size, input, scale, output, zero_point); + DequantizeLinearApply().op(N, axis_dim_val, quant_block_size, input, scale, output, zero_point); } else if (to == ONNX_NAMESPACE::TensorProto::FLOAT16) { const MLFloat16* scale = x_scale.Data(); MLFloat16* output = y.MutableData(); - DequantizeLinearApply().op(N, broadcast_dim, block_size, input, scale, output, zero_point); + DequantizeLinearApply().op(N, axis_dim_val, quant_block_size, input, scale, output, zero_point); } else if (to == ONNX_NAMESPACE::TensorProto::BFLOAT16) { ORT_THROW("DequantizeLinear into BFLOAT16 is not implemented yet."); } else { @@ -512,42 +515,46 @@ void ParQuantizeLinear(const InputType* Input, ParQuantizeLinearStd(Input, Output, N, Scale, ZeroPoint != nullptr ? ZeroPoint[bd] : (OutputType)0, thread_pool); #if !defined(DISABLE_FLOAT8_TYPES) } else { - ParQuantizeLinearSat(Input, Output, N, Scale, ZeroPoint != nullptr ? ZeroPoint[bd] : OutputType(static_cast(static_cast(0)), true), saturate, thread_pool); + ParQuantizeLinearSat(Input, Output, N, Scale, + ZeroPoint != nullptr ? ZeroPoint[bd] + : OutputType(static_cast(static_cast(0)), true), + saturate, thread_pool); } #endif } template void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const T* zero_point, T* output, int64_t N, - int64_t broadcast_dim, int64_t block_size, bool saturate) { + int64_t axis_dim_val, int64_t quant_block_size, bool saturate) { for (size_t n = 0; n < static_cast(N); n++) { - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { - ParQuantizeLinear(input, output, static_cast(block_size), scale[bd], bd, zero_point, saturate, ctx->GetOperatorThreadPool()); - input += block_size; - output += block_size; + for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { + ParQuantizeLinear(input, output, static_cast(quant_block_size), scale[bd], bd, zero_point, saturate, + ctx->GetOperatorThreadPool()); + input += quant_block_size; + output += quant_block_size; } } } // Quantizes float32 to INT4 (in-place) using MLAS kernel. -#define DEFINE_COMPUTE_LOOP_FP32_TO_INT4(INT4_TYPE, QUANT_FUNC) \ - template <> \ - void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const INT4_TYPE* zero_point, \ - INT4_TYPE* output, int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) { \ - ORT_UNUSED_PARAMETER(saturate); \ - size_t output_index = 0; \ - for (size_t n = 0; n < static_cast(N); n++) { \ - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ - size_t bd_i = bd >> 1; /*bd / 2*/ \ - size_t bd_j = bd & 0x1; /*bd % 2*/ \ - INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ - QUANT_FUNC(input, output, output_index, output_index + static_cast(block_size), \ - scale[bd], INT4_TYPE(zp, 0), ctx->GetOperatorThreadPool()); \ - input += block_size; \ - output_index += static_cast(block_size); \ - } \ - } \ - assert(output_index == static_cast(N * broadcast_dim * block_size)); \ +#define DEFINE_COMPUTE_LOOP_FP32_TO_INT4(INT4_TYPE, QUANT_FUNC) \ + template <> \ + void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const INT4_TYPE* zero_point, \ + INT4_TYPE* output, int64_t N, int64_t axis_dim_val, int64_t quant_block_size, bool saturate) { \ + ORT_UNUSED_PARAMETER(saturate); \ + size_t output_index = 0; \ + for (size_t n = 0; n < static_cast(N); n++) { \ + for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { \ + size_t bd_i = bd >> 1; /*bd / 2*/ \ + size_t bd_j = bd & 0x1; /*bd % 2*/ \ + INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ + QUANT_FUNC(input, output, output_index, output_index + static_cast(quant_block_size), \ + scale[bd], INT4_TYPE(zp, 0), ctx->GetOperatorThreadPool()); \ + input += quant_block_size; \ + output_index += static_cast(quant_block_size); \ + } \ + } \ + assert(output_index == static_cast(N * axis_dim_val * quant_block_size)); \ } DEFINE_COMPUTE_LOOP_FP32_TO_INT4(Int4x2, ParQuantizeLinearStdS4) @@ -560,23 +567,23 @@ DEFINE_COMPUTE_LOOP_FP32_TO_INT4(UInt4x2, ParQuantizeLinearStdU4) template <> \ void ComputeLoop(OpKernelContext * ctx, const MLFloat16* input, const MLFloat16* scale, \ const INT4_TYPE* zero_point, INT4_TYPE* output, int64_t N, \ - int64_t broadcast_dim, int64_t block_size, bool saturate) { \ + int64_t axis_dim_val, int64_t quant_block_size, bool saturate) { \ ORT_UNUSED_PARAMETER(saturate); \ \ - size_t total_size = static_cast(N * broadcast_dim * block_size); \ + size_t total_size = static_cast(N * axis_dim_val * quant_block_size); \ auto tmp_buf = std::make_unique(total_size); \ size_t tmp_buf_index = 0; \ \ for (size_t n = 0; n < static_cast(N); n++) { \ - for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { \ + for (size_t bd = 0; bd < static_cast(axis_dim_val); bd++) { \ size_t bd_i = bd >> 1; /*bd / 2*/ \ size_t bd_j = bd & 0x1; /*bd % 2*/ \ INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, \ - static_cast(block_size), scale[bd], \ + static_cast(quant_block_size), scale[bd], \ zp, ctx->GetOperatorThreadPool()); \ - input += block_size; \ - tmp_buf_index += static_cast(block_size); \ + input += quant_block_size; \ + tmp_buf_index += static_cast(quant_block_size); \ } \ } \ \ @@ -605,17 +612,19 @@ Status QuantizeLinear::Compute(OpKernelContext* ctx) const { auto& y = *ctx->Output(0, x_shape); int64_t N; - int64_t broadcast_dim; - int64_t block_size; - PrepareForQDQ(x.Shape(), y_scale, y_zero_point, axis_, N, broadcast_dim, block_size); + int64_t axis_dim_val; + int64_t quant_block_size; + PrepareForQDQ(x.Shape(), y_scale, y_zero_point, axis_, N, axis_dim_val, quant_block_size); const T* zero_point = y_zero_point != nullptr ? y_zero_point->Data() : nullptr; T* output = y.MutableData(); if (x.IsDataType()) { - ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, N, broadcast_dim, block_size, saturate_); + ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, N, axis_dim_val, + quant_block_size, saturate_); } else if (x.IsDataType()) { - ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, N, broadcast_dim, block_size, saturate_); + ComputeLoop(ctx, x.Data(), y_scale.Data(), zero_point, output, N, + axis_dim_val, quant_block_size, saturate_); } else { ORT_THROW("Unsupported input type."); } diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index 15b7eb1063976..235ecfde0954a 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -163,7 +163,7 @@ ParQuantizeLinearStd(const float* Input, /* If starting at an int4 element in the middle of a byte, quantize it by itself. */ \ if (out_start & 0x1) { \ int32_t ival = static_cast(std::nearbyintf(Input[inp_start] / Scale)) + \ - static_cast(ZeroPoint.GetElem0()); \ + static_cast(ZeroPoint.GetElem(0)); \ size_t output_index = out_start >> 1; \ \ INT4_TYPE::UnpackedType quant_val = static_cast( \ @@ -178,7 +178,7 @@ ParQuantizeLinearStd(const float* Input, /* If ending at element that ends in the middle of a byte, quantize it by itself. */ \ if (out_end & 0x1) { \ int32_t ival = static_cast(std::nearbyintf(Input[inp_end - 1] / Scale)) + \ - static_cast(ZeroPoint.GetElem0()); \ + static_cast(ZeroPoint.GetElem(0)); \ size_t output_index = (out_end - 1) >> 1; \ \ INT4_TYPE::UnpackedType quant_val = static_cast( \ @@ -220,7 +220,7 @@ ParQuantizeLinearStd(const float* Input, reinterpret_cast(&(Output[out_idx >> 1])), \ end_idx - begin_idx, \ Scale, \ - static_cast(ZeroPoint.GetElem0())); \ + static_cast(ZeroPoint.GetElem(0))); \ }); \ } diff --git a/onnxruntime/test/onnx/tensorprotoutils.cc b/onnxruntime/test/onnx/tensorprotoutils.cc index b98717116280d..5df055f862a86 100644 --- a/onnxruntime/test/onnx/tensorprotoutils.cc +++ b/onnxruntime/test/onnx/tensorprotoutils.cc @@ -338,7 +338,7 @@ DEFINE_UNPACK_TENSOR_FLOAT8(Float8E5M2FNUZ, TensorProto_DataType_FLOAT8E5M2FNUZ) } \ \ for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { \ - p_data[i] = INT4_TYPE(static_cast(tensor.int32_data()[i])); \ + p_data[i] = INT4_TYPE(static_cast(tensor.int32_data()[i])); \ } \ } diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index e993307482cd3..862408f31f004 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -518,8 +518,8 @@ GetQDQTestCaseFn BuildQDQSplitTestCase(const std::vector& input_shape, if constexpr (std::is_same_v || std::is_same_v) { input_arg = builder.MakeInputInt4(input_shape, InputType::min_val, InputType::max_val); - dq_zp = InputType(static_cast(InputType::max_val / 2)); - q_zp = OutputType(static_cast(OutputType::max_val / 2)); + dq_zp = InputType(static_cast(InputType::max_val / 2)); + q_zp = OutputType(static_cast(OutputType::max_val / 2)); } else { input_arg = builder.MakeInput(input_shape, std::numeric_limits::min(), std::numeric_limits::max()); From 09c11c6444132ca85ceab66453dc1ea3713664a5 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 30 May 2024 10:03:21 -0700 Subject: [PATCH 68/72] Save one instruction in MlasSetInt4Element() --- onnxruntime/core/mlas/lib/mlasi.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 403460c310fb7..83200187963e1 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -2552,11 +2552,11 @@ MlasSetInt4Element(uint8_t* Output, size_t ElemIndex, UnpackedType Value) const size_t OutputIndex = ElemIndex >> 1; // which byte const size_t NibbleIndex = ElemIndex & 0x1; // which 4-bit elem in the byte - const uint8_t Shift = static_cast(4 * NibbleIndex); - const uint8_t Mask = static_cast(0xF << Shift); + const uint8_t Shift = static_cast(NibbleIndex << 2); // Either 0 or 4 + const uint8_t Mask = static_cast(0xF0 >> Shift); uint8_t* Dst = &Output[OutputIndex]; - *Dst &= ~Mask; // Clear 4-bit lane + *Dst &= Mask; // Clear 4-bit lane *Dst |= static_cast((Value & 0xF) << Shift); // Set 4-bit lane } From 12d7d0ec9867b43ed97cd3adacd3d7edf5ffc14a Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 30 May 2024 12:07:49 -0700 Subject: [PATCH 69/72] Use workaround to ensure quant tool stores negative INT4 weights packed in onnx model (onnx bug) --- .../tools/quantization/base_quantizer.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index aff18a8b361c3..74e213fa61362 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -340,7 +340,13 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa f"\nraw={str(q_weight_initializer)[:200]}." ) elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): - q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, q_weight_data) + # TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed + # within int32_data is fixed. + # q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, q_weight_data) + packed_data = onnx.helper.pack_float32_to_4bit(q_weight_data.flatten(), qType == onnx.TensorProto.INT4) + q_weight_initializer = onnx.helper.make_tensor( + q_weight_name, qType, weight.dims, packed_data.tobytes(), raw=True + ) else: q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape( weight.dims @@ -477,8 +483,16 @@ def quantize_weight_per_channel_impl( if not keep_float_weight: if weight_qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): + # TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed + # within int32_data is fixed. + # q_weight_initializer = onnx.helper.make_tensor( + # q_weight_name, weight_qType, weights_shape, quantized_weights + # ) + packed_data = onnx.helper.pack_float32_to_4bit( + quantized_weights.flatten(), weight_qType == onnx.TensorProto.INT4 + ) q_weight_initializer = onnx.helper.make_tensor( - q_weight_name, weight_qType, weights_shape, quantized_weights + q_weight_name, weight_qType, weights_shape, packed_data.tobytes(), raw=True ) self.model.initializer_extend([q_weight_initializer]) else: From 43c7bf17b3a0def48230679b6b09b0e1e014752f Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 30 May 2024 14:22:39 -0700 Subject: [PATCH 70/72] Add int4 qdq quantization tool test --- .../test/python/quantization/test_qdq.py | 178 ++++++++++++++++++ 1 file changed, 178 insertions(+) diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 93488a03d497c..efe21915978ef 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -1515,5 +1515,183 @@ def test_16bit_subgraph(self): check_model_correctness(self, f32_model_path, qdq_model_path, data_reader.get_next()) +class TestQDQ4bit(TestQDQFormat): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.4bit_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_conv_test_model( + self, + inp_shape: list[int], + weight_data: np.ndarray, + bias_data: np.ndarray, + ): + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, inp_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, None) + weight = onnx.numpy_helper.from_array(weight_data, "weight") + bias = onnx.numpy_helper.from_array(bias_data, "bias") + + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [conv_node], + "Convf32", + [input_0], + [output_0], + initializer=[weight, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + + return onnx.shape_inference.infer_shapes(model) + + def test_int4_qdq_conv(self): + """ + Test quantization of int4 conv weight. + """ + float_model_path = os.path.join(self._tmp_dir_path, "conv_int4.f32.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, "conv_int4.qdq.onnx") + + inp_shape = [1, 2, 100, 100] + weight_shape = [2, 2, 20, 20] + + # range = 3.0, scale = 3/15, zp = 0 + weight_data = np.linspace(-1.5, 1.5, num=1600, dtype=np.float32).reshape(weight_shape) + bias_data = np.array([-10.0, 10.0], dtype=np.float32) + float_model = self.build_conv_test_model(inp_shape, weight_data, bias_data) + + onnx.checker.check_model(float_model, True) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": inp_shape}, np.float32) + + tensor_quant_overrides = { + "weight": [{"quant_type": QuantType.QInt4}], # Quantize weights to INT4 + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt16, + weight_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": tensor_quant_overrides, + }, + ) + + qdq_node_counts = {"QuantizeLinear": 2, "DequantizeLinear": 4} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + + # Check the the weight's zero-point data type is INT4 and has expected value + zp_val = 0 + weight_zp_init = initializers["weight_zero_point"] + self.assertEqual(weight_zp_init.data_type, onnx.TensorProto.INT4) + self.assertEqual(weight_zp_init.int32_data[0], zp_val) + + # Check for the expected scale value + weight_scale_init = initializers["weight_scale"] + scale_val = np.float32(3.0 / 15) + self.assertEqual(weight_scale_init.data_type, onnx.TensorProto.FLOAT) + self.assertEqual(weight_scale_init.float_data[0], scale_val) + + # Check that INT4 weights take up approximately 50% the size of INT8 weights. + # Using protobuf's ByteSize() is not exact because it includes other fields in the proto message. + unpacked_size = 1 + for dim in weight_shape: + unpacked_size *= dim + + weight_quant_init = initializers["weight_quantized"] + size_ratio = weight_quant_init.ByteSize() / unpacked_size + self.assertLess(size_ratio, 0.55) + + # Check that the quantized weight values are correct. + if weight_quant_init.HasField("raw_data"): + float_data = weight_data.flatten().tolist() + for index, float_val in enumerate(float_data): + expected_int4_val = np.clip(np.float32(float_val / scale_val).round() + zp_val, -8, 7) + int4_pair = onnx.subbyte.unpack_single_4bitx2(weight_quant_init.raw_data[index >> 1], True) + int4_val = int4_pair[index & 0x1] + + self.assertEqual(np.float32(int4_val), expected_int4_val) + + def test_int4_qdq_per_channel_conv(self): + """ + Test per-channel quantization of int4 conv weight. + """ + float_model_path = os.path.join(self._tmp_dir_path, "conv_int4_per_chan.f32.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, "conv_int4_per_chan.qdq.onnx") + + inp_shape = [1, 2, 100, 100] + weight_shape = [2, 2, 20, 20] + + weight_data = np.linspace(-1.5, 1.5, num=1600, dtype=np.float32).reshape(weight_shape) + bias_data = np.array([-10.0, 10.0], dtype=np.float32) + float_model = self.build_conv_test_model(inp_shape, weight_data, bias_data) + + onnx.checker.check_model(float_model, True) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": inp_shape}, np.float32) + + per_chan_axis = 0 + tensor_quant_overrides = { + "weight": [{"quant_type": QuantType.QInt4, "axis": per_chan_axis}], # Quantize weight to INT4 (per-channel) + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt16, + weight_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": tensor_quant_overrides, + }, + ) + + qdq_node_counts = {"QuantizeLinear": 2, "DequantizeLinear": 4} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + + # Check that the weight's zero-point data type is INT4 and has 2 elems + weight_zp_init = initializers["weight_zero_point"] + self.assertEqual(weight_zp_init.data_type, onnx.TensorProto.INT4) + self.assertEqual(weight_zp_init.dims[0], 2) + + # Check that the weight's scale data type is FLOAT and has 2 elems + weight_scale_init = initializers["weight_scale"] + self.assertEqual(weight_scale_init.data_type, onnx.TensorProto.FLOAT) + self.assertEqual(weight_scale_init.dims[0], 2) + + # Check that INT4 weights take up approximately 50% the size of INT8 weights. + # Using protobuf's ByteSize() is not exact because it includes other fields in the proto message. + unpacked_size = 1 + for dim in weight_shape: + unpacked_size *= dim + + weight_quant_init = initializers["weight_quantized"] + size_ratio = weight_quant_init.ByteSize() / unpacked_size + self.assertLess(size_ratio, 0.55) + + if __name__ == "__main__": unittest.main() From d4a05b7ac4a664063ba898441db4bc4dc2cfd95c Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 30 May 2024 14:35:48 -0700 Subject: [PATCH 71/72] Check opset when using int4 types with quant tool --- onnxruntime/python/tools/quantization/qdq_quantizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index c368d887fda22..36244dc5df128 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -195,14 +195,14 @@ def __init__( # So, may have to override the Q/DQ op domain to 'com.microsoft' if the activation or weight types # are 16-bit integers. if self.opset_version < 21: - int16_types = (TensorProto.UINT16, TensorProto.INT16) - overrides_have_int16 = any(t.tensor_type in int16_types for t in self.tensor_quant_override_qtypes) + new_types = (TensorProto.UINT16, TensorProto.INT16, TensorProto.UINT4, TensorProto.INT4) + overrides_have_new_types = any(t.tensor_type in new_types for t in self.tensor_quant_override_qtypes) if not self.qdq_op_domain and ( - self.activation_qType in int16_types or self.weight_qType in int16_types or overrides_have_int16 + self.activation_qType in new_types or self.weight_qType in new_types or overrides_have_new_types ): logging.warning( "ONNX QuantizeLinear and DequantizeLinear operators do not support " - "16-bit integer quantization types prior to opset 21. " + "16-bit/4-bit integer quantization types prior to opset 21. " f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to " "enable support." ) From 27301e63e10a8a73e938aa0599e599b30f11e88e Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 30 May 2024 14:51:38 -0700 Subject: [PATCH 72/72] Check opset version when creating qdq config for QNN --- .../execution_providers/qnn/quant_config.py | 8 +++++--- .../python/tools/quantization/qdq_quantizer.py | 12 ++++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index 3b857c991951c..1ad56dc3ac455 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -21,6 +21,7 @@ Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16} Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8} +Q4_TYPES = {QuantType.QInt4, QuantType.QUInt4} OP_TYPES_TO_EXCLUDE = {"Cast"} MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB @@ -173,11 +174,12 @@ def get_qnn_qdq_config( } # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain - # on Q/DQ operators if using 16-bit quantization. + # on Q/DQ operators if using 16-bit or 4-bit quantization. onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx") if onnx_opset.version < 21: - overrides_have_int16 = any(t in Q16_TYPES for t in overrides_helper.get_quant_types()) - if activation_type in Q16_TYPES or weight_type in Q16_TYPES or overrides_have_int16: + opset21_types = Q16_TYPES.union(Q4_TYPES) + overrides_have_opset21_types = any(t in opset21_types for t in overrides_helper.get_quant_types()) + if activation_type in opset21_types or weight_type in opset21_types or overrides_have_opset21_types: extra_options["UseQDQContribOps"] = True return StaticQuantConfig( diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 36244dc5df128..ac61f4779d389 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -193,12 +193,16 @@ def __init__( # The ONNX spec did not support 16-bit Q/DQ ops before opset 21. # So, may have to override the Q/DQ op domain to 'com.microsoft' if the activation or weight types - # are 16-bit integers. + # are 16-bit or 4-bit integers. if self.opset_version < 21: - new_types = (TensorProto.UINT16, TensorProto.INT16, TensorProto.UINT4, TensorProto.INT4) - overrides_have_new_types = any(t.tensor_type in new_types for t in self.tensor_quant_override_qtypes) + opset21_types = (TensorProto.UINT16, TensorProto.INT16, TensorProto.UINT4, TensorProto.INT4) + overrides_have_opset21_types = any( + t.tensor_type in opset21_types for t in self.tensor_quant_override_qtypes + ) if not self.qdq_op_domain and ( - self.activation_qType in new_types or self.weight_qType in new_types or overrides_have_new_types + self.activation_qType in opset21_types + or self.weight_qType in opset21_types + or overrides_have_opset21_types ): logging.warning( "ONNX QuantizeLinear and DequantizeLinear operators do not support "