Skip to content

Commit

Permalink
ORT 1.20.0 Release: Cherry pick round 1 (#22526)
Browse files Browse the repository at this point in the history
ORT 1.20.0 release preparation: Cherry pick round 1

Approved cherry pick comments

---------

Co-authored-by: Edward Chen <[email protected]>
Co-authored-by: Hector Li <[email protected]>
Co-authored-by: Adrian Lizarraga <[email protected]>
Co-authored-by: Patrice Vignola <[email protected]>
Co-authored-by: Changming Sun <[email protected]>
Co-authored-by: Yulong Wang <[email protected]>
  • Loading branch information
7 people authored Oct 22, 2024
1 parent f9e623e commit 2d00351
Show file tree
Hide file tree
Showing 63 changed files with 664 additions and 205 deletions.
9 changes: 7 additions & 2 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,14 @@ function(AddTest)

if (IOS)
# target_sources(${_UT_TARGET} PRIVATE ${TEST_SRC_DIR}/xctest/orttestmain.m)

set(_UT_IOS_BUNDLE_GUI_IDENTIFIER com.onnxruntime.utest.${_UT_TARGET})
# replace any characters that are not valid in a bundle identifier with '-'
string(REGEX REPLACE "[^a-zA-Z0-9\\.-]" "-" _UT_IOS_BUNDLE_GUI_IDENTIFIER ${_UT_IOS_BUNDLE_GUI_IDENTIFIER})

set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest"
MACOSX_BUNDLE_BUNDLE_NAME ${_UT_TARGET}
MACOSX_BUNDLE_GUI_IDENTIFIER com.onnxruntime.utest.${_UT_TARGET}
MACOSX_BUNDLE_GUI_IDENTIFIER ${_UT_IOS_BUNDLE_GUI_IDENTIFIER}
MACOSX_BUNDLE_LONG_VERSION_STRING ${ORT_VERSION}
MACOSX_BUNDLE_BUNDLE_VERSION ${ORT_VERSION}
MACOSX_BUNDLE_SHORT_VERSION_STRING ${ORT_VERSION}
Expand All @@ -163,7 +168,7 @@ function(AddTest)

set_target_properties(${_UT_TARGET}_xc PROPERTIES FOLDER "ONNXRuntimeXCTest"
MACOSX_BUNDLE_BUNDLE_NAME ${_UT_TARGET}_xc
MACOSX_BUNDLE_GUI_IDENTIFIER com.onnxruntime.utest.${_UT_TARGET}
MACOSX_BUNDLE_GUI_IDENTIFIER ${_UT_IOS_BUNDLE_GUI_IDENTIFIER}
MACOSX_BUNDLE_LONG_VERSION_STRING ${ORT_VERSION}
MACOSX_BUNDLE_BUNDLE_VERSION ${ORT_VERSION}
MACOSX_BUNDLE_SHORT_VERSION_STRING ${ORT_VERSION}
Expand Down
18 changes: 11 additions & 7 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3651,13 +3651,17 @@ struct OrtApi {
* - "73"
* - "75"
* "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device).
"enable_htp_fp16_precision": Used for float32 model for HTP backend.
Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision.
- "0": With fp32 precision.
- "1": Default. With fp16 precision.
"enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context.
- "0": Default. Disabled.
- "1": Enabled.
* "enable_htp_fp16_precision": Used for float32 model for HTP backend.
* Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision.
* - "0": With fp32 precision.
* - "1": Default. With fp16 precision.
* "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context.
* - "0": Default. Disabled.
* - "1": Enabled.
* "offload_graph_io_quantization": Offload graph input quantization and graph output dequantization to another
* execution provider (typically CPU EP).
* - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O.
* - "1": Enabled.
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,20 @@ bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderI
LOGS(logger, VERBOSE) << "DepthToSpace: CRD mode requires static shape";
return false;
}

if (mode == "DCR" && input_params.coreml_version < 7) {
int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
GetType(*input_defs[0], input_type, logger);

if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
// In CoreML version 6 (e.g., on an iOS 16 simulator) with DCR mode and float16 input, the output is all zeros
// in this unit test: TensorOpTest/1.DepthToSpaceTest_4.
// However, CoreML version 7 is fine.
// Don't support CoreML version < 7, DCR mode, and float16 input.
LOGS(logger, VERBOSE) << "DepthToSpace: DCR mode with float16 input requires at least CoreML version 7.";
return false;
}
}
} else {
if (mode != "DCR") {
LOGS(logger, VERBOSE) << "DepthToSpace: " << mode << " mode is not supported";
Expand Down
82 changes: 48 additions & 34 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ void ComputeJob(
const T* bias_data,
const ptrdiff_t task_idx,
const int64_t norm_size,
IAllocatorUniquePtr<float>& scale_float_uptr,
IAllocatorUniquePtr<float>& bias_float_uptr,
const float* scale_float_ptr,
const float* bias_float_ptr,
float epsilon,
bool simplified,
T* Y_data,
U* mean_data,
U* inv_std_dev_data,
AllocatorPtr alloc) {
ORT_UNUSED_PARAMETER(scale_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(bias_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(scale_float_ptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(bias_float_ptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(alloc);

const T* p_input = X_data + task_idx * norm_size;
Expand Down Expand Up @@ -82,14 +82,17 @@ void ComputeJob(
const MLFloat16* bias_data,
const ptrdiff_t task_idx,
const int64_t norm_size,
IAllocatorUniquePtr<float>& scale_float_uptr,
IAllocatorUniquePtr<float>& bias_float_uptr,
const float* scale_float_ptr,
const float* bias_float_ptr,
float epsilon,
bool simplified,
MLFloat16* Y_data,
U* mean_data,
U* inv_std_dev_data,
AllocatorPtr alloc) {
ORT_UNUSED_PARAMETER(scale_data); // only used in float/double overload
ORT_UNUSED_PARAMETER(bias_data); // only used in float/double overload

const MLFloat16* p_input = X_data + task_idx * norm_size;
MLFloat16* p_output = Y_data + task_idx * norm_size;

Expand Down Expand Up @@ -117,22 +120,10 @@ void ComputeJob(
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

if (!scale_float_uptr) {
scale_float_uptr = std::move(input_float_uptr); // overwrite input with scale values, since they have the same size
MlasConvertHalfToFloatBuffer(scale_data, scale_float_uptr.get(), num_elems);
}

if (bias_data && !bias_float_uptr) {
bias_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_float_uptr.get(), num_elems);
}

const float* scale_float_ptr = scale_float_uptr.get();
const float* bias_float_ptr = bias_float_uptr.get();
for (size_t h = 0; h < num_elems; h++) {
if (simplified) {
output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h];
} else if (nullptr == bias_data) {
} else if (nullptr == bias_float_ptr) {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h];
} else {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h];
Expand Down Expand Up @@ -166,7 +157,13 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I
} // namespace

LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified, bool contrib_op)
: OpKernel(op_kernel_info), simplified_{simplified}, contrib_op_{contrib_op}, scale_fp32_(nullptr), bias_fp32_(nullptr) {
: OpKernel(op_kernel_info),
simplified_{simplified},
contrib_op_{contrib_op},
prepacked_scale_fp32_data_(nullptr),
prepacked_scale_fp32_size_(0),
prepacked_bias_fp32_data_(nullptr),
prepacked_bias_fp32_size_(0) {
ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK());
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
}
Expand All @@ -175,15 +172,15 @@ template <typename T, typename U>
Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) const {
// Inputs
const Tensor* X = p_ctx->Input<Tensor>(0);
const Tensor* scale = p_ctx->Input<Tensor>(1);
const Tensor* bias = p_ctx->Input<Tensor>(2);
const Tensor* scale = prepacked_scale_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(1);
const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(2);
const T* X_data = X->Data<T>();
const T* scale_data = scale->Data<T>();
const T* scale_data = scale ? scale->Data<T>() : nullptr;
const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data<T>();

const TensorShape& x_shape = X->Shape();
const TensorShape& scale_shape = scale->Shape();
const TensorShape& bias_shape = bias->Shape();
size_t scale_size = scale ? static_cast<size_t>(scale->Shape().Size()) : prepacked_scale_fp32_size_;
size_t bias_size = bias ? static_cast<size_t>(bias->Shape().Size()) : prepacked_bias_fp32_size_;
Tensor* Y = p_ctx->Output(0, x_shape);
T* Y_data = Y->MutableData<T>();

Expand Down Expand Up @@ -218,7 +215,7 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));
return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, Y_data, mean_data,
return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_size, bias_data, bias_size, Y_data, mean_data,
inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc);
}

Expand All @@ -237,9 +234,11 @@ Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr

is_packed = false;
if (input_idx == 1) { // scale
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, scale_fp32_, is_packed);
prepacked_scale_fp32_size_ = static_cast<size_t>(tensor.Shape().Size());
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_scale_fp32_data_, is_packed);
} else if (input_idx == 2) { // bias
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed);
prepacked_bias_fp32_size_ = static_cast<size_t>(tensor.Shape().Size());
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
}

return Status::OK();
Expand All @@ -250,9 +249,9 @@ Status LayerNormImpl::ComputeWithoutContext(
const T* X_data,
const TensorShape& x_shape,
const T* scale_data,
const TensorShape& scale_shape,
size_t scale_size,
const T* bias_data,
const TensorShape& bias_shape,
size_t bias_size,
T* Y_data,
U* mean_data,
U* inv_std_dev_data,
Expand All @@ -264,19 +263,34 @@ Status LayerNormImpl::ComputeWithoutContext(
int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));

const auto scale_size = scale_shape.Size();
const auto bias_size = (bias_data) ? bias_shape.Size() : 0;
if (scale_size != norm_size || (bias_data && bias_size != norm_size)) {
if (static_cast<int64_t>(scale_size) != norm_size || (bias_data && static_cast<int64_t>(bias_size) != norm_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", norm_size,
". Size of scale and bias (if provided) must match this. Got scale size of ",
scale_size, " and bias size of ", bias_size);
}

IAllocatorUniquePtr<float> scale_fp32;
IAllocatorUniquePtr<float> bias_fp32;
if constexpr (std::is_same_v<T, MLFloat16>) {
if (prepacked_scale_fp32_data_ == nullptr) {
const size_t num_elems = static_cast<size_t>(norm_size);
scale_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(scale_data, scale_fp32.get(), num_elems);
}
if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
const size_t num_elems = static_cast<size_t>(norm_size);
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
}
}

concurrency::ThreadPool::TryBatchParallelFor(
thread_pool, static_cast<int32_t>(norm_count),
[&](ptrdiff_t task_idx) {
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, scale_fp32_, bias_fp32_,
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size,
prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(),
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);
},
0);
Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class LayerNormImpl : public OpKernel {
const T* X_data,
const TensorShape& x_shape,
const T* scale_data,
const TensorShape& scale_shape,
size_t scale_size,
const T* bias_data,
const TensorShape& bias_shape,
size_t bias_size,
T* Y_data,
U* mean_data,
U* inv_std_dev,
Expand Down Expand Up @@ -63,8 +63,10 @@ class LayerNormImpl : public OpKernel {
float epsilon_;
const bool simplified_;
const bool contrib_op_;
mutable IAllocatorUniquePtr<float> scale_fp32_;
mutable IAllocatorUniquePtr<float> bias_fp32_;
IAllocatorUniquePtr<float> prepacked_scale_fp32_data_;
size_t prepacked_scale_fp32_size_;
IAllocatorUniquePtr<float> prepacked_bias_fp32_data_;
size_t prepacked_bias_fp32_size_;
};

} // namespace onnxruntime
Loading

0 comments on commit 2d00351

Please sign in to comment.