diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index 3e70f848675cb..d5b8961cf8c5a 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -96,79 +96,6 @@ void ComputeJob( } } -void ComputeJob( - const MLFloat16* input_data, - const MLFloat16* skip_data, - const float* prepacked_skip_fp32_data, - const float* gamma_float_ptr, - const float* beta_float_ptr, - const float* bias_float_ptr, - float* output_float_ptr, - ptrdiff_t task_idx, - int hidden_size, - int64_t skip_size, - float epsilon, - bool simplified, - MLFloat16* output_data, - MLFloat16* skip_input_bias_add_output_data, - AllocatorPtr alloc) { - auto offset = task_idx * hidden_size; - const MLFloat16* p_input = input_data + offset; - MLFloat16* p_output = output_data + offset; - MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset; - - float mean(0.0f); - float mean_square(0.0f); - const size_t num_elems = static_cast(hidden_size); - - IAllocatorUniquePtr input_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems); - - IAllocatorUniquePtr skip_float_uptr = nullptr; - if (prepacked_skip_fp32_data == nullptr && skip_data) { - const MLFloat16* p_skip = skip_data + (offset % skip_size); - skip_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems); - } - - const float* input_float_ptr = input_float_uptr.get(); - const float* skip_float_ptr = prepacked_skip_fp32_data ? prepacked_skip_fp32_data : skip_float_uptr.get(); - for (size_t h = 0; h < num_elems; h++) { - float val = input_float_ptr[h] + skip_float_ptr[h]; - - if (bias_float_ptr) { - val += bias_float_ptr[h]; - } - - output_float_ptr[h] = val; - mean += val; - mean_square += val * val; - } - - if (nullptr != p_skip_input_bias_add_output) { - MlasConvertFloatToHalfBuffer(output_float_ptr, p_skip_input_bias_add_output, num_elems); - } - - mean = mean / hidden_size; - if (simplified) { - mean_square = sqrt(mean_square / hidden_size + epsilon); - } else { - mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon); - } - - for (size_t h = 0; h < num_elems; h++) { - if (simplified) { - output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h]; - } else if (nullptr == beta_float_ptr) { - output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h]; - } else { - output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h]; - } - } - - MlasConvertFloatToHalfBuffer(output_float_ptr, p_output, num_elems); -} - void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr& dest, bool& is_packed) { if (tensor.GetElementType() == utils::ToTensorProtoElementType()) { auto tensor_data_ptr = tensor.Data(); @@ -200,8 +127,8 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const Tensor* input = p_ctx->Input(0); const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input(1); const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input(2); - const Tensor* beta = prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input(3); - const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input(4); + const Tensor* beta = simplified ? nullptr : (prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input(3)); + const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input(simplified ? 3 : 4); Tensor* output = p_ctx->Output(0, input->Shape()); // For inferencing, we support one more optional output which is the sum of the input and skip tensors Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape()); @@ -232,56 +159,93 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { // For inferencing, we support one more optional output which is the sum of the input and skip tensors T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData(); - const int64_t skip_size = skip ? skip->Shape().Size() : prepacked_skip_fp32_size_; - AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); - - IAllocatorUniquePtr output_fp32; - IAllocatorUniquePtr gamma_fp32; - IAllocatorUniquePtr beta_fp32; - IAllocatorUniquePtr bias_fp32; - if constexpr (std::is_same_v) { + const size_t total_data_size = static_cast(input->Shape().Size()); + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); + + IAllocatorUniquePtr input_fp32; + IAllocatorUniquePtr output_fp32; + IAllocatorUniquePtr skip_input_bias_add_output_fp32; + IAllocatorUniquePtr skip_fp32; + IAllocatorUniquePtr gamma_fp32; + IAllocatorUniquePtr beta_fp32; + IAllocatorUniquePtr bias_fp32; + + const float* input_data_f = nullptr; + const float* skip_data_f = nullptr; + const float* gamma_data_f = nullptr; + const float* beta_data_f = nullptr; + const float* bias_data_f = nullptr; + float* output_data_f = nullptr; + float* skip_input_bias_add_output_data_f = nullptr; + const size_t num_elems = static_cast(hidden_size); - output_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + input_fp32 = IAllocator::MakeUniquePtr(alloc, total_data_size); + MlasConvertHalfToFloatBuffer(input_data, input_fp32.get(), total_data_size); + input_data_f = input_fp32.get(); + + output_fp32 = IAllocator::MakeUniquePtr(alloc, total_data_size); + output_data_f = output_fp32.get(); + + skip_input_bias_add_output_fp32 = IAllocator::MakeUniquePtr(alloc, total_data_size); + skip_input_bias_add_output_data_f = skip_input_bias_add_output_fp32.get(); - if (prepacked_gamma_fp32_data_ == nullptr && gamma_data) { + if (skip_data) { + skip_fp32 = IAllocator::MakeUniquePtr(alloc, static_cast(skip_size)); + MlasConvertHalfToFloatBuffer(skip_data, skip_fp32.get(), static_cast(skip_size)); + skip_data_f = skip_fp32.get(); + } else if (prepacked_skip_fp32_data_) { + skip_data_f = prepacked_skip_fp32_data_.get(); + } + + if (gamma_data) { gamma_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems); + gamma_data_f = gamma_fp32.get(); + } else if (prepacked_gamma_fp32_data_) { + gamma_data_f = prepacked_gamma_fp32_data_.get(); } - if (prepacked_beta_fp32_data_ == nullptr && beta_data) { + if (beta_data) { beta_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); MlasConvertHalfToFloatBuffer(beta_data, beta_fp32.get(), num_elems); + beta_data_f = beta_fp32.get(); + } else if (prepacked_beta_fp32_data_) { + beta_data_f = prepacked_beta_fp32_data_.get(); } - if (prepacked_bias_fp32_data_ == nullptr && bias_data) { + if (bias_data) { bias_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems); + bias_data_f = bias_fp32.get(); + } else if (prepacked_bias_fp32_data_) { + bias_data_f = prepacked_bias_fp32_data_.get(); } - } - concurrency::ThreadPool::TryBatchParallelFor( - p_ctx->GetOperatorThreadPool(), static_cast(task_count), - [&](ptrdiff_t task_idx) { - if constexpr (std::is_same_v) { - ComputeJob(input_data, skip_data, - prepacked_skip_fp32_data_.get(), - prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(), - prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(), - prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(), - output_fp32.get(), - task_idx, hidden_size, skip_size, epsilon_, simplified, output_data, - skip_input_bias_add_output_data, alloc); - } else { + concurrency::ThreadPool::TryBatchParallelFor( + p_ctx->GetOperatorThreadPool(), static_cast(task_count), + [&](ptrdiff_t task_idx) { + ComputeJob(input_data_f, skip_data_f, gamma_data_f, beta_data_f, bias_data_f, task_idx, hidden_size, skip_size, + epsilon_, simplified, output_data_f, skip_input_bias_add_output_data_f); + }, + 0); + MlasConvertFloatToHalfBuffer(output_data_f, output_data, total_data_size); + if (skip_input_bias_add_output_data != nullptr) + MlasConvertFloatToHalfBuffer(skip_input_bias_add_output_data_f, skip_input_bias_add_output_data, total_data_size); + } else { + concurrency::ThreadPool::TryBatchParallelFor( + p_ctx->GetOperatorThreadPool(), static_cast(task_count), + [&](ptrdiff_t task_idx) { ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data, skip_input_bias_add_output_data); - } - }, - 0); + }, + 0); + } return Status::OK(); } @@ -290,16 +254,22 @@ template Status SkipLayerNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); - is_packed = false; if (input_idx == 1) { // skip prepacked_skip_fp32_size_ = tensor.Shape().Size(); ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed); } else if (input_idx == 2) { // gamma ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed); - } else if (input_idx == 3) { // beta - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed); + } else if (input_idx == 3) { + if constexpr (simplified) { + // bias + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); + } else { + // beta + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed); + } } else if (input_idx == 4) { // bias + ORT_ENFORCE(!simplified, "SkipSimplifiedLayerNormalization should only has 4 inputs (input, skip, gamma, and beta). Got 5."); ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); }